Branch data Line data Source code
1 : : #include "rust-tyty-variance-analysis-private.h"
2 : : #include "rust-hir-type-check.h"
3 : :
4 : : namespace Rust {
5 : : namespace TyTy {
6 : :
7 : : BaseType *
8 : 4473 : lookup_type (HirId ref)
9 : : {
10 : 4473 : BaseType *ty = nullptr;
11 : 4473 : bool ok = Resolver::TypeCheckContext::get ()->lookup_type (ref, &ty);
12 : 4473 : rust_assert (ok);
13 : 4473 : return ty;
14 : : }
15 : :
16 : : namespace VarianceAnalysis {
17 : :
18 : 3619 : CrateCtx::CrateCtx () : private_ctx (new GenericTyPerCrateCtx ()) {}
19 : :
20 : : // Must be here because of incomplete type.
21 : 0 : CrateCtx::~CrateCtx () = default;
22 : :
23 : : void
24 : 1930 : CrateCtx::add_type_constraints (ADTType &type)
25 : : {
26 : 1930 : private_ctx->process_type (type);
27 : 1930 : }
28 : :
29 : : void
30 : 3461 : CrateCtx::solve ()
31 : : {
32 : 3461 : private_ctx->solve ();
33 : 3461 : private_ctx->debug_print_solutions ();
34 : 3461 : }
35 : :
36 : : std::vector<Variance>
37 : 0 : CrateCtx::query_generic_variance (const ADTType &type)
38 : : {
39 : 0 : return private_ctx->query_generic_variance (type);
40 : : }
41 : :
42 : : std::vector<Variance>
43 : 0 : CrateCtx::query_type_variances (BaseType *type)
44 : : {
45 : 0 : TyVisitorCtx ctx (*private_ctx);
46 : 0 : return ctx.collect_variances (*type);
47 : 0 : }
48 : :
49 : : std::vector<Region>
50 : 0 : CrateCtx::query_type_regions (BaseType *type)
51 : : {
52 : 0 : return private_ctx->query_type_regions (type);
53 : : }
54 : :
55 : : std::vector<size_t>
56 : 0 : CrateCtx::query_field_regions (const ADTType *parent, size_t variant_index,
57 : : size_t field_index,
58 : : const FreeRegions &parent_regions)
59 : : {
60 : 0 : return private_ctx->query_field_regions (parent, variant_index, field_index,
61 : 0 : parent_regions);
62 : : }
63 : :
64 : : Variance
65 : 0 : Variance::reverse () const
66 : : {
67 : 0 : switch (kind)
68 : : {
69 : 0 : case BIVARIANT:
70 : 0 : return bivariant ();
71 : 0 : case COVARIANT:
72 : 0 : return contravariant ();
73 : 0 : case CONTRAVARIANT:
74 : 0 : return covariant ();
75 : 0 : case INVARIANT:
76 : 0 : return invariant ();
77 : : }
78 : :
79 : 0 : rust_unreachable ();
80 : : }
81 : :
82 : : Variance
83 : 1139 : Variance::join (Variance lhs, Variance rhs)
84 : : {
85 : 1139 : return {Kind (lhs.kind | rhs.kind)};
86 : : }
87 : :
88 : : void
89 : 1029 : Variance::join (Variance rhs)
90 : : {
91 : 1029 : *this = join (*this, rhs);
92 : 1029 : }
93 : :
94 : : Variance
95 : 117 : Variance::transform (Variance lhs, Variance rhs)
96 : : {
97 : 117 : switch (lhs.kind)
98 : : {
99 : 0 : case BIVARIANT:
100 : 0 : return bivariant ();
101 : 115 : case COVARIANT:
102 : 115 : return rhs;
103 : 0 : case CONTRAVARIANT:
104 : 0 : return rhs.reverse ();
105 : 2 : case INVARIANT:
106 : 2 : return invariant ();
107 : : }
108 : 0 : rust_unreachable ();
109 : : }
110 : :
111 : : std::string
112 : 5897 : Variance::as_string () const
113 : : {
114 : 5897 : switch (kind)
115 : : {
116 : 44 : case BIVARIANT:
117 : 44 : return "o";
118 : 5584 : case COVARIANT:
119 : 5584 : return "+";
120 : 7 : case CONTRAVARIANT:
121 : 7 : return "-";
122 : 262 : case INVARIANT:
123 : 262 : return "*";
124 : : }
125 : 0 : rust_unreachable ();
126 : : }
127 : :
128 : : void
129 : 1930 : GenericTyPerCrateCtx::process_type (ADTType &type)
130 : : {
131 : 1930 : GenericTyVisitorCtx (*this).process_type (type);
132 : 1930 : }
133 : :
134 : : void
135 : 3461 : GenericTyPerCrateCtx::solve ()
136 : : {
137 : 3461 : rust_debug ("Variance analysis solving started:");
138 : :
139 : : // Fix point iteration
140 : 3461 : bool changed = true;
141 : 10405 : while (changed)
142 : : {
143 : 3483 : changed = false;
144 : 3593 : for (auto constraint : constraints)
145 : : {
146 : 110 : rust_debug ("\tapplying constraint: %s <= %s",
147 : : to_string (constraint.target_index).c_str (),
148 : : to_string (*constraint.term).c_str ());
149 : :
150 : 110 : auto old_solution = solutions[constraint.target_index];
151 : 110 : auto new_solution
152 : 110 : = Variance::join (old_solution, evaluate (constraint.term));
153 : :
154 : 110 : if (old_solution != new_solution)
155 : : {
156 : 22 : rust_debug ("\t\tsolution changed: %s => %s",
157 : : old_solution.as_string ().c_str (),
158 : : new_solution.as_string ().c_str ());
159 : :
160 : 22 : changed = true;
161 : 22 : solutions[constraint.target_index] = new_solution;
162 : : }
163 : : }
164 : : }
165 : :
166 : 3461 : constraints.clear ();
167 : 3461 : constraints.shrink_to_fit ();
168 : 3461 : }
169 : :
170 : : void
171 : 3461 : GenericTyPerCrateCtx::debug_print_solutions ()
172 : : {
173 : 3461 : rust_debug ("Variance analysis results:");
174 : :
175 : 5390 : for (auto type : map_from_ty_orig_ref)
176 : : {
177 : 1929 : auto solution_index = type.second;
178 : 1929 : auto ref = type.first;
179 : :
180 : 1929 : BaseType *ty = lookup_type (ref);
181 : :
182 : 1929 : std::string result = "\t";
183 : :
184 : 1929 : if (auto adt = ty->try_as<ADTType> ())
185 : : {
186 : 1929 : result += adt->get_identifier ();
187 : 1929 : result += "<";
188 : :
189 : 1929 : size_t i = solution_index;
190 : 1941 : for (auto ®ion : adt->get_used_arguments ().get_regions ())
191 : : {
192 : 12 : (void) region;
193 : 12 : if (i > solution_index)
194 : 1 : result += ", ";
195 : 12 : result += solutions[i].as_string ();
196 : 12 : i++;
197 : : }
198 : 2664 : for (auto ¶m : adt->get_substs ())
199 : : {
200 : 735 : if (i > solution_index)
201 : 87 : result += ", ";
202 : 735 : result += param.get_generic_param ()
203 : 1470 : .get_type_representation ()
204 : 735 : .as_string ();
205 : 735 : result += "=";
206 : 735 : result += solutions[i].as_string ();
207 : 735 : i++;
208 : : }
209 : :
210 : 1929 : result += ">";
211 : : }
212 : : else
213 : : {
214 : 0 : rust_sorry_at (
215 : : ty->get_ref (),
216 : : "This is a compiler bug: Unhandled type in variance analysis");
217 : : }
218 : 1929 : rust_debug ("%s", result.c_str ());
219 : 1929 : }
220 : 3461 : }
221 : :
222 : : tl::optional<SolutionIndex>
223 : 2141 : GenericTyPerCrateCtx::lookup_type_index (HirId orig_ref)
224 : : {
225 : 2141 : auto it = map_from_ty_orig_ref.find (orig_ref);
226 : 2141 : if (it != map_from_ty_orig_ref.end ())
227 : : {
228 : 147 : return it->second;
229 : : }
230 : 1994 : return tl::nullopt;
231 : : }
232 : :
233 : : void
234 : 1930 : GenericTyVisitorCtx::process_type (ADTType &ty)
235 : : {
236 : 1930 : rust_debug ("add_type_constraints: %s", ty.as_string ().c_str ());
237 : :
238 : 1930 : first_lifetime = lookup_or_add_type (ty.get_orig_ref ());
239 : 1930 : first_type = first_lifetime + ty.get_used_arguments ().get_regions ().size ();
240 : :
241 : 2665 : for (const auto ¶m : ty.get_substs ())
242 : 735 : param_names.push_back (
243 : 1470 : param.get_generic_param ().get_type_representation ().as_string ());
244 : :
245 : 4101 : for (const auto &variant : ty.get_variants ())
246 : : {
247 : 2171 : if (variant->get_variant_type () != VariantDef::NUM)
248 : : {
249 : 5059 : for (const auto &field : variant->get_fields ())
250 : 3082 : add_constraints_from_ty (field->get_field_type (),
251 : : Variance::covariant ());
252 : : }
253 : : }
254 : 1930 : }
255 : :
256 : : std::string
257 : 5692 : GenericTyPerCrateCtx::to_string (const Term &term) const
258 : : {
259 : 5692 : switch (term.kind)
260 : : {
261 : 5106 : case Term::CONST:
262 : 5106 : return term.const_val.as_string ();
263 : 293 : case Term::REF:
264 : 586 : return "v(" + to_string (term.ref) + ")";
265 : 293 : case Term::TRANSFORM:
266 : 586 : return "(" + to_string (*term.transform.lhs) + " x "
267 : 1172 : + to_string (*term.transform.rhs) + ")";
268 : : }
269 : 0 : rust_unreachable ();
270 : : }
271 : :
272 : : std::string
273 : 403 : GenericTyPerCrateCtx::to_string (SolutionIndex index) const
274 : : {
275 : : // Search all values in def_id_to_solution_index_start and find key for
276 : : // largest value smaller than index
277 : 403 : std::pair<HirId, SolutionIndex> best = {0, 0};
278 : :
279 : 1355 : for (const auto &ty_map : map_from_ty_orig_ref)
280 : : {
281 : 952 : if (ty_map.second <= index && ty_map.first > best.first)
282 : : best = ty_map;
283 : : }
284 : 403 : rust_assert (best.first != 0);
285 : :
286 : 403 : BaseType *ty = lookup_type (best.first);
287 : :
288 : 403 : std::string result = "";
289 : 403 : if (auto adt = ty->try_as<ADTType> ())
290 : : {
291 : 403 : result += (adt->get_identifier ());
292 : : }
293 : : else
294 : : {
295 : 0 : result += ty->as_string ();
296 : : }
297 : :
298 : 806 : result += "[" + std::to_string (index - best.first) + "]";
299 : 403 : return result;
300 : : }
301 : :
302 : : Variance
303 : 330 : GenericTyPerCrateCtx::evaluate (Term *term)
304 : : {
305 : 330 : switch (term->kind)
306 : : {
307 : 110 : case Term::CONST:
308 : 110 : return term->const_val;
309 : 110 : case Term::REF:
310 : 110 : return solutions[term->ref];
311 : 110 : case Term::TRANSFORM:
312 : 110 : return Variance::transform (evaluate (term->transform.lhs),
313 : 110 : evaluate (term->transform.rhs));
314 : : }
315 : 0 : rust_unreachable ();
316 : : }
317 : :
318 : : std::vector<Variance>
319 : 0 : GenericTyPerCrateCtx::query_generic_variance (const ADTType &type)
320 : : {
321 : 0 : auto solution_index = lookup_type_index (type.get_orig_ref ());
322 : 0 : rust_assert (solution_index.has_value ());
323 : 0 : auto num_lifetimes = type.get_num_lifetime_params ();
324 : 0 : auto num_types = type.get_num_type_params ();
325 : :
326 : 0 : std::vector<Variance> result;
327 : 0 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
328 : : {
329 : 0 : result.push_back (solutions[solution_index.value () + i]);
330 : : }
331 : :
332 : 0 : return result;
333 : : }
334 : :
335 : : std::vector<size_t>
336 : 0 : GenericTyPerCrateCtx::query_field_regions (const ADTType *parent,
337 : : size_t variant_index,
338 : : size_t field_index,
339 : : const FreeRegions &parent_regions)
340 : : {
341 : 0 : auto orig = lookup_type (parent->get_orig_ref ());
342 : 0 : FieldVisitorCtx ctx (*this, *parent->as<const SubstitutionRef> (),
343 : 0 : parent_regions);
344 : 0 : return ctx.collect_regions (*orig->as<const ADTType> ()
345 : 0 : ->get_variants ()
346 : 0 : .at (variant_index)
347 : 0 : ->get_fields ()
348 : 0 : .at (field_index)
349 : 0 : ->get_field_type ());
350 : 0 : }
351 : : std::vector<Region>
352 : 0 : GenericTyPerCrateCtx::query_type_regions (BaseType *type)
353 : : {
354 : 0 : TyVisitorCtx ctx (*this);
355 : 0 : return ctx.collect_regions (*type);
356 : 0 : }
357 : :
358 : : SolutionIndex
359 : 2141 : GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id)
360 : : {
361 : 2141 : BaseType *ty = lookup_type (hir_id);
362 : 2141 : auto index = ctx.lookup_type_index (hir_id);
363 : 2141 : if (index.has_value ())
364 : : {
365 : 147 : return index.value ();
366 : : }
367 : :
368 : 1994 : SubstitutionRef *subst = nullptr;
369 : 1994 : if (auto adt = ty->try_as<ADTType> ())
370 : : {
371 : 1994 : subst = adt;
372 : : }
373 : : else
374 : : {
375 : 0 : rust_sorry_at (
376 : : ty->get_locus (),
377 : : "This is a compiler bug: Unhandled type in variance analysis");
378 : : }
379 : 0 : rust_assert (subst != nullptr);
380 : :
381 : 1994 : auto solution_index = ctx.solutions.size ();
382 : 1994 : ctx.map_from_ty_orig_ref.emplace (ty->get_orig_ref (), solution_index);
383 : :
384 : 1994 : auto num_lifetime_param = subst->get_used_arguments ().get_regions ().size ();
385 : 1994 : auto num_type_param = subst->get_num_substitutions ();
386 : :
387 : 2742 : for (size_t i = 0; i < num_lifetime_param + num_type_param; ++i)
388 : 748 : ctx.solutions.emplace_back (Variance::bivariant ());
389 : :
390 : 1994 : return solution_index;
391 : : }
392 : :
393 : : void
394 : 3879 : GenericTyVisitorCtx::add_constraints_from_ty (BaseType *type, Term variance)
395 : : {
396 : 3879 : rust_debug ("\tadd_constraint_from_ty: %s with v=%s",
397 : : type->as_string ().c_str (), ctx.to_string (variance).c_str ());
398 : :
399 : 3879 : Visitor visitor (*this, variance);
400 : 3879 : type->accept_vis (visitor);
401 : 3879 : }
402 : :
403 : : void
404 : 1117 : GenericTyVisitorCtx::add_constraint (SolutionIndex index, Term term)
405 : : {
406 : 1117 : rust_debug ("\t\tadd_constraint: %s", ctx.to_string (term).c_str ());
407 : :
408 : 1117 : if (term.kind == Term::CONST)
409 : : {
410 : : // Constant terms do not depend on other solutions, so we can
411 : : // immediately apply them.
412 : 1029 : ctx.solutions[index].join (term.const_val);
413 : : }
414 : : else
415 : : {
416 : 88 : ctx.constraints.push_back ({index, new Term (term)});
417 : : }
418 : 1117 : }
419 : :
420 : : void
421 : 10 : GenericTyVisitorCtx::add_constraints_from_region (const Region ®ion,
422 : : Term term)
423 : : {
424 : 10 : if (region.is_early_bound ())
425 : : {
426 : 10 : add_constraint (first_lifetime + region.get_index (), term);
427 : : }
428 : 10 : }
429 : :
430 : : void
431 : 211 : GenericTyVisitorCtx::add_constraints_from_generic_args (HirId ref,
432 : : SubstitutionRef &subst,
433 : : Term variance,
434 : : bool invariant_args)
435 : : {
436 : 211 : SolutionIndex solution_index = lookup_or_add_type (ref);
437 : :
438 : 211 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
439 : 211 : size_t num_types = subst.get_substs ().size ();
440 : :
441 : 308 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
442 : : {
443 : : // TODO: What about variance from other crates?
444 : 97 : auto variance_i
445 : : = invariant_args
446 : 97 : ? Term::make_transform (variance, Variance::invariant ())
447 : 97 : : Term::make_transform (variance,
448 : : Term::make_ref (solution_index + i));
449 : :
450 : 97 : if (i < num_lifetimes)
451 : : {
452 : 1 : auto region_i = i;
453 : 1 : auto ®ion
454 : 1 : = subst.get_substitution_arguments ().get_mut_regions ()[region_i];
455 : 1 : add_constraints_from_region (region, variance_i);
456 : : }
457 : : else
458 : : {
459 : 96 : auto type_i = i - num_lifetimes;
460 : 96 : auto arg = subst.get_arg_at (type_i);
461 : 96 : if (arg.has_value ())
462 : : {
463 : 95 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
464 : : }
465 : : }
466 : : }
467 : 211 : }
468 : : void
469 : 1107 : GenericTyVisitorCtx::add_constrints_from_param (ParamType &type, Term variance)
470 : : {
471 : 1107 : auto it
472 : 1107 : = std::find (param_names.begin (), param_names.end (), type.get_name ());
473 : 1107 : rust_assert (it != param_names.end ());
474 : :
475 : 1107 : auto index = first_type + std::distance (param_names.begin (), it);
476 : :
477 : 1107 : add_constraint (index, variance);
478 : 1107 : }
479 : :
480 : : Term
481 : 7 : GenericTyVisitorCtx::contra (Term variance)
482 : : {
483 : 7 : return Term::make_transform (variance, Variance::contravariant ());
484 : : }
485 : :
486 : : void
487 : 0 : TyVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
488 : : {
489 : 0 : Visitor visitor (*this, variance);
490 : 0 : ty->accept_vis (visitor);
491 : 0 : }
492 : :
493 : : void
494 : 0 : TyVisitorCtx::add_constraints_from_region (const Region ®ion,
495 : : Variance variance)
496 : : {
497 : 0 : variances.push_back (variance);
498 : 0 : regions.push_back (region);
499 : 0 : }
500 : :
501 : : void
502 : 0 : TyVisitorCtx::add_constraints_from_generic_args (HirId ref,
503 : : SubstitutionRef &subst,
504 : : Variance variance,
505 : : bool invariant_args)
506 : : {
507 : : // Handle function
508 : 0 : auto variances
509 : 0 : = ctx.query_generic_variance (*lookup_type (ref)->as<ADTType> ());
510 : :
511 : 0 : size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
512 : 0 : size_t num_types = subst.get_substs ().size ();
513 : :
514 : 0 : for (size_t i = 0; i < num_lifetimes + num_types; ++i)
515 : : {
516 : : // TODO: What about variance from other crates?
517 : 0 : auto variance_i
518 : : = invariant_args
519 : 0 : ? Variance::transform (variance, Variance::invariant ())
520 : 0 : : Variance::transform (variance, variances[i]);
521 : :
522 : 0 : if (i < num_lifetimes)
523 : : {
524 : 0 : auto region_i = i;
525 : 0 : auto ®ion = subst.get_used_arguments ().get_regions ()[region_i];
526 : 0 : add_constraints_from_region (region, variance_i);
527 : : }
528 : : else
529 : : {
530 : 0 : auto type_i = i - num_lifetimes;
531 : 0 : auto arg = subst.get_arg_at (type_i);
532 : 0 : if (arg.has_value ())
533 : : {
534 : 0 : add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
535 : : }
536 : : }
537 : : }
538 : 0 : }
539 : :
540 : : std::vector<size_t>
541 : 0 : FieldVisitorCtx::collect_regions (BaseType &ty)
542 : : {
543 : : // Segment the regions into ranges for each type parameter. Type parameter
544 : : // at index i contains regions from type_param_ranges[i] to
545 : : // type_param_ranges[i+1] (exclusive).;
546 : 0 : type_param_ranges.push_back (subst.get_num_lifetime_params ());
547 : :
548 : 0 : for (size_t i = 0; i < subst.get_num_type_params (); i++)
549 : : {
550 : 0 : auto arg = subst.get_arg_at (i);
551 : 0 : rust_assert (arg.has_value ());
552 : 0 : type_param_ranges.push_back (
553 : 0 : ctx.query_type_regions (arg.value ().get_tyty ()).size ());
554 : : }
555 : :
556 : 0 : add_constraints_from_ty (&ty, Variance::covariant ());
557 : 0 : return regions;
558 : : }
559 : :
560 : : void
561 : 0 : FieldVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
562 : : {
563 : 0 : Visitor visitor (*this, variance);
564 : 0 : ty->accept_vis (visitor);
565 : 0 : }
566 : :
567 : : void
568 : 0 : FieldVisitorCtx::add_constraints_from_region (const Region ®ion,
569 : : Variance variance)
570 : : {
571 : 0 : if (region.is_early_bound ())
572 : : {
573 : 0 : regions.push_back (parent_regions[region.get_index ()]);
574 : : }
575 : 0 : else if (region.is_late_bound ())
576 : : {
577 : 0 : rust_debug ("Ignoring late bound region");
578 : : }
579 : 0 : }
580 : :
581 : : void
582 : 0 : FieldVisitorCtx::add_constrints_from_param (ParamType ¶m, Variance variance)
583 : : {
584 : 0 : size_t param_i = subst.get_used_arguments ().find_symbol (param).value ();
585 : 0 : for (size_t i = type_param_ranges[param_i];
586 : 0 : i < type_param_ranges[param_i + 1]; i++)
587 : : {
588 : 0 : regions.push_back (parent_regions[i]);
589 : : }
590 : 0 : }
591 : :
592 : : Variance
593 : 0 : TyVisitorCtx::contra (Variance variance)
594 : : {
595 : 0 : return Variance::transform (variance, Variance::contravariant ());
596 : : }
597 : :
598 : : Term
599 : 97 : Term::make_ref (SolutionIndex index)
600 : : {
601 : 97 : Term term;
602 : 97 : term.kind = REF;
603 : 97 : term.ref = index;
604 : 97 : return term;
605 : : }
606 : :
607 : : Term
608 : 104 : Term::make_transform (Term lhs, Term rhs)
609 : : {
610 : 104 : if (lhs.is_const () && rhs.is_const ())
611 : : {
612 : 7 : return Variance::transform (lhs.const_val, rhs.const_val);
613 : : }
614 : :
615 : 97 : Term term;
616 : 97 : term.kind = TRANSFORM;
617 : 97 : term.transform.lhs = new Term (lhs);
618 : 97 : term.transform.rhs = new Term (rhs);
619 : 97 : return term;
620 : : }
621 : :
622 : : } // namespace VarianceAnalysis
623 : : } // namespace TyTy
624 : : } // namespace Rust
|