Branch data Line data Source code
1 : : #include "rust-tyty-variance-analysis-private.h"
2 : : #include "rust-hir-type-check.h"
3 : :
4 : : namespace Rust {
5 : : namespace TyTy {
6 : :
7 : : BaseType *
8 : 6998 : lookup_type (HirId ref)
9 : : {
10 : 6998 : BaseType *ty = nullptr;
11 : 6998 : bool ok = Resolver::TypeCheckContext::get ()->lookup_type (ref, &ty);
12 : 6998 : rust_assert (ok);
13 : 6998 : return ty;
14 : : }
15 : :
16 : : namespace VarianceAnalysis {
17 : :
18 : 4290 : CrateCtx::CrateCtx () : private_ctx (new GenericTyPerCrateCtx ()) {}
19 : :
20 : : // Must be here because of incomplete type.
21 : 0 : CrateCtx::~CrateCtx () = default;
22 : :
23 : : void
24 : 2975 : CrateCtx::add_type_constraints (ADTType &type)
25 : : {
26 : 2975 : private_ctx->process_type (type);
27 : 2975 : }
28 : :
29 : : void
30 : 4117 : CrateCtx::solve ()
31 : : {
32 : 4117 : private_ctx->solve ();
33 : 4117 : private_ctx->debug_print_solutions ();
34 : 4117 : }
35 : :
36 : : std::vector<Variance>
37 : 0 : CrateCtx::query_generic_variance (const ADTType &type)
38 : : {
39 : 0 : return private_ctx->query_generic_variance (type);
40 : : }
41 : :
42 : : std::vector<Variance>
43 : 516 : CrateCtx::query_type_variances (BaseType *type)
44 : : {
45 : 516 : TyVisitorCtx ctx (*private_ctx);
46 : 516 : return ctx.collect_variances (*type);
47 : 516 : }
48 : :
49 : : std::vector<Region>
50 : 76 : CrateCtx::query_type_regions (BaseType *type)
51 : : {
52 : 76 : return private_ctx->query_type_regions (type);
53 : : }
54 : :
55 : : FreeRegions
56 : 5 : CrateCtx::query_field_regions (const ADTType *parent, size_t variant_index,
57 : : size_t field_index,
58 : : const FreeRegions &parent_regions)
59 : : {
60 : 5 : return private_ctx->query_field_regions (parent, variant_index, field_index,
61 : 5 : parent_regions);
62 : : }
63 : :
64 : : Variance
65 : 0 : Variance::reverse () const
66 : : {
67 : 0 : switch (kind)
68 : : {
69 : 0 : case BIVARIANT:
70 : 0 : return bivariant ();
71 : 0 : case COVARIANT:
72 : 0 : return contravariant ();
73 : 0 : case CONTRAVARIANT:
74 : 0 : return covariant ();
75 : 0 : case INVARIANT:
76 : 0 : return invariant ();
77 : : }
78 : :
79 : 0 : rust_unreachable ();
80 : : }
81 : :
82 : : Variance
83 : 1480 : Variance::join (Variance lhs, Variance rhs)
84 : : {
85 : 1480 : return {Kind (lhs.kind | rhs.kind)};
86 : : }
87 : :
88 : : void
89 : 1319 : Variance::join (Variance rhs)
90 : : {
91 : 1319 : *this = join (*this, rhs);
92 : 1319 : }
93 : :
94 : : Variance
95 : 212 : Variance::transform (Variance lhs, Variance rhs)
96 : : {
97 : 212 : switch (lhs.kind)
98 : : {
99 : 0 : case BIVARIANT:
100 : 0 : return bivariant ();
101 : 210 : case COVARIANT:
102 : 210 : return rhs;
103 : 0 : case CONTRAVARIANT:
104 : 0 : return rhs.reverse ();
105 : 2 : case INVARIANT:
106 : 2 : return invariant ();
107 : : }
108 : 0 : rust_unreachable ();
109 : : }
110 : :
111 : : std::string
112 : 8007 : Variance::as_string () const
113 : : {
114 : 8007 : switch (kind)
115 : : {
116 : 218 : case BIVARIANT:
117 : 218 : return "o";
118 : 7487 : case COVARIANT:
119 : 7487 : return "+";
120 : 7 : case CONTRAVARIANT:
121 : 7 : return "-";
122 : 295 : case INVARIANT:
123 : 295 : return "*";
124 : : }
125 : 0 : rust_unreachable ();
126 : : }
127 : :
128 : : void
129 : 2975 : GenericTyPerCrateCtx::process_type (ADTType &type)
130 : : {
131 : 2975 : GenericTyVisitorCtx (*this).process_type (type);
132 : 2975 : }
133 : :
134 : : void
135 : 4117 : GenericTyPerCrateCtx::solve ()
136 : : {
137 : 4117 : rust_debug ("Variance analysis solving started:");
138 : :
139 : : // Fix point iteration
140 : 4117 : bool changed = true;
141 : 12375 : while (changed)
142 : : {
143 : 4141 : changed = false;
144 : 4302 : for (auto constraint : constraints)
145 : : {
146 : 161 : rust_debug ("\tapplying constraint: %s <= %s",
147 : : to_string (constraint.target_index).c_str (),
148 : : to_string (*constraint.term).c_str ());
149 : :
150 : 161 : auto old_solution = solutions[constraint.target_index];
151 : 161 : auto new_solution
152 : 161 : = Variance::join (old_solution, evaluate (constraint.term));
153 : :
154 : 161 : if (old_solution != new_solution)
155 : : {
156 : 24 : rust_debug ("\t\tsolution changed: %s => %s",
157 : : old_solution.as_string ().c_str (),
158 : : new_solution.as_string ().c_str ());
159 : :
160 : 24 : changed = true;
161 : 24 : solutions[constraint.target_index] = new_solution;
162 : : }
163 : : }
164 : : }
165 : :
166 : 4117 : constraints.clear ();
167 : 4117 : constraints.shrink_to_fit ();
168 : 4117 : }
169 : :
170 : : void
171 : 4117 : GenericTyPerCrateCtx::debug_print_solutions ()
172 : : {
173 : 4117 : rust_debug ("Variance analysis results:");
174 : :
175 : 7092 : for (auto type : map_from_ty_orig_ref)
176 : : {
177 : 2975 : auto solution_index = type.second;
178 : 2975 : auto ref = type.first;
179 : :
180 : 2975 : BaseType *ty = lookup_type (ref);
181 : :
182 : 2975 : std::string result = "\t";
183 : :
184 : 2975 : if (auto adt = ty->try_as<ADTType> ())
185 : : {
186 : 8925 : result += adt->get_identifier ();
187 : 2975 : result += "<";
188 : :
189 : 2975 : size_t i = solution_index;
190 : 2994 : for (auto ®ion : adt->get_used_arguments ().get_regions ())
191 : : {
192 : 19 : (void) region;
193 : 19 : if (i > solution_index)
194 : 1 : result += ", ";
195 : 38 : result += solutions[i].as_string ();
196 : 19 : i++;
197 : : }
198 : 4123 : for (auto ¶m : adt->get_substs ())
199 : : {
200 : 1148 : if (i > solution_index)
201 : 129 : result += ", ";
202 : 2296 : result += param.get_type_representation ().as_string ();
203 : 1148 : result += "=";
204 : 2296 : result += solutions[i].as_string ();
205 : 1148 : i++;
206 : : }
207 : :
208 : 2975 : result += ">";
209 : : }
210 : : else
211 : : {
212 : 0 : rust_sorry_at (
213 : : ty->get_ref (),
214 : : "This is a compiler bug: Unhandled type in variance analysis");
215 : : }
216 : 2975 : rust_debug ("%s", result.c_str ());
217 : 2975 : }
218 : 4117 : }
219 : :
220 : : tl::optional<SolutionIndex>
221 : 3399 : GenericTyPerCrateCtx::lookup_type_index (HirId orig_ref)
222 : : {
223 : 3399 : auto it = map_from_ty_orig_ref.find (orig_ref);
224 : 3399 : if (it != map_from_ty_orig_ref.end ())
225 : : {
226 : 360 : return it->second;
227 : : }
228 : 3039 : return tl::nullopt;
229 : : }
230 : :
231 : : void
232 : 2975 : GenericTyVisitorCtx::process_type (ADTType &ty)
233 : : {
234 : 2975 : rust_debug ("add_type_constraints: %s", ty.as_string ().c_str ());
235 : :
236 : 2975 : first_lifetime = lookup_or_add_type (ty.get_orig_ref ());
237 : 2975 : first_type = first_lifetime + ty.get_used_arguments ().get_regions ().size ();
238 : :
239 : 4123 : for (auto ¶m : ty.get_substs ())
240 : 1148 : param_names.push_back (param.get_type_representation ().as_string ());
241 : :
242 : 6635 : for (const auto &variant : ty.get_variants ())
243 : : {
244 : 3660 : if (variant->get_variant_type () != VariantDef::NUM)
245 : : {
246 : 7166 : for (const auto &field : variant->get_fields ())
247 : 4203 : add_constraints_from_ty (field->get_field_type (),
248 : 4203 : Variance::covariant ());
249 : : }
250 : : }
251 : 2975 : }
252 : :
253 : : std::string
254 : 7708 : GenericTyPerCrateCtx::to_string (const Term &term) const
255 : : {
256 : 7708 : switch (term.kind)
257 : : {
258 : 6792 : case Term::CONST:
259 : 6792 : return term.const_val.as_string ();
260 : 458 : case Term::REF:
261 : 916 : return "v(" + to_string (term.ref) + ")";
262 : 458 : case Term::TRANSFORM:
263 : 916 : return "(" + to_string (*term.transform.lhs) + " x "
264 : 1832 : + to_string (*term.transform.rhs) + ")";
265 : : }
266 : 0 : rust_unreachable ();
267 : : }
268 : :
269 : : std::string
270 : 619 : GenericTyPerCrateCtx::to_string (SolutionIndex index) const
271 : : {
272 : : // Search all values in def_id_to_solution_index_start and find key for
273 : : // largest value smaller than index
274 : 619 : std::pair<HirId, SolutionIndex> best = {0, 0};
275 : :
276 : 2798 : for (const auto &ty_map : map_from_ty_orig_ref)
277 : : {
278 : 2179 : if (ty_map.second <= index && ty_map.first > best.first)
279 : : best = ty_map;
280 : : }
281 : 619 : rust_assert (best.first != 0);
282 : :
283 : 619 : BaseType *ty = lookup_type (best.first);
284 : :
285 : 619 : std::string result = "";
286 : 619 : if (auto adt = ty->try_as<ADTType> ())
287 : : {
288 : 1857 : result += (adt->get_identifier ());
289 : : }
290 : : else
291 : : {
292 : 0 : result += ty->as_string ();
293 : : }
294 : :
295 : 1857 : result += "[" + std::to_string (index - best.first) + "]";
296 : 619 : return result;
297 : : }
298 : :
299 : : Variance
300 : 483 : GenericTyPerCrateCtx::evaluate (Term *term)
301 : : {
302 : 483 : switch (term->kind)
303 : : {
304 : 161 : case Term::CONST:
305 : 161 : return term->const_val;
306 : 161 : case Term::REF:
307 : 161 : return solutions[term->ref];
308 : 161 : case Term::TRANSFORM:
309 : 161 : return Variance::transform (evaluate (term->transform.lhs),
310 : 161 : evaluate (term->transform.rhs));
311 : : }
312 : 0 : rust_unreachable ();
313 : : }
314 : :
315 : : std::vector<Variance>
316 : 78 : GenericTyPerCrateCtx::query_generic_variance (const ADTType &type)
317 : : {
318 : 78 : auto solution_index = lookup_type_index (type.get_orig_ref ());
319 : 78 : rust_assert (solution_index.has_value ());
320 : 78 : auto num_lifetimes = type.get_num_lifetime_params ();
321 : 78 : auto num_types = type.get_num_type_params ();
322 : :
323 : 78 : std::vector<Variance> result;
324 : 122 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
325 : : {
326 : 44 : result.push_back (solutions[solution_index.value () + i]);
327 : : }
328 : :
329 : 78 : return result;
330 : : }
331 : :
332 : : FreeRegions
333 : 5 : GenericTyPerCrateCtx::query_field_regions (const ADTType *parent,
334 : : size_t variant_index,
335 : : size_t field_index,
336 : : const FreeRegions &parent_regions)
337 : : {
338 : 5 : auto orig = lookup_type (parent->get_orig_ref ());
339 : 5 : FieldVisitorCtx ctx (*this, *parent->as<const SubstitutionRef> (),
340 : 5 : parent_regions);
341 : 5 : return ctx.collect_regions (*orig->as<const ADTType> ()
342 : 5 : ->get_variants ()
343 : 5 : .at (variant_index)
344 : 5 : ->get_fields ()
345 : 5 : .at (field_index)
346 : 5 : ->get_field_type ());
347 : 5 : }
348 : : std::vector<Region>
349 : 76 : GenericTyPerCrateCtx::query_type_regions (BaseType *type)
350 : : {
351 : 76 : TyVisitorCtx ctx (*this);
352 : 76 : return ctx.collect_regions (*type);
353 : 76 : }
354 : :
355 : : SolutionIndex
356 : 3321 : GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id)
357 : : {
358 : 3321 : BaseType *ty = lookup_type (hir_id);
359 : 3321 : auto index = ctx.lookup_type_index (hir_id);
360 : 3321 : if (index.has_value ())
361 : : {
362 : 282 : return index.value ();
363 : : }
364 : :
365 : 3039 : SubstitutionRef *subst = nullptr;
366 : 3039 : if (auto adt = ty->try_as<ADTType> ())
367 : : {
368 : 3039 : subst = adt;
369 : : }
370 : : else
371 : : {
372 : 0 : rust_sorry_at (
373 : : ty->get_locus (),
374 : : "This is a compiler bug: Unhandled type in variance analysis");
375 : : }
376 : 0 : rust_assert (subst != nullptr);
377 : :
378 : 3039 : auto solution_index = ctx.solutions.size ();
379 : 3039 : ctx.map_from_ty_orig_ref.emplace (ty->get_orig_ref (), solution_index);
380 : :
381 : 3039 : auto num_lifetime_param = subst->get_used_arguments ().get_regions ().size ();
382 : 3039 : auto num_type_param = subst->get_num_substitutions ();
383 : :
384 : 4207 : for (size_t i = 0; i < num_lifetime_param + num_type_param; ++i)
385 : 1168 : ctx.solutions.emplace_back (Variance::bivariant ());
386 : :
387 : 3039 : return solution_index;
388 : : }
389 : :
390 : : void
391 : 5175 : GenericTyVisitorCtx::add_constraints_from_ty (BaseType *type, Term variance)
392 : : {
393 : 5175 : rust_debug ("\tadd_constraint_from_ty: %s with v=%s",
394 : : type->as_string ().c_str (), ctx.to_string (variance).c_str ());
395 : :
396 : 5175 : Visitor visitor (*this, variance);
397 : 5175 : type->accept_vis (visitor);
398 : 5175 : }
399 : :
400 : : void
401 : 1456 : GenericTyVisitorCtx::add_constraint (SolutionIndex index, Term term)
402 : : {
403 : 1456 : rust_debug ("\t\tadd_constraint: %s", ctx.to_string (term).c_str ());
404 : :
405 : 1456 : if (term.kind == Term::CONST)
406 : : {
407 : : // Constant terms do not depend on other solutions, so we can
408 : : // immediately apply them.
409 : 1319 : ctx.solutions[index].join (term.const_val);
410 : : }
411 : : else
412 : : {
413 : 137 : ctx.constraints.push_back ({index, new Term (term)});
414 : : }
415 : 1456 : }
416 : :
417 : : void
418 : 23 : GenericTyVisitorCtx::add_constraints_from_region (const Region ®ion,
419 : : Term term)
420 : : {
421 : 23 : if (region.is_early_bound ())
422 : : {
423 : 14 : add_constraint (first_lifetime + region.get_index (), term);
424 : : }
425 : 23 : }
426 : :
427 : : void
428 : 346 : GenericTyVisitorCtx::add_constraints_from_generic_args (HirId ref,
429 : : SubstitutionRef &subst,
430 : : Term variance,
431 : : bool invariant_args)
432 : : {
433 : 346 : SolutionIndex solution_index = lookup_or_add_type (ref);
434 : :
435 : 346 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
436 : 346 : size_t num_types = subst.get_substs ().size ();
437 : :
438 : 513 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
439 : : {
440 : : // TODO: What about variance from other crates?
441 : 167 : auto variance_i
442 : : = invariant_args
443 : 167 : ? Term::make_transform (variance, Variance::invariant ())
444 : 167 : : Term::make_transform (variance,
445 : : Term::make_ref (solution_index + i));
446 : :
447 : 167 : if (i < num_lifetimes)
448 : : {
449 : 1 : auto region_i = i;
450 : 1 : auto ®ion
451 : 1 : = subst.get_substitution_arguments ().get_mut_regions ()[region_i];
452 : 1 : add_constraints_from_region (region, variance_i);
453 : : }
454 : : else
455 : : {
456 : 166 : auto type_i = i - num_lifetimes;
457 : 166 : auto arg = subst.get_arg_at (type_i);
458 : 166 : if (arg.has_value ())
459 : : {
460 : 158 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
461 : : }
462 : : }
463 : : }
464 : 346 : }
465 : : void
466 : 1442 : GenericTyVisitorCtx::add_constrints_from_param (ParamType &type, Term variance)
467 : : {
468 : 1442 : auto it
469 : 1442 : = std::find (param_names.begin (), param_names.end (), type.get_name ());
470 : 1442 : rust_assert (it != param_names.end ());
471 : :
472 : 1442 : auto index = first_type + std::distance (param_names.begin (), it);
473 : :
474 : 1442 : add_constraint (index, variance);
475 : 1442 : }
476 : :
477 : : Term
478 : 7 : GenericTyVisitorCtx::contra (Term variance)
479 : : {
480 : 7 : return Term::make_transform (variance, Variance::contravariant ());
481 : : }
482 : :
483 : : void
484 : 820 : TyVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
485 : : {
486 : 820 : Visitor visitor (*this, variance);
487 : 820 : ty->accept_vis (visitor);
488 : 820 : }
489 : :
490 : : void
491 : 284 : TyVisitorCtx::add_constraints_from_region (const Region ®ion,
492 : : Variance variance)
493 : : {
494 : 284 : variances.push_back (variance);
495 : 284 : regions.push_back (region);
496 : 284 : }
497 : :
498 : : void
499 : 78 : TyVisitorCtx::add_constraints_from_generic_args (HirId ref,
500 : : SubstitutionRef &subst,
501 : : Variance variance,
502 : : bool invariant_args)
503 : : {
504 : : // Handle function
505 : 78 : auto variances
506 : 78 : = ctx.query_generic_variance (*lookup_type (ref)->as<ADTType> ());
507 : :
508 : 78 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
509 : 78 : size_t num_types = subst.get_substs ().size ();
510 : :
511 : 122 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
512 : : {
513 : : // TODO: What about variance from other crates?
514 : 44 : auto variance_i
515 : : = invariant_args
516 : 44 : ? Variance::transform (variance, Variance::invariant ())
517 : 44 : : Variance::transform (variance, variances[i]);
518 : :
519 : 44 : if (i < num_lifetimes)
520 : : {
521 : 44 : auto region_i = i;
522 : 44 : auto ®ion = subst.get_used_arguments ().get_regions ()[region_i];
523 : 44 : add_constraints_from_region (region, variance_i);
524 : : }
525 : : else
526 : : {
527 : 0 : auto type_i = i - num_lifetimes;
528 : 0 : auto arg = subst.get_arg_at (type_i);
529 : 0 : if (arg.has_value ())
530 : : {
531 : 0 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
532 : : }
533 : : }
534 : : }
535 : 78 : }
536 : :
537 : : FreeRegions
538 : 5 : FieldVisitorCtx::collect_regions (BaseType &ty)
539 : : {
540 : : // Segment the regions into ranges for each type parameter. Type parameter
541 : : // at index i contains regions from type_param_ranges[i] to
542 : : // type_param_ranges[i+1] (exclusive).;
543 : 5 : type_param_ranges.push_back (subst.get_num_lifetime_params ());
544 : :
545 : 5 : for (size_t i = 0; i < subst.get_num_type_params (); i++)
546 : : {
547 : 0 : auto arg = subst.get_arg_at (i);
548 : 0 : rust_assert (arg.has_value ());
549 : 0 : type_param_ranges.push_back (
550 : 0 : ctx.query_type_regions (arg.value ().get_tyty ()).size ());
551 : : }
552 : :
553 : 5 : add_constraints_from_ty (&ty, Variance::covariant ());
554 : 5 : return regions;
555 : : }
556 : :
557 : : void
558 : 8 : FieldVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
559 : : {
560 : 8 : Visitor visitor (*this, variance);
561 : 8 : ty->accept_vis (visitor);
562 : 8 : }
563 : :
564 : : void
565 : 3 : FieldVisitorCtx::add_constraints_from_region (const Region ®ion,
566 : : Variance variance)
567 : : {
568 : 3 : if (region.is_early_bound ())
569 : : {
570 : 3 : regions.push_back (parent_regions[region.get_index ()]);
571 : : }
572 : 0 : else if (region.is_late_bound ())
573 : : {
574 : 0 : rust_debug ("Ignoring late bound region");
575 : : }
576 : 3 : }
577 : :
578 : : void
579 : 0 : FieldVisitorCtx::add_constrints_from_param (ParamType ¶m, Variance variance)
580 : : {
581 : 0 : size_t param_i = subst.get_used_arguments ().find_symbol (param).value ();
582 : 0 : for (size_t i = type_param_ranges[param_i];
583 : 0 : i < type_param_ranges[param_i + 1]; i++)
584 : : {
585 : 0 : regions.push_back (parent_regions[i]);
586 : : }
587 : 0 : }
588 : :
589 : : Variance
590 : 0 : TyVisitorCtx::contra (Variance variance)
591 : : {
592 : 0 : return Variance::transform (variance, Variance::contravariant ());
593 : : }
594 : :
595 : : Term
596 : 167 : Term::make_ref (SolutionIndex index)
597 : : {
598 : 167 : Term term;
599 : 167 : term.kind = REF;
600 : 167 : term.ref = index;
601 : 167 : return term;
602 : : }
603 : :
604 : : Term
605 : 174 : Term::make_transform (Term lhs, Term rhs)
606 : : {
607 : 174 : if (lhs.is_const () && rhs.is_const ())
608 : : {
609 : 7 : return Variance::transform (lhs.const_val, rhs.const_val);
610 : : }
611 : :
612 : 167 : Term term;
613 : 167 : term.kind = TRANSFORM;
614 : 167 : term.transform.lhs = new Term (lhs);
615 : 167 : term.transform.rhs = new Term (rhs);
616 : 167 : return term;
617 : : }
618 : :
619 : : } // namespace VarianceAnalysis
620 : : } // namespace TyTy
621 : : } // namespace Rust
|