Line data Source code
1 : #include "rust-tyty-variance-analysis-private.h"
2 : #include "rust-hir-type-check.h"
3 :
4 : namespace Rust {
5 : namespace TyTy {
6 :
7 : BaseType *
8 7195 : lookup_type (HirId ref)
9 : {
10 7195 : BaseType *ty = nullptr;
11 7195 : bool ok = Resolver::TypeCheckContext::get ()->lookup_type (ref, &ty);
12 7195 : rust_assert (ok);
13 7195 : return ty;
14 : }
15 :
16 : namespace VarianceAnalysis {
17 :
18 4510 : CrateCtx::CrateCtx () : private_ctx (new GenericTyPerCrateCtx ()) {}
19 :
20 : // Must be here because of incomplete type.
21 0 : CrateCtx::~CrateCtx () = default;
22 :
23 : void
24 3068 : CrateCtx::add_type_constraints (ADTType &type)
25 : {
26 3068 : private_ctx->process_type (type);
27 3068 : }
28 :
29 : void
30 4300 : CrateCtx::solve ()
31 : {
32 4300 : private_ctx->solve ();
33 4300 : private_ctx->debug_print_solutions ();
34 4300 : }
35 :
36 : std::vector<Variance>
37 0 : CrateCtx::query_generic_variance (const ADTType &type)
38 : {
39 0 : return private_ctx->query_generic_variance (type);
40 : }
41 :
42 : std::vector<Variance>
43 516 : CrateCtx::query_type_variances (BaseType *type)
44 : {
45 516 : TyVisitorCtx ctx (*private_ctx);
46 516 : return ctx.collect_variances (*type);
47 516 : }
48 :
49 : std::vector<Region>
50 76 : CrateCtx::query_type_regions (BaseType *type)
51 : {
52 76 : return private_ctx->query_type_regions (type);
53 : }
54 :
55 : FreeRegions
56 5 : CrateCtx::query_field_regions (const ADTType *parent, size_t variant_index,
57 : size_t field_index,
58 : const FreeRegions &parent_regions)
59 : {
60 5 : return private_ctx->query_field_regions (parent, variant_index, field_index,
61 5 : parent_regions);
62 : }
63 :
64 : Variance
65 0 : Variance::reverse () const
66 : {
67 0 : switch (kind)
68 : {
69 0 : case BIVARIANT:
70 0 : return bivariant ();
71 0 : case COVARIANT:
72 0 : return contravariant ();
73 0 : case CONTRAVARIANT:
74 0 : return covariant ();
75 0 : case INVARIANT:
76 0 : return invariant ();
77 : }
78 :
79 0 : rust_unreachable ();
80 : }
81 :
82 : Variance
83 1483 : Variance::join (Variance lhs, Variance rhs)
84 : {
85 1483 : return {Kind (lhs.kind | rhs.kind)};
86 : }
87 :
88 : void
89 1320 : Variance::join (Variance rhs)
90 : {
91 1320 : *this = join (*this, rhs);
92 1320 : }
93 :
94 : Variance
95 216 : Variance::transform (Variance lhs, Variance rhs)
96 : {
97 216 : switch (lhs.kind)
98 : {
99 0 : case BIVARIANT:
100 0 : return bivariant ();
101 214 : case COVARIANT:
102 214 : return rhs;
103 0 : case CONTRAVARIANT:
104 0 : return rhs.reverse ();
105 2 : case INVARIANT:
106 2 : return invariant ();
107 : }
108 0 : rust_unreachable ();
109 : }
110 :
111 : std::string
112 8139 : Variance::as_string () const
113 : {
114 8139 : switch (kind)
115 : {
116 254 : case BIVARIANT:
117 254 : return "o";
118 7583 : case COVARIANT:
119 7583 : return "+";
120 7 : case CONTRAVARIANT:
121 7 : return "-";
122 295 : case INVARIANT:
123 295 : return "*";
124 : }
125 0 : rust_unreachable ();
126 : }
127 :
128 : void
129 3068 : GenericTyPerCrateCtx::process_type (ADTType &type)
130 : {
131 3068 : GenericTyVisitorCtx (*this).process_type (type);
132 3068 : }
133 :
134 : void
135 4300 : GenericTyPerCrateCtx::solve ()
136 : {
137 4300 : rust_debug ("Variance analysis solving started:");
138 :
139 : // Fix point iteration
140 4300 : bool changed = true;
141 12924 : while (changed)
142 : {
143 4324 : changed = false;
144 4487 : for (auto constraint : constraints)
145 : {
146 163 : rust_debug ("\tapplying constraint: %s <= %s",
147 : to_string (constraint.target_index).c_str (),
148 : to_string (*constraint.term).c_str ());
149 :
150 163 : auto old_solution = solutions[constraint.target_index];
151 163 : auto new_solution
152 163 : = Variance::join (old_solution, evaluate (constraint.term));
153 :
154 163 : if (old_solution != new_solution)
155 : {
156 24 : rust_debug ("\t\tsolution changed: %s => %s",
157 : old_solution.as_string ().c_str (),
158 : new_solution.as_string ().c_str ());
159 :
160 24 : changed = true;
161 24 : solutions[constraint.target_index] = new_solution;
162 : }
163 : }
164 : }
165 :
166 4300 : constraints.clear ();
167 4300 : constraints.shrink_to_fit ();
168 4300 : }
169 :
170 : void
171 4300 : GenericTyPerCrateCtx::debug_print_solutions ()
172 : {
173 4300 : rust_debug ("Variance analysis results:");
174 :
175 7368 : for (auto type : map_from_ty_orig_ref)
176 : {
177 3068 : auto solution_index = type.second;
178 3068 : auto ref = type.first;
179 :
180 3068 : BaseType *ty = lookup_type (ref);
181 :
182 3068 : std::string result = "\t";
183 :
184 3068 : if (auto adt = ty->try_as<ADTType> ())
185 : {
186 9204 : result += adt->get_identifier ();
187 3068 : result += "<";
188 :
189 3068 : size_t i = solution_index;
190 3089 : for (auto ®ion : adt->get_used_arguments ().get_regions ())
191 : {
192 21 : (void) region;
193 21 : if (i > solution_index)
194 1 : result += ", ";
195 42 : result += solutions[i].as_string ();
196 21 : i++;
197 : }
198 4251 : for (auto ¶m : adt->get_substs ())
199 : {
200 1183 : if (i > solution_index)
201 129 : result += ", ";
202 2366 : result += param.get_type_representation ().as_string ();
203 1183 : result += "=";
204 2366 : result += solutions[i].as_string ();
205 1183 : i++;
206 : }
207 :
208 3068 : result += ">";
209 : }
210 : else
211 : {
212 0 : rust_sorry_at (
213 : ty->get_ref (),
214 : "This is a compiler bug: Unhandled type in variance analysis");
215 : }
216 3068 : rust_debug ("%s", result.c_str ());
217 3068 : }
218 4300 : }
219 :
220 : tl::optional<SolutionIndex>
221 3495 : GenericTyPerCrateCtx::lookup_type_index (HirId orig_ref)
222 : {
223 3495 : auto it = map_from_ty_orig_ref.find (orig_ref);
224 3495 : if (it != map_from_ty_orig_ref.end ())
225 : {
226 363 : return it->second;
227 : }
228 3132 : return tl::nullopt;
229 : }
230 :
231 : void
232 3068 : GenericTyVisitorCtx::process_type (ADTType &ty)
233 : {
234 3068 : rust_debug ("add_type_constraints: %s", ty.as_string ().c_str ());
235 :
236 3068 : first_lifetime = lookup_or_add_type (ty.get_orig_ref ());
237 3068 : first_type = first_lifetime + ty.get_used_arguments ().get_regions ().size ();
238 :
239 4251 : for (auto ¶m : ty.get_substs ())
240 1183 : param_names.push_back (param.get_type_representation ().as_string ());
241 :
242 6826 : for (const auto &variant : ty.get_variants ())
243 : {
244 3758 : if (variant->get_variant_type () != VariantDef::NUM)
245 : {
246 7343 : for (const auto &field : variant->get_fields ())
247 4288 : add_constraints_from_ty (field->get_field_type (),
248 4288 : Variance::covariant ());
249 : }
250 : }
251 3068 : }
252 :
253 : std::string
254 7815 : GenericTyPerCrateCtx::to_string (const Term &term) const
255 : {
256 7815 : switch (term.kind)
257 : {
258 6887 : case Term::CONST:
259 6887 : return term.const_val.as_string ();
260 464 : case Term::REF:
261 928 : return "v(" + to_string (term.ref) + ")";
262 464 : case Term::TRANSFORM:
263 928 : return "(" + to_string (*term.transform.lhs) + " x "
264 1856 : + to_string (*term.transform.rhs) + ")";
265 : }
266 0 : rust_unreachable ();
267 : }
268 :
269 : std::string
270 627 : GenericTyPerCrateCtx::to_string (SolutionIndex index) const
271 : {
272 : // Search all values in def_id_to_solution_index_start and find key for
273 : // largest value smaller than index
274 627 : std::pair<HirId, SolutionIndex> best = {0, 0};
275 :
276 2830 : for (const auto &ty_map : map_from_ty_orig_ref)
277 : {
278 2203 : if (ty_map.second <= index && ty_map.first > best.first)
279 : best = ty_map;
280 : }
281 627 : rust_assert (best.first != 0);
282 :
283 627 : BaseType *ty = lookup_type (best.first);
284 :
285 627 : std::string result = "";
286 627 : if (auto adt = ty->try_as<ADTType> ())
287 : {
288 1881 : result += (adt->get_identifier ());
289 : }
290 : else
291 : {
292 0 : result += ty->as_string ();
293 : }
294 :
295 1881 : result += "[" + std::to_string (index - best.first) + "]";
296 627 : return result;
297 : }
298 :
299 : Variance
300 489 : GenericTyPerCrateCtx::evaluate (Term *term)
301 : {
302 489 : switch (term->kind)
303 : {
304 163 : case Term::CONST:
305 163 : return term->const_val;
306 163 : case Term::REF:
307 163 : return solutions[term->ref];
308 163 : case Term::TRANSFORM:
309 163 : return Variance::transform (evaluate (term->transform.lhs),
310 163 : evaluate (term->transform.rhs));
311 : }
312 0 : rust_unreachable ();
313 : }
314 :
315 : std::vector<Variance>
316 78 : GenericTyPerCrateCtx::query_generic_variance (const ADTType &type)
317 : {
318 78 : auto solution_index = lookup_type_index (type.get_orig_ref ());
319 78 : rust_assert (solution_index.has_value ());
320 78 : auto num_lifetimes = type.get_num_lifetime_params ();
321 78 : auto num_types = type.get_num_type_params ();
322 :
323 78 : std::vector<Variance> result;
324 78 : result.reserve (num_lifetimes + num_types);
325 :
326 122 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
327 : {
328 44 : result.push_back (solutions[solution_index.value () + i]);
329 : }
330 :
331 78 : return result;
332 : }
333 :
334 : FreeRegions
335 5 : GenericTyPerCrateCtx::query_field_regions (const ADTType *parent,
336 : size_t variant_index,
337 : size_t field_index,
338 : const FreeRegions &parent_regions)
339 : {
340 5 : auto orig = lookup_type (parent->get_orig_ref ());
341 5 : FieldVisitorCtx ctx (*this, *parent->as<const SubstitutionRef> (),
342 5 : parent_regions);
343 5 : return ctx.collect_regions (*orig->as<const ADTType> ()
344 5 : ->get_variants ()
345 5 : .at (variant_index)
346 5 : ->get_fields ()
347 5 : .at (field_index)
348 5 : ->get_field_type ());
349 5 : }
350 : std::vector<Region>
351 76 : GenericTyPerCrateCtx::query_type_regions (BaseType *type)
352 : {
353 76 : TyVisitorCtx ctx (*this);
354 76 : return ctx.collect_regions (*type);
355 76 : }
356 :
357 : SolutionIndex
358 3417 : GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id)
359 : {
360 3417 : BaseType *ty = lookup_type (hir_id);
361 3417 : auto index = ctx.lookup_type_index (hir_id);
362 3417 : if (index.has_value ())
363 : {
364 285 : return index.value ();
365 : }
366 :
367 3132 : SubstitutionRef *subst = nullptr;
368 3132 : if (auto adt = ty->try_as<ADTType> ())
369 : {
370 3132 : subst = adt;
371 : }
372 : else
373 : {
374 0 : rust_sorry_at (
375 : ty->get_locus (),
376 : "This is a compiler bug: Unhandled type in variance analysis");
377 : }
378 0 : rust_assert (subst != nullptr);
379 :
380 3132 : auto solution_index = ctx.solutions.size ();
381 3132 : ctx.map_from_ty_orig_ref.emplace (ty->get_orig_ref (), solution_index);
382 :
383 3132 : auto num_lifetime_param = subst->get_used_arguments ().get_regions ().size ();
384 3132 : auto num_type_param = subst->get_num_substitutions ();
385 :
386 4337 : for (size_t i = 0; i < num_lifetime_param + num_type_param; ++i)
387 1205 : ctx.solutions.emplace_back (Variance::bivariant ());
388 :
389 3132 : return solution_index;
390 : }
391 :
392 : void
393 5265 : GenericTyVisitorCtx::add_constraints_from_ty (BaseType *type, Term variance)
394 : {
395 5265 : rust_debug ("\tadd_constraint_from_ty: %s with v=%s",
396 : type->as_string ().c_str (), ctx.to_string (variance).c_str ());
397 :
398 5265 : Visitor visitor (*this, variance);
399 5265 : type->accept_vis (visitor);
400 5265 : }
401 :
402 : void
403 1459 : GenericTyVisitorCtx::add_constraint (SolutionIndex index, Term term)
404 : {
405 1459 : rust_debug ("\t\tadd_constraint: %s", ctx.to_string (term).c_str ());
406 :
407 1459 : if (term.kind == Term::CONST)
408 : {
409 : // Constant terms do not depend on other solutions, so we can
410 : // immediately apply them.
411 1320 : ctx.solutions[index].join (term.const_val);
412 : }
413 : else
414 : {
415 139 : ctx.constraints.emplace_back (index, new Term (term));
416 : }
417 1459 : }
418 :
419 : void
420 23 : GenericTyVisitorCtx::add_constraints_from_region (const Region ®ion,
421 : Term term)
422 : {
423 23 : if (region.is_early_bound ())
424 : {
425 14 : add_constraint (first_lifetime + region.get_index (), term);
426 : }
427 23 : }
428 :
429 : void
430 349 : GenericTyVisitorCtx::add_constraints_from_generic_args (HirId ref,
431 : SubstitutionRef &subst,
432 : Term variance,
433 : bool invariant_args)
434 : {
435 349 : SolutionIndex solution_index = lookup_or_add_type (ref);
436 :
437 349 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
438 349 : size_t num_types = subst.get_substs ().size ();
439 :
440 518 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
441 : {
442 : // TODO: What about variance from other crates?
443 169 : auto variance_i
444 : = invariant_args
445 169 : ? Term::make_transform (variance, Variance::invariant ())
446 169 : : Term::make_transform (variance,
447 : Term::make_ref (solution_index + i));
448 :
449 169 : if (i < num_lifetimes)
450 : {
451 1 : auto region_i = i;
452 1 : auto ®ion
453 1 : = subst.get_substitution_arguments ().get_mut_regions ()[region_i];
454 1 : add_constraints_from_region (region, variance_i);
455 : }
456 : else
457 : {
458 168 : auto type_i = i - num_lifetimes;
459 168 : auto arg = subst.get_arg_at (type_i);
460 168 : if (arg.has_value ())
461 : {
462 160 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
463 : }
464 : }
465 : }
466 349 : }
467 : void
468 1445 : GenericTyVisitorCtx::add_constrints_from_param (ParamType &type, Term variance)
469 : {
470 1445 : auto it
471 1445 : = std::find (param_names.begin (), param_names.end (), type.get_name ());
472 1445 : rust_assert (it != param_names.end ());
473 :
474 1445 : auto index = first_type + std::distance (param_names.begin (), it);
475 :
476 1445 : add_constraint (index, variance);
477 1445 : }
478 :
479 : Term
480 9 : GenericTyVisitorCtx::contra (Term variance)
481 : {
482 9 : return Term::make_transform (variance, Variance::contravariant ());
483 : }
484 :
485 : void
486 820 : TyVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
487 : {
488 820 : Visitor visitor (*this, variance);
489 820 : ty->accept_vis (visitor);
490 820 : }
491 :
492 : void
493 284 : TyVisitorCtx::add_constraints_from_region (const Region ®ion,
494 : Variance variance)
495 : {
496 284 : variances.push_back (variance);
497 284 : regions.push_back (region);
498 284 : }
499 :
500 : void
501 78 : TyVisitorCtx::add_constraints_from_generic_args (HirId ref,
502 : SubstitutionRef &subst,
503 : Variance variance,
504 : bool invariant_args)
505 : {
506 : // Handle function
507 78 : auto variances
508 78 : = ctx.query_generic_variance (*lookup_type (ref)->as<ADTType> ());
509 :
510 78 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
511 78 : size_t num_types = subst.get_substs ().size ();
512 :
513 122 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
514 : {
515 : // TODO: What about variance from other crates?
516 44 : auto variance_i
517 : = invariant_args
518 44 : ? Variance::transform (variance, Variance::invariant ())
519 44 : : Variance::transform (variance, variances[i]);
520 :
521 44 : if (i < num_lifetimes)
522 : {
523 44 : auto region_i = i;
524 44 : auto ®ion = subst.get_used_arguments ().get_regions ()[region_i];
525 44 : add_constraints_from_region (region, variance_i);
526 : }
527 : else
528 : {
529 0 : auto type_i = i - num_lifetimes;
530 0 : auto arg = subst.get_arg_at (type_i);
531 0 : if (arg.has_value ())
532 : {
533 0 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
534 : }
535 : }
536 : }
537 78 : }
538 :
539 : FreeRegions
540 5 : FieldVisitorCtx::collect_regions (BaseType &ty)
541 : {
542 : // Segment the regions into ranges for each type parameter. Type parameter
543 : // at index i contains regions from type_param_ranges[i] to
544 : // type_param_ranges[i+1] (exclusive).;
545 5 : type_param_ranges.push_back (subst.get_num_lifetime_params ());
546 :
547 5 : for (size_t i = 0; i < subst.get_num_type_params (); i++)
548 : {
549 0 : auto arg = subst.get_arg_at (i);
550 0 : rust_assert (arg.has_value ());
551 0 : type_param_ranges.push_back (
552 0 : ctx.query_type_regions (arg.value ().get_tyty ()).size ());
553 : }
554 :
555 5 : add_constraints_from_ty (&ty, Variance::covariant ());
556 5 : return regions;
557 : }
558 :
559 : void
560 8 : FieldVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
561 : {
562 8 : Visitor visitor (*this, variance);
563 8 : ty->accept_vis (visitor);
564 8 : }
565 :
566 : void
567 3 : FieldVisitorCtx::add_constraints_from_region (const Region ®ion,
568 : Variance variance)
569 : {
570 3 : if (region.is_early_bound ())
571 : {
572 3 : regions.push_back (parent_regions[region.get_index ()]);
573 : }
574 0 : else if (region.is_late_bound ())
575 : {
576 0 : rust_debug ("Ignoring late bound region");
577 : }
578 3 : }
579 :
580 : void
581 0 : FieldVisitorCtx::add_constrints_from_param (ParamType ¶m, Variance variance)
582 : {
583 0 : size_t param_i = subst.get_used_arguments ().find_symbol (param).value ();
584 0 : for (size_t i = type_param_ranges[param_i];
585 0 : i < type_param_ranges[param_i + 1]; i++)
586 : {
587 0 : regions.push_back (parent_regions[i]);
588 : }
589 0 : }
590 :
591 : Variance
592 0 : TyVisitorCtx::contra (Variance variance)
593 : {
594 0 : return Variance::transform (variance, Variance::contravariant ());
595 : }
596 :
597 : Term
598 169 : Term::make_ref (SolutionIndex index)
599 : {
600 169 : Term term;
601 169 : term.kind = REF;
602 169 : term.ref = index;
603 169 : return term;
604 : }
605 :
606 : Term
607 178 : Term::make_transform (Term lhs, Term rhs)
608 : {
609 178 : if (lhs.is_const () && rhs.is_const ())
610 : {
611 9 : return Variance::transform (lhs.const_val, rhs.const_val);
612 : }
613 :
614 169 : Term term;
615 169 : term.kind = TRANSFORM;
616 169 : term.transform.lhs = new Term (lhs);
617 169 : term.transform.rhs = new Term (rhs);
618 169 : return term;
619 : }
620 :
621 : } // namespace VarianceAnalysis
622 : } // namespace TyTy
623 : } // namespace Rust
|