Line data Source code
1 : #include "rust-tyty-variance-analysis-private.h"
2 : #include "rust-hir-type-check.h"
3 :
4 : namespace Rust {
5 : namespace TyTy {
6 :
7 : BaseType *
8 7229 : lookup_type (HirId ref)
9 : {
10 7229 : BaseType *ty = nullptr;
11 7229 : bool ok = Resolver::TypeCheckContext::get ()->lookup_type (ref, &ty);
12 7229 : rust_assert (ok);
13 7229 : return ty;
14 : }
15 :
16 : namespace VarianceAnalysis {
17 :
18 4680 : CrateCtx::CrateCtx () : private_ctx (new GenericTyPerCrateCtx ()) {}
19 :
20 : // Must be here because of incomplete type.
21 0 : CrateCtx::~CrateCtx () = default;
22 :
23 : void
24 3085 : CrateCtx::add_type_constraints (ADTType &type)
25 : {
26 3085 : private_ctx->process_type (type);
27 3085 : }
28 :
29 : void
30 4463 : CrateCtx::solve ()
31 : {
32 4463 : private_ctx->solve ();
33 4463 : private_ctx->debug_print_solutions ();
34 4463 : }
35 :
36 : std::vector<Variance>
37 0 : CrateCtx::query_generic_variance (const ADTType &type)
38 : {
39 0 : return private_ctx->query_generic_variance (type);
40 : }
41 :
42 : std::vector<Variance>
43 516 : CrateCtx::query_type_variances (BaseType *type)
44 : {
45 516 : TyVisitorCtx ctx (*private_ctx);
46 516 : return ctx.collect_variances (*type);
47 516 : }
48 :
49 : std::vector<Region>
50 76 : CrateCtx::query_type_regions (BaseType *type)
51 : {
52 76 : return private_ctx->query_type_regions (type);
53 : }
54 :
55 : FreeRegions
56 5 : CrateCtx::query_field_regions (const ADTType *parent, size_t variant_index,
57 : size_t field_index,
58 : const FreeRegions &parent_regions)
59 : {
60 5 : return private_ctx->query_field_regions (parent, variant_index, field_index,
61 5 : parent_regions);
62 : }
63 :
64 : Variance
65 0 : Variance::reverse () const
66 : {
67 0 : switch (kind)
68 : {
69 0 : case BIVARIANT:
70 0 : return bivariant ();
71 0 : case COVARIANT:
72 0 : return contravariant ();
73 0 : case CONTRAVARIANT:
74 0 : return covariant ();
75 0 : case INVARIANT:
76 0 : return invariant ();
77 : }
78 :
79 0 : rust_unreachable ();
80 : }
81 :
82 : Variance
83 1483 : Variance::join (Variance lhs, Variance rhs)
84 : {
85 1483 : return {Kind (lhs.kind | rhs.kind)};
86 : }
87 :
88 : void
89 1320 : Variance::join (Variance rhs)
90 : {
91 1320 : *this = join (*this, rhs);
92 1320 : }
93 :
94 : Variance
95 216 : Variance::transform (Variance lhs, Variance rhs)
96 : {
97 216 : switch (lhs.kind)
98 : {
99 0 : case BIVARIANT:
100 0 : return bivariant ();
101 214 : case COVARIANT:
102 214 : return rhs;
103 0 : case CONTRAVARIANT:
104 0 : return rhs.reverse ();
105 2 : case INVARIANT:
106 2 : return invariant ();
107 : }
108 0 : rust_unreachable ();
109 : }
110 :
111 : std::string
112 8156 : Variance::as_string () const
113 : {
114 8156 : switch (kind)
115 : {
116 255 : case BIVARIANT:
117 255 : return "o";
118 7599 : case COVARIANT:
119 7599 : return "+";
120 7 : case CONTRAVARIANT:
121 7 : return "-";
122 295 : case INVARIANT:
123 295 : return "*";
124 : }
125 0 : rust_unreachable ();
126 : }
127 :
128 : void
129 3085 : GenericTyPerCrateCtx::process_type (ADTType &type)
130 : {
131 3085 : GenericTyVisitorCtx (*this).process_type (type);
132 3085 : }
133 :
134 : void
135 4463 : GenericTyPerCrateCtx::solve ()
136 : {
137 4463 : rust_debug ("Variance analysis solving started:");
138 :
139 : // Fix point iteration
140 4463 : bool changed = true;
141 13413 : while (changed)
142 : {
143 4487 : changed = false;
144 4650 : for (auto constraint : constraints)
145 : {
146 163 : rust_debug ("\tapplying constraint: %s <= %s",
147 : to_string (constraint.target_index).c_str (),
148 : to_string (*constraint.term).c_str ());
149 :
150 163 : auto old_solution = solutions[constraint.target_index];
151 163 : auto new_solution
152 163 : = Variance::join (old_solution, evaluate (constraint.term));
153 :
154 163 : if (old_solution != new_solution)
155 : {
156 24 : rust_debug ("\t\tsolution changed: %s => %s",
157 : old_solution.as_string ().c_str (),
158 : new_solution.as_string ().c_str ());
159 :
160 24 : changed = true;
161 24 : solutions[constraint.target_index] = new_solution;
162 : }
163 : }
164 : }
165 :
166 4463 : constraints.clear ();
167 4463 : constraints.shrink_to_fit ();
168 4463 : }
169 :
170 : void
171 4463 : GenericTyPerCrateCtx::debug_print_solutions ()
172 : {
173 4463 : rust_debug ("Variance analysis results:");
174 :
175 7548 : for (auto type : map_from_ty_orig_ref)
176 : {
177 3085 : auto solution_index = type.second;
178 3085 : auto ref = type.first;
179 :
180 3085 : BaseType *ty = lookup_type (ref);
181 :
182 3085 : std::string result = "\t";
183 :
184 3085 : if (auto adt = ty->try_as<ADTType> ())
185 : {
186 9255 : result += adt->get_identifier ();
187 3085 : result += "<";
188 :
189 3085 : size_t i = solution_index;
190 3106 : for (auto ®ion : adt->get_used_arguments ().get_regions ())
191 : {
192 21 : (void) region;
193 21 : if (i > solution_index)
194 1 : result += ", ";
195 42 : result += solutions[i].as_string ();
196 21 : i++;
197 : }
198 4269 : for (auto ¶m : adt->get_substs ())
199 : {
200 1184 : if (i > solution_index)
201 129 : result += ", ";
202 2368 : result += param.get_type_representation ().as_string ();
203 1184 : result += "=";
204 2368 : result += solutions[i].as_string ();
205 1184 : i++;
206 : }
207 :
208 3085 : result += ">";
209 : }
210 : else
211 : {
212 0 : rust_sorry_at (
213 : ty->get_ref (),
214 : "This is a compiler bug: Unhandled type in variance analysis");
215 : }
216 3085 : rust_debug ("%s", result.c_str ());
217 3085 : }
218 4463 : }
219 :
220 : tl::optional<SolutionIndex>
221 3512 : GenericTyPerCrateCtx::lookup_type_index (HirId orig_ref)
222 : {
223 3512 : auto it = map_from_ty_orig_ref.find (orig_ref);
224 3512 : if (it != map_from_ty_orig_ref.end ())
225 : {
226 363 : return it->second;
227 : }
228 3149 : return tl::nullopt;
229 : }
230 :
231 : void
232 3085 : GenericTyVisitorCtx::process_type (ADTType &ty)
233 : {
234 3085 : rust_debug ("add_type_constraints: %s", ty.as_string ().c_str ());
235 :
236 3085 : first_lifetime = lookup_or_add_type (ty.get_orig_ref ());
237 3085 : first_type = first_lifetime + ty.get_used_arguments ().get_regions ().size ();
238 :
239 4269 : for (auto ¶m : ty.get_substs ())
240 1184 : param_names.push_back (param.get_type_representation ().as_string ());
241 :
242 6864 : for (const auto &variant : ty.get_variants ())
243 : {
244 3779 : if (variant->get_variant_type () != VariantDef::NUM
245 3779 : && variant->get_variant_type () != VariantDef::UNIT)
246 : {
247 6828 : for (const auto &field : variant->get_fields ())
248 4300 : add_constraints_from_ty (field->get_field_type (),
249 4300 : Variance::covariant ());
250 : }
251 : }
252 3085 : }
253 :
254 : std::string
255 7831 : GenericTyPerCrateCtx::to_string (const Term &term) const
256 : {
257 7831 : switch (term.kind)
258 : {
259 6903 : case Term::CONST:
260 6903 : return term.const_val.as_string ();
261 464 : case Term::REF:
262 928 : return "v(" + to_string (term.ref) + ")";
263 464 : case Term::TRANSFORM:
264 928 : return "(" + to_string (*term.transform.lhs) + " x "
265 1856 : + to_string (*term.transform.rhs) + ")";
266 : }
267 0 : rust_unreachable ();
268 : }
269 :
270 : std::string
271 627 : GenericTyPerCrateCtx::to_string (SolutionIndex index) const
272 : {
273 : // Search all values in def_id_to_solution_index_start and find key for
274 : // largest value smaller than index
275 627 : std::pair<HirId, SolutionIndex> best = {0, 0};
276 :
277 2830 : for (const auto &ty_map : map_from_ty_orig_ref)
278 : {
279 2203 : if (ty_map.second <= index && ty_map.first > best.first)
280 : best = ty_map;
281 : }
282 627 : rust_assert (best.first != 0);
283 :
284 627 : BaseType *ty = lookup_type (best.first);
285 :
286 627 : std::string result = "";
287 627 : if (auto adt = ty->try_as<ADTType> ())
288 : {
289 1881 : result += (adt->get_identifier ());
290 : }
291 : else
292 : {
293 0 : result += ty->as_string ();
294 : }
295 :
296 1881 : result += "[" + std::to_string (index - best.first) + "]";
297 627 : return result;
298 : }
299 :
300 : Variance
301 489 : GenericTyPerCrateCtx::evaluate (Term *term)
302 : {
303 489 : switch (term->kind)
304 : {
305 163 : case Term::CONST:
306 163 : return term->const_val;
307 163 : case Term::REF:
308 163 : return solutions[term->ref];
309 163 : case Term::TRANSFORM:
310 163 : return Variance::transform (evaluate (term->transform.lhs),
311 163 : evaluate (term->transform.rhs));
312 : }
313 0 : rust_unreachable ();
314 : }
315 :
316 : std::vector<Variance>
317 78 : GenericTyPerCrateCtx::query_generic_variance (const ADTType &type)
318 : {
319 78 : auto solution_index = lookup_type_index (type.get_orig_ref ());
320 78 : rust_assert (solution_index.has_value ());
321 78 : auto num_lifetimes = type.get_num_lifetime_params ();
322 78 : auto num_types = type.get_num_type_params ();
323 :
324 78 : std::vector<Variance> result;
325 78 : result.reserve (num_lifetimes + num_types);
326 :
327 122 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
328 : {
329 44 : result.push_back (solutions[solution_index.value () + i]);
330 : }
331 :
332 78 : return result;
333 : }
334 :
335 : FreeRegions
336 5 : GenericTyPerCrateCtx::query_field_regions (const ADTType *parent,
337 : size_t variant_index,
338 : size_t field_index,
339 : const FreeRegions &parent_regions)
340 : {
341 5 : auto orig = lookup_type (parent->get_orig_ref ());
342 5 : FieldVisitorCtx ctx (*this, *parent->as<const SubstitutionRef> (),
343 5 : parent_regions);
344 5 : return ctx.collect_regions (*orig->as<const ADTType> ()
345 5 : ->get_variants ()
346 5 : .at (variant_index)
347 5 : ->get_fields ()
348 5 : .at (field_index)
349 5 : ->get_field_type ());
350 5 : }
351 : std::vector<Region>
352 76 : GenericTyPerCrateCtx::query_type_regions (BaseType *type)
353 : {
354 76 : TyVisitorCtx ctx (*this);
355 76 : return ctx.collect_regions (*type);
356 76 : }
357 :
358 : SolutionIndex
359 3434 : GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id)
360 : {
361 3434 : BaseType *ty = lookup_type (hir_id);
362 3434 : auto index = ctx.lookup_type_index (hir_id);
363 3434 : if (index.has_value ())
364 : {
365 285 : return index.value ();
366 : }
367 :
368 3149 : SubstitutionRef *subst = nullptr;
369 3149 : if (auto adt = ty->try_as<ADTType> ())
370 : {
371 3149 : subst = adt;
372 : }
373 : else
374 : {
375 0 : rust_sorry_at (
376 : ty->get_locus (),
377 : "This is a compiler bug: Unhandled type in variance analysis");
378 : }
379 0 : rust_assert (subst != nullptr);
380 :
381 3149 : auto solution_index = ctx.solutions.size ();
382 3149 : ctx.map_from_ty_orig_ref.emplace (ty->get_orig_ref (), solution_index);
383 :
384 3149 : auto num_lifetime_param = subst->get_used_arguments ().get_regions ().size ();
385 3149 : auto num_type_param = subst->get_num_substitutions ();
386 :
387 4355 : for (size_t i = 0; i < num_lifetime_param + num_type_param; ++i)
388 1206 : ctx.solutions.emplace_back (Variance::bivariant ());
389 :
390 3149 : return solution_index;
391 : }
392 :
393 : void
394 5281 : GenericTyVisitorCtx::add_constraints_from_ty (BaseType *type, Term variance)
395 : {
396 5281 : rust_debug ("\tadd_constraint_from_ty: %s with v=%s",
397 : type->as_string ().c_str (), ctx.to_string (variance).c_str ());
398 :
399 5281 : Visitor visitor (*this, variance);
400 5281 : type->accept_vis (visitor);
401 5281 : }
402 :
403 : void
404 1459 : GenericTyVisitorCtx::add_constraint (SolutionIndex index, Term term)
405 : {
406 1459 : rust_debug ("\t\tadd_constraint: %s", ctx.to_string (term).c_str ());
407 :
408 1459 : if (term.kind == Term::CONST)
409 : {
410 : // Constant terms do not depend on other solutions, so we can
411 : // immediately apply them.
412 1320 : ctx.solutions[index].join (term.const_val);
413 : }
414 : else
415 : {
416 139 : ctx.constraints.emplace_back (index, new Term (term));
417 : }
418 1459 : }
419 :
420 : void
421 23 : GenericTyVisitorCtx::add_constraints_from_region (const Region ®ion,
422 : Term term)
423 : {
424 23 : if (region.is_early_bound ())
425 : {
426 14 : add_constraint (first_lifetime + region.get_index (), term);
427 : }
428 23 : }
429 :
430 : void
431 349 : GenericTyVisitorCtx::add_constraints_from_generic_args (HirId ref,
432 : SubstitutionRef &subst,
433 : Term variance,
434 : bool invariant_args)
435 : {
436 349 : SolutionIndex solution_index = lookup_or_add_type (ref);
437 :
438 349 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
439 349 : size_t num_types = subst.get_substs ().size ();
440 :
441 518 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
442 : {
443 : // TODO: What about variance from other crates?
444 169 : auto variance_i
445 : = invariant_args
446 169 : ? Term::make_transform (variance, Variance::invariant ())
447 169 : : Term::make_transform (variance,
448 : Term::make_ref (solution_index + i));
449 :
450 169 : if (i < num_lifetimes)
451 : {
452 1 : auto region_i = i;
453 1 : auto ®ion
454 1 : = subst.get_substitution_arguments ().get_mut_regions ()[region_i];
455 1 : add_constraints_from_region (region, variance_i);
456 : }
457 : else
458 : {
459 168 : auto type_i = i - num_lifetimes;
460 168 : auto arg = subst.get_arg_at (type_i);
461 168 : if (arg.has_value ())
462 : {
463 160 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
464 : }
465 : }
466 : }
467 349 : }
468 : void
469 1445 : GenericTyVisitorCtx::add_constrints_from_param (ParamType &type, Term variance)
470 : {
471 1445 : auto it
472 1445 : = std::find (param_names.begin (), param_names.end (), type.get_name ());
473 1445 : rust_assert (it != param_names.end ());
474 :
475 1445 : auto index = first_type + std::distance (param_names.begin (), it);
476 :
477 1445 : add_constraint (index, variance);
478 1445 : }
479 :
480 : Term
481 9 : GenericTyVisitorCtx::contra (Term variance)
482 : {
483 9 : return Term::make_transform (variance, Variance::contravariant ());
484 : }
485 :
486 : void
487 820 : TyVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
488 : {
489 820 : Visitor visitor (*this, variance);
490 820 : ty->accept_vis (visitor);
491 820 : }
492 :
493 : void
494 284 : TyVisitorCtx::add_constraints_from_region (const Region ®ion,
495 : Variance variance)
496 : {
497 284 : variances.push_back (variance);
498 284 : regions.push_back (region);
499 284 : }
500 :
501 : void
502 78 : TyVisitorCtx::add_constraints_from_generic_args (HirId ref,
503 : SubstitutionRef &subst,
504 : Variance variance,
505 : bool invariant_args)
506 : {
507 : // Handle function
508 78 : auto variances
509 78 : = ctx.query_generic_variance (*lookup_type (ref)->as<ADTType> ());
510 :
511 78 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
512 78 : size_t num_types = subst.get_substs ().size ();
513 :
514 122 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
515 : {
516 : // TODO: What about variance from other crates?
517 44 : auto variance_i
518 : = invariant_args
519 44 : ? Variance::transform (variance, Variance::invariant ())
520 44 : : Variance::transform (variance, variances[i]);
521 :
522 44 : if (i < num_lifetimes)
523 : {
524 44 : auto region_i = i;
525 44 : auto ®ion = subst.get_used_arguments ().get_regions ()[region_i];
526 44 : add_constraints_from_region (region, variance_i);
527 : }
528 : else
529 : {
530 0 : auto type_i = i - num_lifetimes;
531 0 : auto arg = subst.get_arg_at (type_i);
532 0 : if (arg.has_value ())
533 : {
534 0 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
535 : }
536 : }
537 : }
538 78 : }
539 :
540 : FreeRegions
541 5 : FieldVisitorCtx::collect_regions (BaseType &ty)
542 : {
543 : // Segment the regions into ranges for each type parameter. Type parameter
544 : // at index i contains regions from type_param_ranges[i] to
545 : // type_param_ranges[i+1] (exclusive).;
546 5 : type_param_ranges.push_back (subst.get_num_lifetime_params ());
547 :
548 5 : for (size_t i = 0; i < subst.get_num_type_params (); i++)
549 : {
550 0 : auto arg = subst.get_arg_at (i);
551 0 : rust_assert (arg.has_value ());
552 0 : type_param_ranges.push_back (
553 0 : ctx.query_type_regions (arg.value ().get_tyty ()).size ());
554 : }
555 :
556 5 : add_constraints_from_ty (&ty, Variance::covariant ());
557 5 : return regions;
558 : }
559 :
560 : void
561 8 : FieldVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
562 : {
563 8 : Visitor visitor (*this, variance);
564 8 : ty->accept_vis (visitor);
565 8 : }
566 :
567 : void
568 3 : FieldVisitorCtx::add_constraints_from_region (const Region ®ion,
569 : Variance variance)
570 : {
571 3 : if (region.is_early_bound ())
572 : {
573 3 : regions.push_back (parent_regions[region.get_index ()]);
574 : }
575 0 : else if (region.is_late_bound ())
576 : {
577 0 : rust_debug ("Ignoring late bound region");
578 : }
579 3 : }
580 :
581 : void
582 0 : FieldVisitorCtx::add_constrints_from_param (ParamType ¶m, Variance variance)
583 : {
584 0 : size_t param_i = subst.get_used_arguments ().find_symbol (param).value ();
585 0 : for (size_t i = type_param_ranges[param_i];
586 0 : i < type_param_ranges[param_i + 1]; i++)
587 : {
588 0 : regions.push_back (parent_regions[i]);
589 : }
590 0 : }
591 :
592 : Variance
593 0 : TyVisitorCtx::contra (Variance variance)
594 : {
595 0 : return Variance::transform (variance, Variance::contravariant ());
596 : }
597 :
598 : Term
599 169 : Term::make_ref (SolutionIndex index)
600 : {
601 169 : Term term;
602 169 : term.kind = REF;
603 169 : term.ref = index;
604 169 : return term;
605 : }
606 :
607 : Term
608 178 : Term::make_transform (Term lhs, Term rhs)
609 : {
610 178 : if (lhs.is_const () && rhs.is_const ())
611 : {
612 9 : return Variance::transform (lhs.const_val, rhs.const_val);
613 : }
614 :
615 169 : Term term;
616 169 : term.kind = TRANSFORM;
617 169 : term.transform.lhs = new Term (lhs);
618 169 : term.transform.rhs = new Term (rhs);
619 169 : return term;
620 : }
621 :
622 : } // namespace VarianceAnalysis
623 : } // namespace TyTy
624 : } // namespace Rust
|