LCOV - code coverage report
Current view: top level - gcc/rust/typecheck - rust-tyty-variance-analysis.cc (source / functions) Coverage Total Hit
Test: gcc.info Lines: 85.9 % 311 267
Test Date: 2026-02-28 14:20:25 Functions: 88.1 % 42 37
Legend: Lines:     hit not hit

            Line data    Source code
       1              : #include "rust-tyty-variance-analysis-private.h"
       2              : #include "rust-hir-type-check.h"
       3              : 
       4              : namespace Rust {
       5              : namespace TyTy {
       6              : 
       7              : BaseType *
       8         7195 : lookup_type (HirId ref)
       9              : {
      10         7195 :   BaseType *ty = nullptr;
      11         7195 :   bool ok = Resolver::TypeCheckContext::get ()->lookup_type (ref, &ty);
      12         7195 :   rust_assert (ok);
      13         7195 :   return ty;
      14              : }
      15              : 
      16              : namespace VarianceAnalysis {
      17              : 
      18         4510 : CrateCtx::CrateCtx () : private_ctx (new GenericTyPerCrateCtx ()) {}
      19              : 
      20              : // Must be here because of incomplete type.
      21            0 : CrateCtx::~CrateCtx () = default;
      22              : 
      23              : void
      24         3068 : CrateCtx::add_type_constraints (ADTType &type)
      25              : {
      26         3068 :   private_ctx->process_type (type);
      27         3068 : }
      28              : 
      29              : void
      30         4300 : CrateCtx::solve ()
      31              : {
      32         4300 :   private_ctx->solve ();
      33         4300 :   private_ctx->debug_print_solutions ();
      34         4300 : }
      35              : 
      36              : std::vector<Variance>
      37            0 : CrateCtx::query_generic_variance (const ADTType &type)
      38              : {
      39            0 :   return private_ctx->query_generic_variance (type);
      40              : }
      41              : 
      42              : std::vector<Variance>
      43          516 : CrateCtx::query_type_variances (BaseType *type)
      44              : {
      45          516 :   TyVisitorCtx ctx (*private_ctx);
      46          516 :   return ctx.collect_variances (*type);
      47          516 : }
      48              : 
      49              : std::vector<Region>
      50           76 : CrateCtx::query_type_regions (BaseType *type)
      51              : {
      52           76 :   return private_ctx->query_type_regions (type);
      53              : }
      54              : 
      55              : FreeRegions
      56            5 : CrateCtx::query_field_regions (const ADTType *parent, size_t variant_index,
      57              :                                size_t field_index,
      58              :                                const FreeRegions &parent_regions)
      59              : {
      60            5 :   return private_ctx->query_field_regions (parent, variant_index, field_index,
      61            5 :                                            parent_regions);
      62              : }
      63              : 
      64              : Variance
      65            0 : Variance::reverse () const
      66              : {
      67            0 :   switch (kind)
      68              :     {
      69            0 :     case BIVARIANT:
      70            0 :       return bivariant ();
      71            0 :     case COVARIANT:
      72            0 :       return contravariant ();
      73            0 :     case CONTRAVARIANT:
      74            0 :       return covariant ();
      75            0 :     case INVARIANT:
      76            0 :       return invariant ();
      77              :     }
      78              : 
      79            0 :   rust_unreachable ();
      80              : }
      81              : 
      82              : Variance
      83         1483 : Variance::join (Variance lhs, Variance rhs)
      84              : {
      85         1483 :   return {Kind (lhs.kind | rhs.kind)};
      86              : }
      87              : 
      88              : void
      89         1320 : Variance::join (Variance rhs)
      90              : {
      91         1320 :   *this = join (*this, rhs);
      92         1320 : }
      93              : 
      94              : Variance
      95          216 : Variance::transform (Variance lhs, Variance rhs)
      96              : {
      97          216 :   switch (lhs.kind)
      98              :     {
      99            0 :     case BIVARIANT:
     100            0 :       return bivariant ();
     101          214 :     case COVARIANT:
     102          214 :       return rhs;
     103            0 :     case CONTRAVARIANT:
     104            0 :       return rhs.reverse ();
     105            2 :     case INVARIANT:
     106            2 :       return invariant ();
     107              :     }
     108            0 :   rust_unreachable ();
     109              : }
     110              : 
     111              : std::string
     112         8139 : Variance::as_string () const
     113              : {
     114         8139 :   switch (kind)
     115              :     {
     116          254 :     case BIVARIANT:
     117          254 :       return "o";
     118         7583 :     case COVARIANT:
     119         7583 :       return "+";
     120            7 :     case CONTRAVARIANT:
     121            7 :       return "-";
     122          295 :     case INVARIANT:
     123          295 :       return "*";
     124              :     }
     125            0 :   rust_unreachable ();
     126              : }
     127              : 
     128              : void
     129         3068 : GenericTyPerCrateCtx::process_type (ADTType &type)
     130              : {
     131         3068 :   GenericTyVisitorCtx (*this).process_type (type);
     132         3068 : }
     133              : 
     134              : void
     135         4300 : GenericTyPerCrateCtx::solve ()
     136              : {
     137         4300 :   rust_debug ("Variance analysis solving started:");
     138              : 
     139              :   // Fix point iteration
     140         4300 :   bool changed = true;
     141        12924 :   while (changed)
     142              :     {
     143         4324 :       changed = false;
     144         4487 :       for (auto constraint : constraints)
     145              :         {
     146          163 :           rust_debug ("\tapplying constraint: %s <= %s",
     147              :                       to_string (constraint.target_index).c_str (),
     148              :                       to_string (*constraint.term).c_str ());
     149              : 
     150          163 :           auto old_solution = solutions[constraint.target_index];
     151          163 :           auto new_solution
     152          163 :             = Variance::join (old_solution, evaluate (constraint.term));
     153              : 
     154          163 :           if (old_solution != new_solution)
     155              :             {
     156           24 :               rust_debug ("\t\tsolution changed: %s => %s",
     157              :                           old_solution.as_string ().c_str (),
     158              :                           new_solution.as_string ().c_str ());
     159              : 
     160           24 :               changed = true;
     161           24 :               solutions[constraint.target_index] = new_solution;
     162              :             }
     163              :         }
     164              :     }
     165              : 
     166         4300 :   constraints.clear ();
     167         4300 :   constraints.shrink_to_fit ();
     168         4300 : }
     169              : 
     170              : void
     171         4300 : GenericTyPerCrateCtx::debug_print_solutions ()
     172              : {
     173         4300 :   rust_debug ("Variance analysis results:");
     174              : 
     175         7368 :   for (auto type : map_from_ty_orig_ref)
     176              :     {
     177         3068 :       auto solution_index = type.second;
     178         3068 :       auto ref = type.first;
     179              : 
     180         3068 :       BaseType *ty = lookup_type (ref);
     181              : 
     182         3068 :       std::string result = "\t";
     183              : 
     184         3068 :       if (auto adt = ty->try_as<ADTType> ())
     185              :         {
     186         9204 :           result += adt->get_identifier ();
     187         3068 :           result += "<";
     188              : 
     189         3068 :           size_t i = solution_index;
     190         3089 :           for (auto &region : adt->get_used_arguments ().get_regions ())
     191              :             {
     192           21 :               (void) region;
     193           21 :               if (i > solution_index)
     194            1 :                 result += ", ";
     195           42 :               result += solutions[i].as_string ();
     196           21 :               i++;
     197              :             }
     198         4251 :           for (auto &param : adt->get_substs ())
     199              :             {
     200         1183 :               if (i > solution_index)
     201          129 :                 result += ", ";
     202         2366 :               result += param.get_type_representation ().as_string ();
     203         1183 :               result += "=";
     204         2366 :               result += solutions[i].as_string ();
     205         1183 :               i++;
     206              :             }
     207              : 
     208         3068 :           result += ">";
     209              :         }
     210              :       else
     211              :         {
     212            0 :           rust_sorry_at (
     213              :             ty->get_ref (),
     214              :             "This is a compiler bug: Unhandled type in variance analysis");
     215              :         }
     216         3068 :       rust_debug ("%s", result.c_str ());
     217         3068 :     }
     218         4300 : }
     219              : 
     220              : tl::optional<SolutionIndex>
     221         3495 : GenericTyPerCrateCtx::lookup_type_index (HirId orig_ref)
     222              : {
     223         3495 :   auto it = map_from_ty_orig_ref.find (orig_ref);
     224         3495 :   if (it != map_from_ty_orig_ref.end ())
     225              :     {
     226          363 :       return it->second;
     227              :     }
     228         3132 :   return tl::nullopt;
     229              : }
     230              : 
     231              : void
     232         3068 : GenericTyVisitorCtx::process_type (ADTType &ty)
     233              : {
     234         3068 :   rust_debug ("add_type_constraints: %s", ty.as_string ().c_str ());
     235              : 
     236         3068 :   first_lifetime = lookup_or_add_type (ty.get_orig_ref ());
     237         3068 :   first_type = first_lifetime + ty.get_used_arguments ().get_regions ().size ();
     238              : 
     239         4251 :   for (auto &param : ty.get_substs ())
     240         1183 :     param_names.push_back (param.get_type_representation ().as_string ());
     241              : 
     242         6826 :   for (const auto &variant : ty.get_variants ())
     243              :     {
     244         3758 :       if (variant->get_variant_type () != VariantDef::NUM)
     245              :         {
     246         7343 :           for (const auto &field : variant->get_fields ())
     247         4288 :             add_constraints_from_ty (field->get_field_type (),
     248         4288 :                                      Variance::covariant ());
     249              :         }
     250              :     }
     251         3068 : }
     252              : 
     253              : std::string
     254         7815 : GenericTyPerCrateCtx::to_string (const Term &term) const
     255              : {
     256         7815 :   switch (term.kind)
     257              :     {
     258         6887 :     case Term::CONST:
     259         6887 :       return term.const_val.as_string ();
     260          464 :     case Term::REF:
     261          928 :       return "v(" + to_string (term.ref) + ")";
     262          464 :     case Term::TRANSFORM:
     263          928 :       return "(" + to_string (*term.transform.lhs) + " x "
     264         1856 :              + to_string (*term.transform.rhs) + ")";
     265              :     }
     266            0 :   rust_unreachable ();
     267              : }
     268              : 
     269              : std::string
     270          627 : GenericTyPerCrateCtx::to_string (SolutionIndex index) const
     271              : {
     272              :   // Search all values in def_id_to_solution_index_start and find key for
     273              :   // largest value smaller than index
     274          627 :   std::pair<HirId, SolutionIndex> best = {0, 0};
     275              : 
     276         2830 :   for (const auto &ty_map : map_from_ty_orig_ref)
     277              :     {
     278         2203 :       if (ty_map.second <= index && ty_map.first > best.first)
     279              :         best = ty_map;
     280              :     }
     281          627 :   rust_assert (best.first != 0);
     282              : 
     283          627 :   BaseType *ty = lookup_type (best.first);
     284              : 
     285          627 :   std::string result = "";
     286          627 :   if (auto adt = ty->try_as<ADTType> ())
     287              :     {
     288         1881 :       result += (adt->get_identifier ());
     289              :     }
     290              :   else
     291              :     {
     292            0 :       result += ty->as_string ();
     293              :     }
     294              : 
     295         1881 :   result += "[" + std::to_string (index - best.first) + "]";
     296          627 :   return result;
     297              : }
     298              : 
     299              : Variance
     300          489 : GenericTyPerCrateCtx::evaluate (Term *term)
     301              : {
     302          489 :   switch (term->kind)
     303              :     {
     304          163 :     case Term::CONST:
     305          163 :       return term->const_val;
     306          163 :     case Term::REF:
     307          163 :       return solutions[term->ref];
     308          163 :     case Term::TRANSFORM:
     309          163 :       return Variance::transform (evaluate (term->transform.lhs),
     310          163 :                                   evaluate (term->transform.rhs));
     311              :     }
     312            0 :   rust_unreachable ();
     313              : }
     314              : 
     315              : std::vector<Variance>
     316           78 : GenericTyPerCrateCtx::query_generic_variance (const ADTType &type)
     317              : {
     318           78 :   auto solution_index = lookup_type_index (type.get_orig_ref ());
     319           78 :   rust_assert (solution_index.has_value ());
     320           78 :   auto num_lifetimes = type.get_num_lifetime_params ();
     321           78 :   auto num_types = type.get_num_type_params ();
     322              : 
     323           78 :   std::vector<Variance> result;
     324           78 :   result.reserve (num_lifetimes + num_types);
     325              : 
     326          122 :   for (size_t i = 0; i < num_lifetimes + num_types; ++i)
     327              :     {
     328           44 :       result.push_back (solutions[solution_index.value () + i]);
     329              :     }
     330              : 
     331           78 :   return result;
     332              : }
     333              : 
     334              : FreeRegions
     335            5 : GenericTyPerCrateCtx::query_field_regions (const ADTType *parent,
     336              :                                            size_t variant_index,
     337              :                                            size_t field_index,
     338              :                                            const FreeRegions &parent_regions)
     339              : {
     340            5 :   auto orig = lookup_type (parent->get_orig_ref ());
     341            5 :   FieldVisitorCtx ctx (*this, *parent->as<const SubstitutionRef> (),
     342            5 :                        parent_regions);
     343            5 :   return ctx.collect_regions (*orig->as<const ADTType> ()
     344            5 :                                  ->get_variants ()
     345            5 :                                  .at (variant_index)
     346            5 :                                  ->get_fields ()
     347            5 :                                  .at (field_index)
     348            5 :                                  ->get_field_type ());
     349            5 : }
     350              : std::vector<Region>
     351           76 : GenericTyPerCrateCtx::query_type_regions (BaseType *type)
     352              : {
     353           76 :   TyVisitorCtx ctx (*this);
     354           76 :   return ctx.collect_regions (*type);
     355           76 : }
     356              : 
     357              : SolutionIndex
     358         3417 : GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id)
     359              : {
     360         3417 :   BaseType *ty = lookup_type (hir_id);
     361         3417 :   auto index = ctx.lookup_type_index (hir_id);
     362         3417 :   if (index.has_value ())
     363              :     {
     364          285 :       return index.value ();
     365              :     }
     366              : 
     367         3132 :   SubstitutionRef *subst = nullptr;
     368         3132 :   if (auto adt = ty->try_as<ADTType> ())
     369              :     {
     370         3132 :       subst = adt;
     371              :     }
     372              :   else
     373              :     {
     374            0 :       rust_sorry_at (
     375              :         ty->get_locus (),
     376              :         "This is a compiler bug: Unhandled type in variance analysis");
     377              :     }
     378            0 :   rust_assert (subst != nullptr);
     379              : 
     380         3132 :   auto solution_index = ctx.solutions.size ();
     381         3132 :   ctx.map_from_ty_orig_ref.emplace (ty->get_orig_ref (), solution_index);
     382              : 
     383         3132 :   auto num_lifetime_param = subst->get_used_arguments ().get_regions ().size ();
     384         3132 :   auto num_type_param = subst->get_num_substitutions ();
     385              : 
     386         4337 :   for (size_t i = 0; i < num_lifetime_param + num_type_param; ++i)
     387         1205 :     ctx.solutions.emplace_back (Variance::bivariant ());
     388              : 
     389         3132 :   return solution_index;
     390              : }
     391              : 
     392              : void
     393         5265 : GenericTyVisitorCtx::add_constraints_from_ty (BaseType *type, Term variance)
     394              : {
     395         5265 :   rust_debug ("\tadd_constraint_from_ty: %s with v=%s",
     396              :               type->as_string ().c_str (), ctx.to_string (variance).c_str ());
     397              : 
     398         5265 :   Visitor visitor (*this, variance);
     399         5265 :   type->accept_vis (visitor);
     400         5265 : }
     401              : 
     402              : void
     403         1459 : GenericTyVisitorCtx::add_constraint (SolutionIndex index, Term term)
     404              : {
     405         1459 :   rust_debug ("\t\tadd_constraint: %s", ctx.to_string (term).c_str ());
     406              : 
     407         1459 :   if (term.kind == Term::CONST)
     408              :     {
     409              :       // Constant terms do not depend on other solutions, so we can
     410              :       // immediately apply them.
     411         1320 :       ctx.solutions[index].join (term.const_val);
     412              :     }
     413              :   else
     414              :     {
     415          139 :       ctx.constraints.emplace_back (index, new Term (term));
     416              :     }
     417         1459 : }
     418              : 
     419              : void
     420           23 : GenericTyVisitorCtx::add_constraints_from_region (const Region &region,
     421              :                                                   Term term)
     422              : {
     423           23 :   if (region.is_early_bound ())
     424              :     {
     425           14 :       add_constraint (first_lifetime + region.get_index (), term);
     426              :     }
     427           23 : }
     428              : 
     429              : void
     430          349 : GenericTyVisitorCtx::add_constraints_from_generic_args (HirId ref,
     431              :                                                         SubstitutionRef &subst,
     432              :                                                         Term variance,
     433              :                                                         bool invariant_args)
     434              : {
     435          349 :   SolutionIndex solution_index = lookup_or_add_type (ref);
     436              : 
     437          349 :   size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
     438          349 :   size_t num_types = subst.get_substs ().size ();
     439              : 
     440          518 :   for (size_t i = 0; i < num_lifetimes + num_types; ++i)
     441              :     {
     442              :       // TODO: What about variance from other crates?
     443          169 :       auto variance_i
     444              :         = invariant_args
     445          169 :             ? Term::make_transform (variance, Variance::invariant ())
     446          169 :             : Term::make_transform (variance,
     447              :                                     Term::make_ref (solution_index + i));
     448              : 
     449          169 :       if (i < num_lifetimes)
     450              :         {
     451            1 :           auto region_i = i;
     452            1 :           auto &region
     453            1 :             = subst.get_substitution_arguments ().get_mut_regions ()[region_i];
     454            1 :           add_constraints_from_region (region, variance_i);
     455              :         }
     456              :       else
     457              :         {
     458          168 :           auto type_i = i - num_lifetimes;
     459          168 :           auto arg = subst.get_arg_at (type_i);
     460          168 :           if (arg.has_value ())
     461              :             {
     462          160 :               add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
     463              :             }
     464              :         }
     465              :     }
     466          349 : }
     467              : void
     468         1445 : GenericTyVisitorCtx::add_constrints_from_param (ParamType &type, Term variance)
     469              : {
     470         1445 :   auto it
     471         1445 :     = std::find (param_names.begin (), param_names.end (), type.get_name ());
     472         1445 :   rust_assert (it != param_names.end ());
     473              : 
     474         1445 :   auto index = first_type + std::distance (param_names.begin (), it);
     475              : 
     476         1445 :   add_constraint (index, variance);
     477         1445 : }
     478              : 
     479              : Term
     480            9 : GenericTyVisitorCtx::contra (Term variance)
     481              : {
     482            9 :   return Term::make_transform (variance, Variance::contravariant ());
     483              : }
     484              : 
     485              : void
     486          820 : TyVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
     487              : {
     488          820 :   Visitor visitor (*this, variance);
     489          820 :   ty->accept_vis (visitor);
     490          820 : }
     491              : 
     492              : void
     493          284 : TyVisitorCtx::add_constraints_from_region (const Region &region,
     494              :                                            Variance variance)
     495              : {
     496          284 :   variances.push_back (variance);
     497          284 :   regions.push_back (region);
     498          284 : }
     499              : 
     500              : void
     501           78 : TyVisitorCtx::add_constraints_from_generic_args (HirId ref,
     502              :                                                  SubstitutionRef &subst,
     503              :                                                  Variance variance,
     504              :                                                  bool invariant_args)
     505              : {
     506              :   // Handle function
     507           78 :   auto variances
     508           78 :     = ctx.query_generic_variance (*lookup_type (ref)->as<ADTType> ());
     509              : 
     510           78 :   size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
     511           78 :   size_t num_types = subst.get_substs ().size ();
     512              : 
     513          122 :   for (size_t i = 0; i < num_lifetimes + num_types; ++i)
     514              :     {
     515              :       // TODO: What about variance from other crates?
     516           44 :       auto variance_i
     517              :         = invariant_args
     518           44 :             ? Variance::transform (variance, Variance::invariant ())
     519           44 :             : Variance::transform (variance, variances[i]);
     520              : 
     521           44 :       if (i < num_lifetimes)
     522              :         {
     523           44 :           auto region_i = i;
     524           44 :           auto &region = subst.get_used_arguments ().get_regions ()[region_i];
     525           44 :           add_constraints_from_region (region, variance_i);
     526              :         }
     527              :       else
     528              :         {
     529            0 :           auto type_i = i - num_lifetimes;
     530            0 :           auto arg = subst.get_arg_at (type_i);
     531            0 :           if (arg.has_value ())
     532              :             {
     533            0 :               add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
     534              :             }
     535              :         }
     536              :     }
     537           78 : }
     538              : 
     539              : FreeRegions
     540            5 : FieldVisitorCtx::collect_regions (BaseType &ty)
     541              : {
     542              :   // Segment the regions into ranges for each type parameter. Type parameter
     543              :   // at index i contains regions from type_param_ranges[i] to
     544              :   // type_param_ranges[i+1] (exclusive).;
     545            5 :   type_param_ranges.push_back (subst.get_num_lifetime_params ());
     546              : 
     547            5 :   for (size_t i = 0; i < subst.get_num_type_params (); i++)
     548              :     {
     549            0 :       auto arg = subst.get_arg_at (i);
     550            0 :       rust_assert (arg.has_value ());
     551            0 :       type_param_ranges.push_back (
     552            0 :         ctx.query_type_regions (arg.value ().get_tyty ()).size ());
     553              :     }
     554              : 
     555            5 :   add_constraints_from_ty (&ty, Variance::covariant ());
     556            5 :   return regions;
     557              : }
     558              : 
     559              : void
     560            8 : FieldVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
     561              : {
     562            8 :   Visitor visitor (*this, variance);
     563            8 :   ty->accept_vis (visitor);
     564            8 : }
     565              : 
     566              : void
     567            3 : FieldVisitorCtx::add_constraints_from_region (const Region &region,
     568              :                                               Variance variance)
     569              : {
     570            3 :   if (region.is_early_bound ())
     571              :     {
     572            3 :       regions.push_back (parent_regions[region.get_index ()]);
     573              :     }
     574            0 :   else if (region.is_late_bound ())
     575              :     {
     576            0 :       rust_debug ("Ignoring late bound region");
     577              :     }
     578            3 : }
     579              : 
     580              : void
     581            0 : FieldVisitorCtx::add_constrints_from_param (ParamType &param, Variance variance)
     582              : {
     583            0 :   size_t param_i = subst.get_used_arguments ().find_symbol (param).value ();
     584            0 :   for (size_t i = type_param_ranges[param_i];
     585            0 :        i < type_param_ranges[param_i + 1]; i++)
     586              :     {
     587            0 :       regions.push_back (parent_regions[i]);
     588              :     }
     589            0 : }
     590              : 
     591              : Variance
     592            0 : TyVisitorCtx::contra (Variance variance)
     593              : {
     594            0 :   return Variance::transform (variance, Variance::contravariant ());
     595              : }
     596              : 
     597              : Term
     598          169 : Term::make_ref (SolutionIndex index)
     599              : {
     600          169 :   Term term;
     601          169 :   term.kind = REF;
     602          169 :   term.ref = index;
     603          169 :   return term;
     604              : }
     605              : 
     606              : Term
     607          178 : Term::make_transform (Term lhs, Term rhs)
     608              : {
     609          178 :   if (lhs.is_const () && rhs.is_const ())
     610              :     {
     611            9 :       return Variance::transform (lhs.const_val, rhs.const_val);
     612              :     }
     613              : 
     614          169 :   Term term;
     615          169 :   term.kind = TRANSFORM;
     616          169 :   term.transform.lhs = new Term (lhs);
     617          169 :   term.transform.rhs = new Term (rhs);
     618          169 :   return term;
     619              : }
     620              : 
     621              : } // namespace VarianceAnalysis
     622              : } // namespace TyTy
     623              : } // namespace Rust
        

Generated by: LCOV version 2.4-beta

LCOV profile is generated on x86_64 machine using following configure options: configure --disable-bootstrap --enable-coverage=opt --enable-languages=c,c++,fortran,go,jit,lto,rust,m2 --enable-host-shared. GCC test suite is run with the built compiler.