LCOV - code coverage report
Current view: top level - gcc/rust/typecheck - rust-tyty-variance-analysis.cc (source / functions) Coverage Total Hit
Test: gcc.info Lines: 85.9 % 312 268
Test Date: 2026-04-20 14:57:17 Functions: 88.1 % 42 37
Legend: Lines:     hit not hit

            Line data    Source code
       1              : #include "rust-tyty-variance-analysis-private.h"
       2              : #include "rust-hir-type-check.h"
       3              : 
       4              : namespace Rust {
       5              : namespace TyTy {
       6              : 
       7              : BaseType *
       8         7229 : lookup_type (HirId ref)
       9              : {
      10         7229 :   BaseType *ty = nullptr;
      11         7229 :   bool ok = Resolver::TypeCheckContext::get ()->lookup_type (ref, &ty);
      12         7229 :   rust_assert (ok);
      13         7229 :   return ty;
      14              : }
      15              : 
      16              : namespace VarianceAnalysis {
      17              : 
      18         4680 : CrateCtx::CrateCtx () : private_ctx (new GenericTyPerCrateCtx ()) {}
      19              : 
      20              : // Must be here because of incomplete type.
      21            0 : CrateCtx::~CrateCtx () = default;
      22              : 
      23              : void
      24         3085 : CrateCtx::add_type_constraints (ADTType &type)
      25              : {
      26         3085 :   private_ctx->process_type (type);
      27         3085 : }
      28              : 
      29              : void
      30         4463 : CrateCtx::solve ()
      31              : {
      32         4463 :   private_ctx->solve ();
      33         4463 :   private_ctx->debug_print_solutions ();
      34         4463 : }
      35              : 
      36              : std::vector<Variance>
      37            0 : CrateCtx::query_generic_variance (const ADTType &type)
      38              : {
      39            0 :   return private_ctx->query_generic_variance (type);
      40              : }
      41              : 
      42              : std::vector<Variance>
      43          516 : CrateCtx::query_type_variances (BaseType *type)
      44              : {
      45          516 :   TyVisitorCtx ctx (*private_ctx);
      46          516 :   return ctx.collect_variances (*type);
      47          516 : }
      48              : 
      49              : std::vector<Region>
      50           76 : CrateCtx::query_type_regions (BaseType *type)
      51              : {
      52           76 :   return private_ctx->query_type_regions (type);
      53              : }
      54              : 
      55              : FreeRegions
      56            5 : CrateCtx::query_field_regions (const ADTType *parent, size_t variant_index,
      57              :                                size_t field_index,
      58              :                                const FreeRegions &parent_regions)
      59              : {
      60            5 :   return private_ctx->query_field_regions (parent, variant_index, field_index,
      61            5 :                                            parent_regions);
      62              : }
      63              : 
      64              : Variance
      65            0 : Variance::reverse () const
      66              : {
      67            0 :   switch (kind)
      68              :     {
      69            0 :     case BIVARIANT:
      70            0 :       return bivariant ();
      71            0 :     case COVARIANT:
      72            0 :       return contravariant ();
      73            0 :     case CONTRAVARIANT:
      74            0 :       return covariant ();
      75            0 :     case INVARIANT:
      76            0 :       return invariant ();
      77              :     }
      78              : 
      79            0 :   rust_unreachable ();
      80              : }
      81              : 
      82              : Variance
      83         1483 : Variance::join (Variance lhs, Variance rhs)
      84              : {
      85         1483 :   return {Kind (lhs.kind | rhs.kind)};
      86              : }
      87              : 
      88              : void
      89         1320 : Variance::join (Variance rhs)
      90              : {
      91         1320 :   *this = join (*this, rhs);
      92         1320 : }
      93              : 
      94              : Variance
      95          216 : Variance::transform (Variance lhs, Variance rhs)
      96              : {
      97          216 :   switch (lhs.kind)
      98              :     {
      99            0 :     case BIVARIANT:
     100            0 :       return bivariant ();
     101          214 :     case COVARIANT:
     102          214 :       return rhs;
     103            0 :     case CONTRAVARIANT:
     104            0 :       return rhs.reverse ();
     105            2 :     case INVARIANT:
     106            2 :       return invariant ();
     107              :     }
     108            0 :   rust_unreachable ();
     109              : }
     110              : 
     111              : std::string
     112         8156 : Variance::as_string () const
     113              : {
     114         8156 :   switch (kind)
     115              :     {
     116          255 :     case BIVARIANT:
     117          255 :       return "o";
     118         7599 :     case COVARIANT:
     119         7599 :       return "+";
     120            7 :     case CONTRAVARIANT:
     121            7 :       return "-";
     122          295 :     case INVARIANT:
     123          295 :       return "*";
     124              :     }
     125            0 :   rust_unreachable ();
     126              : }
     127              : 
     128              : void
     129         3085 : GenericTyPerCrateCtx::process_type (ADTType &type)
     130              : {
     131         3085 :   GenericTyVisitorCtx (*this).process_type (type);
     132         3085 : }
     133              : 
     134              : void
     135         4463 : GenericTyPerCrateCtx::solve ()
     136              : {
     137         4463 :   rust_debug ("Variance analysis solving started:");
     138              : 
     139              :   // Fix point iteration
     140         4463 :   bool changed = true;
     141        13413 :   while (changed)
     142              :     {
     143         4487 :       changed = false;
     144         4650 :       for (auto constraint : constraints)
     145              :         {
     146          163 :           rust_debug ("\tapplying constraint: %s <= %s",
     147              :                       to_string (constraint.target_index).c_str (),
     148              :                       to_string (*constraint.term).c_str ());
     149              : 
     150          163 :           auto old_solution = solutions[constraint.target_index];
     151          163 :           auto new_solution
     152          163 :             = Variance::join (old_solution, evaluate (constraint.term));
     153              : 
     154          163 :           if (old_solution != new_solution)
     155              :             {
     156           24 :               rust_debug ("\t\tsolution changed: %s => %s",
     157              :                           old_solution.as_string ().c_str (),
     158              :                           new_solution.as_string ().c_str ());
     159              : 
     160           24 :               changed = true;
     161           24 :               solutions[constraint.target_index] = new_solution;
     162              :             }
     163              :         }
     164              :     }
     165              : 
     166         4463 :   constraints.clear ();
     167         4463 :   constraints.shrink_to_fit ();
     168         4463 : }
     169              : 
     170              : void
     171         4463 : GenericTyPerCrateCtx::debug_print_solutions ()
     172              : {
     173         4463 :   rust_debug ("Variance analysis results:");
     174              : 
     175         7548 :   for (auto type : map_from_ty_orig_ref)
     176              :     {
     177         3085 :       auto solution_index = type.second;
     178         3085 :       auto ref = type.first;
     179              : 
     180         3085 :       BaseType *ty = lookup_type (ref);
     181              : 
     182         3085 :       std::string result = "\t";
     183              : 
     184         3085 :       if (auto adt = ty->try_as<ADTType> ())
     185              :         {
     186         9255 :           result += adt->get_identifier ();
     187         3085 :           result += "<";
     188              : 
     189         3085 :           size_t i = solution_index;
     190         3106 :           for (auto &region : adt->get_used_arguments ().get_regions ())
     191              :             {
     192           21 :               (void) region;
     193           21 :               if (i > solution_index)
     194            1 :                 result += ", ";
     195           42 :               result += solutions[i].as_string ();
     196           21 :               i++;
     197              :             }
     198         4269 :           for (auto &param : adt->get_substs ())
     199              :             {
     200         1184 :               if (i > solution_index)
     201          129 :                 result += ", ";
     202         2368 :               result += param.get_type_representation ().as_string ();
     203         1184 :               result += "=";
     204         2368 :               result += solutions[i].as_string ();
     205         1184 :               i++;
     206              :             }
     207              : 
     208         3085 :           result += ">";
     209              :         }
     210              :       else
     211              :         {
     212            0 :           rust_sorry_at (
     213              :             ty->get_ref (),
     214              :             "This is a compiler bug: Unhandled type in variance analysis");
     215              :         }
     216         3085 :       rust_debug ("%s", result.c_str ());
     217         3085 :     }
     218         4463 : }
     219              : 
     220              : tl::optional<SolutionIndex>
     221         3512 : GenericTyPerCrateCtx::lookup_type_index (HirId orig_ref)
     222              : {
     223         3512 :   auto it = map_from_ty_orig_ref.find (orig_ref);
     224         3512 :   if (it != map_from_ty_orig_ref.end ())
     225              :     {
     226          363 :       return it->second;
     227              :     }
     228         3149 :   return tl::nullopt;
     229              : }
     230              : 
     231              : void
     232         3085 : GenericTyVisitorCtx::process_type (ADTType &ty)
     233              : {
     234         3085 :   rust_debug ("add_type_constraints: %s", ty.as_string ().c_str ());
     235              : 
     236         3085 :   first_lifetime = lookup_or_add_type (ty.get_orig_ref ());
     237         3085 :   first_type = first_lifetime + ty.get_used_arguments ().get_regions ().size ();
     238              : 
     239         4269 :   for (auto &param : ty.get_substs ())
     240         1184 :     param_names.push_back (param.get_type_representation ().as_string ());
     241              : 
     242         6864 :   for (const auto &variant : ty.get_variants ())
     243              :     {
     244         3779 :       if (variant->get_variant_type () != VariantDef::NUM
     245         3779 :           && variant->get_variant_type () != VariantDef::UNIT)
     246              :         {
     247         6828 :           for (const auto &field : variant->get_fields ())
     248         4300 :             add_constraints_from_ty (field->get_field_type (),
     249         4300 :                                      Variance::covariant ());
     250              :         }
     251              :     }
     252         3085 : }
     253              : 
     254              : std::string
     255         7831 : GenericTyPerCrateCtx::to_string (const Term &term) const
     256              : {
     257         7831 :   switch (term.kind)
     258              :     {
     259         6903 :     case Term::CONST:
     260         6903 :       return term.const_val.as_string ();
     261          464 :     case Term::REF:
     262          928 :       return "v(" + to_string (term.ref) + ")";
     263          464 :     case Term::TRANSFORM:
     264          928 :       return "(" + to_string (*term.transform.lhs) + " x "
     265         1856 :              + to_string (*term.transform.rhs) + ")";
     266              :     }
     267            0 :   rust_unreachable ();
     268              : }
     269              : 
     270              : std::string
     271          627 : GenericTyPerCrateCtx::to_string (SolutionIndex index) const
     272              : {
     273              :   // Search all values in def_id_to_solution_index_start and find key for
     274              :   // largest value smaller than index
     275          627 :   std::pair<HirId, SolutionIndex> best = {0, 0};
     276              : 
     277         2830 :   for (const auto &ty_map : map_from_ty_orig_ref)
     278              :     {
     279         2203 :       if (ty_map.second <= index && ty_map.first > best.first)
     280              :         best = ty_map;
     281              :     }
     282          627 :   rust_assert (best.first != 0);
     283              : 
     284          627 :   BaseType *ty = lookup_type (best.first);
     285              : 
     286          627 :   std::string result = "";
     287          627 :   if (auto adt = ty->try_as<ADTType> ())
     288              :     {
     289         1881 :       result += (adt->get_identifier ());
     290              :     }
     291              :   else
     292              :     {
     293            0 :       result += ty->as_string ();
     294              :     }
     295              : 
     296         1881 :   result += "[" + std::to_string (index - best.first) + "]";
     297          627 :   return result;
     298              : }
     299              : 
     300              : Variance
     301          489 : GenericTyPerCrateCtx::evaluate (Term *term)
     302              : {
     303          489 :   switch (term->kind)
     304              :     {
     305          163 :     case Term::CONST:
     306          163 :       return term->const_val;
     307          163 :     case Term::REF:
     308          163 :       return solutions[term->ref];
     309          163 :     case Term::TRANSFORM:
     310          163 :       return Variance::transform (evaluate (term->transform.lhs),
     311          163 :                                   evaluate (term->transform.rhs));
     312              :     }
     313            0 :   rust_unreachable ();
     314              : }
     315              : 
     316              : std::vector<Variance>
     317           78 : GenericTyPerCrateCtx::query_generic_variance (const ADTType &type)
     318              : {
     319           78 :   auto solution_index = lookup_type_index (type.get_orig_ref ());
     320           78 :   rust_assert (solution_index.has_value ());
     321           78 :   auto num_lifetimes = type.get_num_lifetime_params ();
     322           78 :   auto num_types = type.get_num_type_params ();
     323              : 
     324           78 :   std::vector<Variance> result;
     325           78 :   result.reserve (num_lifetimes + num_types);
     326              : 
     327          122 :   for (size_t i = 0; i < num_lifetimes + num_types; ++i)
     328              :     {
     329           44 :       result.push_back (solutions[solution_index.value () + i]);
     330              :     }
     331              : 
     332           78 :   return result;
     333              : }
     334              : 
     335              : FreeRegions
     336            5 : GenericTyPerCrateCtx::query_field_regions (const ADTType *parent,
     337              :                                            size_t variant_index,
     338              :                                            size_t field_index,
     339              :                                            const FreeRegions &parent_regions)
     340              : {
     341            5 :   auto orig = lookup_type (parent->get_orig_ref ());
     342            5 :   FieldVisitorCtx ctx (*this, *parent->as<const SubstitutionRef> (),
     343            5 :                        parent_regions);
     344            5 :   return ctx.collect_regions (*orig->as<const ADTType> ()
     345            5 :                                  ->get_variants ()
     346            5 :                                  .at (variant_index)
     347            5 :                                  ->get_fields ()
     348            5 :                                  .at (field_index)
     349            5 :                                  ->get_field_type ());
     350            5 : }
     351              : std::vector<Region>
     352           76 : GenericTyPerCrateCtx::query_type_regions (BaseType *type)
     353              : {
     354           76 :   TyVisitorCtx ctx (*this);
     355           76 :   return ctx.collect_regions (*type);
     356           76 : }
     357              : 
     358              : SolutionIndex
     359         3434 : GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id)
     360              : {
     361         3434 :   BaseType *ty = lookup_type (hir_id);
     362         3434 :   auto index = ctx.lookup_type_index (hir_id);
     363         3434 :   if (index.has_value ())
     364              :     {
     365          285 :       return index.value ();
     366              :     }
     367              : 
     368         3149 :   SubstitutionRef *subst = nullptr;
     369         3149 :   if (auto adt = ty->try_as<ADTType> ())
     370              :     {
     371         3149 :       subst = adt;
     372              :     }
     373              :   else
     374              :     {
     375            0 :       rust_sorry_at (
     376              :         ty->get_locus (),
     377              :         "This is a compiler bug: Unhandled type in variance analysis");
     378              :     }
     379            0 :   rust_assert (subst != nullptr);
     380              : 
     381         3149 :   auto solution_index = ctx.solutions.size ();
     382         3149 :   ctx.map_from_ty_orig_ref.emplace (ty->get_orig_ref (), solution_index);
     383              : 
     384         3149 :   auto num_lifetime_param = subst->get_used_arguments ().get_regions ().size ();
     385         3149 :   auto num_type_param = subst->get_num_substitutions ();
     386              : 
     387         4355 :   for (size_t i = 0; i < num_lifetime_param + num_type_param; ++i)
     388         1206 :     ctx.solutions.emplace_back (Variance::bivariant ());
     389              : 
     390         3149 :   return solution_index;
     391              : }
     392              : 
     393              : void
     394         5281 : GenericTyVisitorCtx::add_constraints_from_ty (BaseType *type, Term variance)
     395              : {
     396         5281 :   rust_debug ("\tadd_constraint_from_ty: %s with v=%s",
     397              :               type->as_string ().c_str (), ctx.to_string (variance).c_str ());
     398              : 
     399         5281 :   Visitor visitor (*this, variance);
     400         5281 :   type->accept_vis (visitor);
     401         5281 : }
     402              : 
     403              : void
     404         1459 : GenericTyVisitorCtx::add_constraint (SolutionIndex index, Term term)
     405              : {
     406         1459 :   rust_debug ("\t\tadd_constraint: %s", ctx.to_string (term).c_str ());
     407              : 
     408         1459 :   if (term.kind == Term::CONST)
     409              :     {
     410              :       // Constant terms do not depend on other solutions, so we can
     411              :       // immediately apply them.
     412         1320 :       ctx.solutions[index].join (term.const_val);
     413              :     }
     414              :   else
     415              :     {
     416          139 :       ctx.constraints.emplace_back (index, new Term (term));
     417              :     }
     418         1459 : }
     419              : 
     420              : void
     421           23 : GenericTyVisitorCtx::add_constraints_from_region (const Region &region,
     422              :                                                   Term term)
     423              : {
     424           23 :   if (region.is_early_bound ())
     425              :     {
     426           14 :       add_constraint (first_lifetime + region.get_index (), term);
     427              :     }
     428           23 : }
     429              : 
     430              : void
     431          349 : GenericTyVisitorCtx::add_constraints_from_generic_args (HirId ref,
     432              :                                                         SubstitutionRef &subst,
     433              :                                                         Term variance,
     434              :                                                         bool invariant_args)
     435              : {
     436          349 :   SolutionIndex solution_index = lookup_or_add_type (ref);
     437              : 
     438          349 :   size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
     439          349 :   size_t num_types = subst.get_substs ().size ();
     440              : 
     441          518 :   for (size_t i = 0; i < num_lifetimes + num_types; ++i)
     442              :     {
     443              :       // TODO: What about variance from other crates?
     444          169 :       auto variance_i
     445              :         = invariant_args
     446          169 :             ? Term::make_transform (variance, Variance::invariant ())
     447          169 :             : Term::make_transform (variance,
     448              :                                     Term::make_ref (solution_index + i));
     449              : 
     450          169 :       if (i < num_lifetimes)
     451              :         {
     452            1 :           auto region_i = i;
     453            1 :           auto &region
     454            1 :             = subst.get_substitution_arguments ().get_mut_regions ()[region_i];
     455            1 :           add_constraints_from_region (region, variance_i);
     456              :         }
     457              :       else
     458              :         {
     459          168 :           auto type_i = i - num_lifetimes;
     460          168 :           auto arg = subst.get_arg_at (type_i);
     461          168 :           if (arg.has_value ())
     462              :             {
     463          160 :               add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
     464              :             }
     465              :         }
     466              :     }
     467          349 : }
     468              : void
     469         1445 : GenericTyVisitorCtx::add_constrints_from_param (ParamType &type, Term variance)
     470              : {
     471         1445 :   auto it
     472         1445 :     = std::find (param_names.begin (), param_names.end (), type.get_name ());
     473         1445 :   rust_assert (it != param_names.end ());
     474              : 
     475         1445 :   auto index = first_type + std::distance (param_names.begin (), it);
     476              : 
     477         1445 :   add_constraint (index, variance);
     478         1445 : }
     479              : 
     480              : Term
     481            9 : GenericTyVisitorCtx::contra (Term variance)
     482              : {
     483            9 :   return Term::make_transform (variance, Variance::contravariant ());
     484              : }
     485              : 
     486              : void
     487          820 : TyVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
     488              : {
     489          820 :   Visitor visitor (*this, variance);
     490          820 :   ty->accept_vis (visitor);
     491          820 : }
     492              : 
     493              : void
     494          284 : TyVisitorCtx::add_constraints_from_region (const Region &region,
     495              :                                            Variance variance)
     496              : {
     497          284 :   variances.push_back (variance);
     498          284 :   regions.push_back (region);
     499          284 : }
     500              : 
     501              : void
     502           78 : TyVisitorCtx::add_constraints_from_generic_args (HirId ref,
     503              :                                                  SubstitutionRef &subst,
     504              :                                                  Variance variance,
     505              :                                                  bool invariant_args)
     506              : {
     507              :   // Handle function
     508           78 :   auto variances
     509           78 :     = ctx.query_generic_variance (*lookup_type (ref)->as<ADTType> ());
     510              : 
     511           78 :   size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
     512           78 :   size_t num_types = subst.get_substs ().size ();
     513              : 
     514          122 :   for (size_t i = 0; i < num_lifetimes + num_types; ++i)
     515              :     {
     516              :       // TODO: What about variance from other crates?
     517           44 :       auto variance_i
     518              :         = invariant_args
     519           44 :             ? Variance::transform (variance, Variance::invariant ())
     520           44 :             : Variance::transform (variance, variances[i]);
     521              : 
     522           44 :       if (i < num_lifetimes)
     523              :         {
     524           44 :           auto region_i = i;
     525           44 :           auto &region = subst.get_used_arguments ().get_regions ()[region_i];
     526           44 :           add_constraints_from_region (region, variance_i);
     527              :         }
     528              :       else
     529              :         {
     530            0 :           auto type_i = i - num_lifetimes;
     531            0 :           auto arg = subst.get_arg_at (type_i);
     532            0 :           if (arg.has_value ())
     533              :             {
     534            0 :               add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
     535              :             }
     536              :         }
     537              :     }
     538           78 : }
     539              : 
     540              : FreeRegions
     541            5 : FieldVisitorCtx::collect_regions (BaseType &ty)
     542              : {
     543              :   // Segment the regions into ranges for each type parameter. Type parameter
     544              :   // at index i contains regions from type_param_ranges[i] to
     545              :   // type_param_ranges[i+1] (exclusive).;
     546            5 :   type_param_ranges.push_back (subst.get_num_lifetime_params ());
     547              : 
     548            5 :   for (size_t i = 0; i < subst.get_num_type_params (); i++)
     549              :     {
     550            0 :       auto arg = subst.get_arg_at (i);
     551            0 :       rust_assert (arg.has_value ());
     552            0 :       type_param_ranges.push_back (
     553            0 :         ctx.query_type_regions (arg.value ().get_tyty ()).size ());
     554              :     }
     555              : 
     556            5 :   add_constraints_from_ty (&ty, Variance::covariant ());
     557            5 :   return regions;
     558              : }
     559              : 
     560              : void
     561            8 : FieldVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
     562              : {
     563            8 :   Visitor visitor (*this, variance);
     564            8 :   ty->accept_vis (visitor);
     565            8 : }
     566              : 
     567              : void
     568            3 : FieldVisitorCtx::add_constraints_from_region (const Region &region,
     569              :                                               Variance variance)
     570              : {
     571            3 :   if (region.is_early_bound ())
     572              :     {
     573            3 :       regions.push_back (parent_regions[region.get_index ()]);
     574              :     }
     575            0 :   else if (region.is_late_bound ())
     576              :     {
     577            0 :       rust_debug ("Ignoring late bound region");
     578              :     }
     579            3 : }
     580              : 
     581              : void
     582            0 : FieldVisitorCtx::add_constrints_from_param (ParamType &param, Variance variance)
     583              : {
     584            0 :   size_t param_i = subst.get_used_arguments ().find_symbol (param).value ();
     585            0 :   for (size_t i = type_param_ranges[param_i];
     586            0 :        i < type_param_ranges[param_i + 1]; i++)
     587              :     {
     588            0 :       regions.push_back (parent_regions[i]);
     589              :     }
     590            0 : }
     591              : 
     592              : Variance
     593            0 : TyVisitorCtx::contra (Variance variance)
     594              : {
     595            0 :   return Variance::transform (variance, Variance::contravariant ());
     596              : }
     597              : 
     598              : Term
     599          169 : Term::make_ref (SolutionIndex index)
     600              : {
     601          169 :   Term term;
     602          169 :   term.kind = REF;
     603          169 :   term.ref = index;
     604          169 :   return term;
     605              : }
     606              : 
     607              : Term
     608          178 : Term::make_transform (Term lhs, Term rhs)
     609              : {
     610          178 :   if (lhs.is_const () && rhs.is_const ())
     611              :     {
     612            9 :       return Variance::transform (lhs.const_val, rhs.const_val);
     613              :     }
     614              : 
     615          169 :   Term term;
     616          169 :   term.kind = TRANSFORM;
     617          169 :   term.transform.lhs = new Term (lhs);
     618          169 :   term.transform.rhs = new Term (rhs);
     619          169 :   return term;
     620              : }
     621              : 
     622              : } // namespace VarianceAnalysis
     623              : } // namespace TyTy
     624              : } // namespace Rust
        

Generated by: LCOV version 2.4-beta

LCOV profile is generated on x86_64 machine using following configure options: configure --disable-bootstrap --enable-coverage=opt --enable-languages=c,c++,fortran,go,jit,lto,rust,m2 --enable-host-shared. GCC test suite is run with the built compiler.