Branch data Line data Source code
1 : : #include "rust-tyty-variance-analysis-private.h"
2 : : #include "rust-hir-type-check.h"
3 : :
4 : : namespace Rust {
5 : : namespace TyTy {
6 : :
7 : : BaseType *
8 : 7151 : lookup_type (HirId ref)
9 : : {
10 : 7151 : BaseType *ty = nullptr;
11 : 7151 : bool ok = Resolver::TypeCheckContext::get ()->lookup_type (ref, &ty);
12 : 7151 : rust_assert (ok);
13 : 7151 : return ty;
14 : : }
15 : :
16 : : namespace VarianceAnalysis {
17 : :
18 : 4449 : CrateCtx::CrateCtx () : private_ctx (new GenericTyPerCrateCtx ()) {}
19 : :
20 : : // Must be here because of incomplete type.
21 : 0 : CrateCtx::~CrateCtx () = default;
22 : :
23 : : void
24 : 3046 : CrateCtx::add_type_constraints (ADTType &type)
25 : : {
26 : 3046 : private_ctx->process_type (type);
27 : 3046 : }
28 : :
29 : : void
30 : 4259 : CrateCtx::solve ()
31 : : {
32 : 4259 : private_ctx->solve ();
33 : 4259 : private_ctx->debug_print_solutions ();
34 : 4259 : }
35 : :
36 : : std::vector<Variance>
37 : 0 : CrateCtx::query_generic_variance (const ADTType &type)
38 : : {
39 : 0 : return private_ctx->query_generic_variance (type);
40 : : }
41 : :
42 : : std::vector<Variance>
43 : 516 : CrateCtx::query_type_variances (BaseType *type)
44 : : {
45 : 516 : TyVisitorCtx ctx (*private_ctx);
46 : 516 : return ctx.collect_variances (*type);
47 : 516 : }
48 : :
49 : : std::vector<Region>
50 : 76 : CrateCtx::query_type_regions (BaseType *type)
51 : : {
52 : 76 : return private_ctx->query_type_regions (type);
53 : : }
54 : :
55 : : FreeRegions
56 : 5 : CrateCtx::query_field_regions (const ADTType *parent, size_t variant_index,
57 : : size_t field_index,
58 : : const FreeRegions &parent_regions)
59 : : {
60 : 5 : return private_ctx->query_field_regions (parent, variant_index, field_index,
61 : 5 : parent_regions);
62 : : }
63 : :
64 : : Variance
65 : 0 : Variance::reverse () const
66 : : {
67 : 0 : switch (kind)
68 : : {
69 : 0 : case BIVARIANT:
70 : 0 : return bivariant ();
71 : 0 : case COVARIANT:
72 : 0 : return contravariant ();
73 : 0 : case CONTRAVARIANT:
74 : 0 : return covariant ();
75 : 0 : case INVARIANT:
76 : 0 : return invariant ();
77 : : }
78 : :
79 : 0 : rust_unreachable ();
80 : : }
81 : :
82 : : Variance
83 : 1483 : Variance::join (Variance lhs, Variance rhs)
84 : : {
85 : 1483 : return {Kind (lhs.kind | rhs.kind)};
86 : : }
87 : :
88 : : void
89 : 1320 : Variance::join (Variance rhs)
90 : : {
91 : 1320 : *this = join (*this, rhs);
92 : 1320 : }
93 : :
94 : : Variance
95 : 216 : Variance::transform (Variance lhs, Variance rhs)
96 : : {
97 : 216 : switch (lhs.kind)
98 : : {
99 : 0 : case BIVARIANT:
100 : 0 : return bivariant ();
101 : 214 : case COVARIANT:
102 : 214 : return rhs;
103 : 0 : case CONTRAVARIANT:
104 : 0 : return rhs.reverse ();
105 : 2 : case INVARIANT:
106 : 2 : return invariant ();
107 : : }
108 : 0 : rust_unreachable ();
109 : : }
110 : :
111 : : std::string
112 : 8116 : Variance::as_string () const
113 : : {
114 : 8116 : switch (kind)
115 : : {
116 : 254 : case BIVARIANT:
117 : 254 : return "o";
118 : 7560 : case COVARIANT:
119 : 7560 : return "+";
120 : 7 : case CONTRAVARIANT:
121 : 7 : return "-";
122 : 295 : case INVARIANT:
123 : 295 : return "*";
124 : : }
125 : 0 : rust_unreachable ();
126 : : }
127 : :
128 : : void
129 : 3046 : GenericTyPerCrateCtx::process_type (ADTType &type)
130 : : {
131 : 3046 : GenericTyVisitorCtx (*this).process_type (type);
132 : 3046 : }
133 : :
134 : : void
135 : 4259 : GenericTyPerCrateCtx::solve ()
136 : : {
137 : 4259 : rust_debug ("Variance analysis solving started:");
138 : :
139 : : // Fix point iteration
140 : 4259 : bool changed = true;
141 : 12801 : while (changed)
142 : : {
143 : 4283 : changed = false;
144 : 4446 : for (auto constraint : constraints)
145 : : {
146 : 163 : rust_debug ("\tapplying constraint: %s <= %s",
147 : : to_string (constraint.target_index).c_str (),
148 : : to_string (*constraint.term).c_str ());
149 : :
150 : 163 : auto old_solution = solutions[constraint.target_index];
151 : 163 : auto new_solution
152 : 163 : = Variance::join (old_solution, evaluate (constraint.term));
153 : :
154 : 163 : if (old_solution != new_solution)
155 : : {
156 : 24 : rust_debug ("\t\tsolution changed: %s => %s",
157 : : old_solution.as_string ().c_str (),
158 : : new_solution.as_string ().c_str ());
159 : :
160 : 24 : changed = true;
161 : 24 : solutions[constraint.target_index] = new_solution;
162 : : }
163 : : }
164 : : }
165 : :
166 : 4259 : constraints.clear ();
167 : 4259 : constraints.shrink_to_fit ();
168 : 4259 : }
169 : :
170 : : void
171 : 4259 : GenericTyPerCrateCtx::debug_print_solutions ()
172 : : {
173 : 4259 : rust_debug ("Variance analysis results:");
174 : :
175 : 7305 : for (auto type : map_from_ty_orig_ref)
176 : : {
177 : 3046 : auto solution_index = type.second;
178 : 3046 : auto ref = type.first;
179 : :
180 : 3046 : BaseType *ty = lookup_type (ref);
181 : :
182 : 3046 : std::string result = "\t";
183 : :
184 : 3046 : if (auto adt = ty->try_as<ADTType> ())
185 : : {
186 : 9138 : result += adt->get_identifier ();
187 : 3046 : result += "<";
188 : :
189 : 3046 : size_t i = solution_index;
190 : 3067 : for (auto ®ion : adt->get_used_arguments ().get_regions ())
191 : : {
192 : 21 : (void) region;
193 : 21 : if (i > solution_index)
194 : 1 : result += ", ";
195 : 42 : result += solutions[i].as_string ();
196 : 21 : i++;
197 : : }
198 : 4229 : for (auto ¶m : adt->get_substs ())
199 : : {
200 : 1183 : if (i > solution_index)
201 : 129 : result += ", ";
202 : 2366 : result += param.get_type_representation ().as_string ();
203 : 1183 : result += "=";
204 : 2366 : result += solutions[i].as_string ();
205 : 1183 : i++;
206 : : }
207 : :
208 : 3046 : result += ">";
209 : : }
210 : : else
211 : : {
212 : 0 : rust_sorry_at (
213 : : ty->get_ref (),
214 : : "This is a compiler bug: Unhandled type in variance analysis");
215 : : }
216 : 3046 : rust_debug ("%s", result.c_str ());
217 : 3046 : }
218 : 4259 : }
219 : :
220 : : tl::optional<SolutionIndex>
221 : 3473 : GenericTyPerCrateCtx::lookup_type_index (HirId orig_ref)
222 : : {
223 : 3473 : auto it = map_from_ty_orig_ref.find (orig_ref);
224 : 3473 : if (it != map_from_ty_orig_ref.end ())
225 : : {
226 : 363 : return it->second;
227 : : }
228 : 3110 : return tl::nullopt;
229 : : }
230 : :
231 : : void
232 : 3046 : GenericTyVisitorCtx::process_type (ADTType &ty)
233 : : {
234 : 3046 : rust_debug ("add_type_constraints: %s", ty.as_string ().c_str ());
235 : :
236 : 3046 : first_lifetime = lookup_or_add_type (ty.get_orig_ref ());
237 : 3046 : first_type = first_lifetime + ty.get_used_arguments ().get_regions ().size ();
238 : :
239 : 4229 : for (auto ¶m : ty.get_substs ())
240 : 1183 : param_names.push_back (param.get_type_representation ().as_string ());
241 : :
242 : 6777 : for (const auto &variant : ty.get_variants ())
243 : : {
244 : 3731 : if (variant->get_variant_type () != VariantDef::NUM)
245 : : {
246 : 7296 : for (const auto &field : variant->get_fields ())
247 : 4265 : add_constraints_from_ty (field->get_field_type (),
248 : 4265 : Variance::covariant ());
249 : : }
250 : : }
251 : 3046 : }
252 : :
253 : : std::string
254 : 7792 : GenericTyPerCrateCtx::to_string (const Term &term) const
255 : : {
256 : 7792 : switch (term.kind)
257 : : {
258 : 6864 : case Term::CONST:
259 : 6864 : return term.const_val.as_string ();
260 : 464 : case Term::REF:
261 : 928 : return "v(" + to_string (term.ref) + ")";
262 : 464 : case Term::TRANSFORM:
263 : 928 : return "(" + to_string (*term.transform.lhs) + " x "
264 : 1856 : + to_string (*term.transform.rhs) + ")";
265 : : }
266 : 0 : rust_unreachable ();
267 : : }
268 : :
269 : : std::string
270 : 627 : GenericTyPerCrateCtx::to_string (SolutionIndex index) const
271 : : {
272 : : // Search all values in def_id_to_solution_index_start and find key for
273 : : // largest value smaller than index
274 : 627 : std::pair<HirId, SolutionIndex> best = {0, 0};
275 : :
276 : 2830 : for (const auto &ty_map : map_from_ty_orig_ref)
277 : : {
278 : 2203 : if (ty_map.second <= index && ty_map.first > best.first)
279 : : best = ty_map;
280 : : }
281 : 627 : rust_assert (best.first != 0);
282 : :
283 : 627 : BaseType *ty = lookup_type (best.first);
284 : :
285 : 627 : std::string result = "";
286 : 627 : if (auto adt = ty->try_as<ADTType> ())
287 : : {
288 : 1881 : result += (adt->get_identifier ());
289 : : }
290 : : else
291 : : {
292 : 0 : result += ty->as_string ();
293 : : }
294 : :
295 : 1881 : result += "[" + std::to_string (index - best.first) + "]";
296 : 627 : return result;
297 : : }
298 : :
299 : : Variance
300 : 489 : GenericTyPerCrateCtx::evaluate (Term *term)
301 : : {
302 : 489 : switch (term->kind)
303 : : {
304 : 163 : case Term::CONST:
305 : 163 : return term->const_val;
306 : 163 : case Term::REF:
307 : 163 : return solutions[term->ref];
308 : 163 : case Term::TRANSFORM:
309 : 163 : return Variance::transform (evaluate (term->transform.lhs),
310 : 163 : evaluate (term->transform.rhs));
311 : : }
312 : 0 : rust_unreachable ();
313 : : }
314 : :
315 : : std::vector<Variance>
316 : 78 : GenericTyPerCrateCtx::query_generic_variance (const ADTType &type)
317 : : {
318 : 78 : auto solution_index = lookup_type_index (type.get_orig_ref ());
319 : 78 : rust_assert (solution_index.has_value ());
320 : 78 : auto num_lifetimes = type.get_num_lifetime_params ();
321 : 78 : auto num_types = type.get_num_type_params ();
322 : :
323 : 78 : std::vector<Variance> result;
324 : 78 : result.reserve (num_lifetimes + num_types);
325 : :
326 : 122 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
327 : : {
328 : 44 : result.push_back (solutions[solution_index.value () + i]);
329 : : }
330 : :
331 : 78 : return result;
332 : : }
333 : :
334 : : FreeRegions
335 : 5 : GenericTyPerCrateCtx::query_field_regions (const ADTType *parent,
336 : : size_t variant_index,
337 : : size_t field_index,
338 : : const FreeRegions &parent_regions)
339 : : {
340 : 5 : auto orig = lookup_type (parent->get_orig_ref ());
341 : 5 : FieldVisitorCtx ctx (*this, *parent->as<const SubstitutionRef> (),
342 : 5 : parent_regions);
343 : 5 : return ctx.collect_regions (*orig->as<const ADTType> ()
344 : 5 : ->get_variants ()
345 : 5 : .at (variant_index)
346 : 5 : ->get_fields ()
347 : 5 : .at (field_index)
348 : 5 : ->get_field_type ());
349 : 5 : }
350 : : std::vector<Region>
351 : 76 : GenericTyPerCrateCtx::query_type_regions (BaseType *type)
352 : : {
353 : 76 : TyVisitorCtx ctx (*this);
354 : 76 : return ctx.collect_regions (*type);
355 : 76 : }
356 : :
357 : : SolutionIndex
358 : 3395 : GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id)
359 : : {
360 : 3395 : BaseType *ty = lookup_type (hir_id);
361 : 3395 : auto index = ctx.lookup_type_index (hir_id);
362 : 3395 : if (index.has_value ())
363 : : {
364 : 285 : return index.value ();
365 : : }
366 : :
367 : 3110 : SubstitutionRef *subst = nullptr;
368 : 3110 : if (auto adt = ty->try_as<ADTType> ())
369 : : {
370 : 3110 : subst = adt;
371 : : }
372 : : else
373 : : {
374 : 0 : rust_sorry_at (
375 : : ty->get_locus (),
376 : : "This is a compiler bug: Unhandled type in variance analysis");
377 : : }
378 : 0 : rust_assert (subst != nullptr);
379 : :
380 : 3110 : auto solution_index = ctx.solutions.size ();
381 : 3110 : ctx.map_from_ty_orig_ref.emplace (ty->get_orig_ref (), solution_index);
382 : :
383 : 3110 : auto num_lifetime_param = subst->get_used_arguments ().get_regions ().size ();
384 : 3110 : auto num_type_param = subst->get_num_substitutions ();
385 : :
386 : 4315 : for (size_t i = 0; i < num_lifetime_param + num_type_param; ++i)
387 : 1205 : ctx.solutions.emplace_back (Variance::bivariant ());
388 : :
389 : 3110 : return solution_index;
390 : : }
391 : :
392 : : void
393 : 5242 : GenericTyVisitorCtx::add_constraints_from_ty (BaseType *type, Term variance)
394 : : {
395 : 5242 : rust_debug ("\tadd_constraint_from_ty: %s with v=%s",
396 : : type->as_string ().c_str (), ctx.to_string (variance).c_str ());
397 : :
398 : 5242 : Visitor visitor (*this, variance);
399 : 5242 : type->accept_vis (visitor);
400 : 5242 : }
401 : :
402 : : void
403 : 1459 : GenericTyVisitorCtx::add_constraint (SolutionIndex index, Term term)
404 : : {
405 : 1459 : rust_debug ("\t\tadd_constraint: %s", ctx.to_string (term).c_str ());
406 : :
407 : 1459 : if (term.kind == Term::CONST)
408 : : {
409 : : // Constant terms do not depend on other solutions, so we can
410 : : // immediately apply them.
411 : 1320 : ctx.solutions[index].join (term.const_val);
412 : : }
413 : : else
414 : : {
415 : 139 : ctx.constraints.emplace_back (index, new Term (term));
416 : : }
417 : 1459 : }
418 : :
419 : : void
420 : 23 : GenericTyVisitorCtx::add_constraints_from_region (const Region ®ion,
421 : : Term term)
422 : : {
423 : 23 : if (region.is_early_bound ())
424 : : {
425 : 14 : add_constraint (first_lifetime + region.get_index (), term);
426 : : }
427 : 23 : }
428 : :
429 : : void
430 : 349 : GenericTyVisitorCtx::add_constraints_from_generic_args (HirId ref,
431 : : SubstitutionRef &subst,
432 : : Term variance,
433 : : bool invariant_args)
434 : : {
435 : 349 : SolutionIndex solution_index = lookup_or_add_type (ref);
436 : :
437 : 349 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
438 : 349 : size_t num_types = subst.get_substs ().size ();
439 : :
440 : 518 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
441 : : {
442 : : // TODO: What about variance from other crates?
443 : 169 : auto variance_i
444 : : = invariant_args
445 : 169 : ? Term::make_transform (variance, Variance::invariant ())
446 : 169 : : Term::make_transform (variance,
447 : : Term::make_ref (solution_index + i));
448 : :
449 : 169 : if (i < num_lifetimes)
450 : : {
451 : 1 : auto region_i = i;
452 : 1 : auto ®ion
453 : 1 : = subst.get_substitution_arguments ().get_mut_regions ()[region_i];
454 : 1 : add_constraints_from_region (region, variance_i);
455 : : }
456 : : else
457 : : {
458 : 168 : auto type_i = i - num_lifetimes;
459 : 168 : auto arg = subst.get_arg_at (type_i);
460 : 168 : if (arg.has_value ())
461 : : {
462 : 160 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
463 : : }
464 : : }
465 : : }
466 : 349 : }
467 : : void
468 : 1445 : GenericTyVisitorCtx::add_constrints_from_param (ParamType &type, Term variance)
469 : : {
470 : 1445 : auto it
471 : 1445 : = std::find (param_names.begin (), param_names.end (), type.get_name ());
472 : 1445 : rust_assert (it != param_names.end ());
473 : :
474 : 1445 : auto index = first_type + std::distance (param_names.begin (), it);
475 : :
476 : 1445 : add_constraint (index, variance);
477 : 1445 : }
478 : :
479 : : Term
480 : 9 : GenericTyVisitorCtx::contra (Term variance)
481 : : {
482 : 9 : return Term::make_transform (variance, Variance::contravariant ());
483 : : }
484 : :
485 : : void
486 : 820 : TyVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
487 : : {
488 : 820 : Visitor visitor (*this, variance);
489 : 820 : ty->accept_vis (visitor);
490 : 820 : }
491 : :
492 : : void
493 : 284 : TyVisitorCtx::add_constraints_from_region (const Region ®ion,
494 : : Variance variance)
495 : : {
496 : 284 : variances.push_back (variance);
497 : 284 : regions.push_back (region);
498 : 284 : }
499 : :
500 : : void
501 : 78 : TyVisitorCtx::add_constraints_from_generic_args (HirId ref,
502 : : SubstitutionRef &subst,
503 : : Variance variance,
504 : : bool invariant_args)
505 : : {
506 : : // Handle function
507 : 78 : auto variances
508 : 78 : = ctx.query_generic_variance (*lookup_type (ref)->as<ADTType> ());
509 : :
510 : 78 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
511 : 78 : size_t num_types = subst.get_substs ().size ();
512 : :
513 : 122 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
514 : : {
515 : : // TODO: What about variance from other crates?
516 : 44 : auto variance_i
517 : : = invariant_args
518 : 44 : ? Variance::transform (variance, Variance::invariant ())
519 : 44 : : Variance::transform (variance, variances[i]);
520 : :
521 : 44 : if (i < num_lifetimes)
522 : : {
523 : 44 : auto region_i = i;
524 : 44 : auto ®ion = subst.get_used_arguments ().get_regions ()[region_i];
525 : 44 : add_constraints_from_region (region, variance_i);
526 : : }
527 : : else
528 : : {
529 : 0 : auto type_i = i - num_lifetimes;
530 : 0 : auto arg = subst.get_arg_at (type_i);
531 : 0 : if (arg.has_value ())
532 : : {
533 : 0 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
534 : : }
535 : : }
536 : : }
537 : 78 : }
538 : :
539 : : FreeRegions
540 : 5 : FieldVisitorCtx::collect_regions (BaseType &ty)
541 : : {
542 : : // Segment the regions into ranges for each type parameter. Type parameter
543 : : // at index i contains regions from type_param_ranges[i] to
544 : : // type_param_ranges[i+1] (exclusive).;
545 : 5 : type_param_ranges.push_back (subst.get_num_lifetime_params ());
546 : :
547 : 5 : for (size_t i = 0; i < subst.get_num_type_params (); i++)
548 : : {
549 : 0 : auto arg = subst.get_arg_at (i);
550 : 0 : rust_assert (arg.has_value ());
551 : 0 : type_param_ranges.push_back (
552 : 0 : ctx.query_type_regions (arg.value ().get_tyty ()).size ());
553 : : }
554 : :
555 : 5 : add_constraints_from_ty (&ty, Variance::covariant ());
556 : 5 : return regions;
557 : : }
558 : :
559 : : void
560 : 8 : FieldVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
561 : : {
562 : 8 : Visitor visitor (*this, variance);
563 : 8 : ty->accept_vis (visitor);
564 : 8 : }
565 : :
566 : : void
567 : 3 : FieldVisitorCtx::add_constraints_from_region (const Region ®ion,
568 : : Variance variance)
569 : : {
570 : 3 : if (region.is_early_bound ())
571 : : {
572 : 3 : regions.push_back (parent_regions[region.get_index ()]);
573 : : }
574 : 0 : else if (region.is_late_bound ())
575 : : {
576 : 0 : rust_debug ("Ignoring late bound region");
577 : : }
578 : 3 : }
579 : :
580 : : void
581 : 0 : FieldVisitorCtx::add_constrints_from_param (ParamType ¶m, Variance variance)
582 : : {
583 : 0 : size_t param_i = subst.get_used_arguments ().find_symbol (param).value ();
584 : 0 : for (size_t i = type_param_ranges[param_i];
585 : 0 : i < type_param_ranges[param_i + 1]; i++)
586 : : {
587 : 0 : regions.push_back (parent_regions[i]);
588 : : }
589 : 0 : }
590 : :
591 : : Variance
592 : 0 : TyVisitorCtx::contra (Variance variance)
593 : : {
594 : 0 : return Variance::transform (variance, Variance::contravariant ());
595 : : }
596 : :
597 : : Term
598 : 169 : Term::make_ref (SolutionIndex index)
599 : : {
600 : 169 : Term term;
601 : 169 : term.kind = REF;
602 : 169 : term.ref = index;
603 : 169 : return term;
604 : : }
605 : :
606 : : Term
607 : 178 : Term::make_transform (Term lhs, Term rhs)
608 : : {
609 : 178 : if (lhs.is_const () && rhs.is_const ())
610 : : {
611 : 9 : return Variance::transform (lhs.const_val, rhs.const_val);
612 : : }
613 : :
614 : 169 : Term term;
615 : 169 : term.kind = TRANSFORM;
616 : 169 : term.transform.lhs = new Term (lhs);
617 : 169 : term.transform.rhs = new Term (rhs);
618 : 169 : return term;
619 : : }
620 : :
621 : : } // namespace VarianceAnalysis
622 : : } // namespace TyTy
623 : : } // namespace Rust
|