Line data Source code
1 : // Copyright (C) 2025-2026 Free Software Foundation, Inc.
2 :
3 : // This file is part of GCC.
4 :
5 : // GCC is free software; you can redistribute it and/or modify it under
6 : // the terms of the GNU General Public License as published by the Free
7 : // Software Foundation; either version 3, or (at your option) any later
8 : // version.
9 :
10 : // GCC is distributed in the hope that it will be useful, but WITHOUT ANY
11 : // WARRANTY; without even the implied warranty of MERCHANTABILITY or
12 : // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 : // for more details.
14 :
15 : // You should have received a copy of the GNU General Public License
16 : // along with GCC; see the file COPYING3. If not see
17 : // <http://www.gnu.org/licenses/>.
18 :
19 : #include "rust-derive-eq.h"
20 : #include "rust-ast.h"
21 : #include "rust-expr.h"
22 : #include "rust-item.h"
23 : #include "rust-path.h"
24 : #include "rust-pattern.h"
25 : #include "rust-system.h"
26 :
27 : namespace Rust {
28 : namespace AST {
29 :
30 : static TypePath
31 94 : get_eq_trait_path (Builder &builder)
32 : {
33 94 : return builder.type_path ({"core", "cmp", "Eq"}, true);
34 : }
35 :
36 47 : DeriveEq::DeriveEq (location_t loc) : DeriveVisitor (loc) {}
37 :
38 : std::vector<std::unique_ptr<AST::Item>>
39 47 : DeriveEq::go (Item &item)
40 : {
41 47 : item.accept_vis (*this);
42 :
43 47 : return std::move (expanded);
44 : }
45 :
46 : std::unique_ptr<AssociatedItem>
47 47 : DeriveEq::assert_receiver_is_total_eq_fn (
48 : std::vector<std::unique_ptr<Type>> &&types)
49 : {
50 47 : auto stmts = std::vector<std::unique_ptr<Stmt>> ();
51 :
52 47 : stmts.emplace_back (assert_param_is_eq ());
53 :
54 125 : for (auto &&type : types)
55 78 : stmts.emplace_back (assert_type_is_eq (std::move (type)));
56 :
57 47 : auto block = std::unique_ptr<BlockExpr> (
58 47 : new BlockExpr (std::move (stmts), nullptr, {}, {}, tl::nullopt, loc, loc));
59 :
60 47 : auto self = builder.self_ref_param ();
61 :
62 141 : return builder.function ("assert_receiver_is_total_eq",
63 141 : vec (std::move (self)), {}, std::move (block));
64 47 : }
65 :
66 : std::unique_ptr<Stmt>
67 47 : DeriveEq::assert_param_is_eq ()
68 : {
69 47 : auto eq_bound = std::unique_ptr<TypeParamBound> (
70 47 : new TraitBound (get_eq_trait_path (builder), loc));
71 :
72 47 : auto sized_bound = std::unique_ptr<TypeParamBound> (
73 47 : new TraitBound (builder.type_path (LangItem::Kind::SIZED), loc, false,
74 47 : true /* opening_question_mark */));
75 :
76 47 : auto bounds = vec (std::move (eq_bound), std::move (sized_bound));
77 :
78 47 : auto assert_param_is_eq = "AssertParamIsEq";
79 :
80 47 : auto t = std::unique_ptr<GenericParam> (
81 94 : new TypeParam (Identifier ("T"), loc, std::move (bounds)));
82 :
83 141 : return builder.struct_struct (
84 94 : assert_param_is_eq, vec (std::move (t)),
85 : {StructField (
86 141 : Identifier ("_t"),
87 94 : builder.single_generic_type_path (
88 : LangItem::Kind::PHANTOM_DATA,
89 47 : GenericArgs (
90 141 : {}, {GenericArg::create_type (builder.single_type_path ("T"))}, {})),
91 141 : Visibility::create_private (), loc)});
92 47 : }
93 :
94 : std::unique_ptr<Stmt>
95 78 : DeriveEq::assert_type_is_eq (std::unique_ptr<Type> &&type)
96 : {
97 78 : auto assert_param_is_eq = "AssertParamIsEq";
98 :
99 : // AssertParamIsCopy::<Self>
100 78 : auto assert_param_is_eq_ty
101 : = std::unique_ptr<TypePathSegment> (new TypePathSegmentGeneric (
102 156 : PathIdentSegment (assert_param_is_eq, loc), false,
103 234 : GenericArgs ({}, {GenericArg::create_type (std::move (type))}, {}, loc),
104 156 : loc));
105 :
106 : // TODO: Improve this, it's really ugly
107 78 : auto type_paths = std::vector<std::unique_ptr<TypePathSegment>> ();
108 78 : type_paths.emplace_back (std::move (assert_param_is_eq_ty));
109 :
110 78 : auto full_path
111 78 : = std::unique_ptr<Type> (new TypePath ({std::move (type_paths)}, loc));
112 :
113 78 : return builder.let (builder.wildcard (), std::move (full_path));
114 78 : }
115 :
116 : std::vector<std::unique_ptr<Item>>
117 47 : DeriveEq::eq_impls (
118 : std::unique_ptr<AssociatedItem> &&fn, std::string name,
119 : const std::vector<std::unique_ptr<GenericParam>> &type_generics)
120 : {
121 47 : auto eq = [this] () { return get_eq_trait_path (builder); };
122 0 : auto eq_bound = [&, this] () { return builder.trait_bound (eq ()); };
123 :
124 47 : auto steq = builder.type_path (LangItem::Kind::STRUCTURAL_TEQ);
125 :
126 47 : auto trait_items = vec (std::move (fn));
127 :
128 47 : auto eq_generics = setup_impl_generics (name, type_generics, eq_bound);
129 47 : auto steq_generics = setup_impl_generics (name, type_generics);
130 :
131 94 : auto eq_impl = builder.trait_impl (eq (), std::move (eq_generics.self_type),
132 : std::move (trait_items),
133 47 : std::move (eq_generics.impl));
134 :
135 : // StructuralEq is a marker trait
136 47 : decltype (trait_items) steq_trait_items = {};
137 :
138 47 : auto steq_impl
139 94 : = builder.trait_impl (steq, std::move (steq_generics.self_type),
140 : std::move (steq_trait_items),
141 47 : std::move (steq_generics.impl));
142 :
143 47 : return vec (std::move (eq_impl), std::move (steq_impl));
144 141 : }
145 :
146 : void
147 1 : DeriveEq::visit_tuple (TupleStruct &item)
148 : {
149 1 : auto types = std::vector<std::unique_ptr<Type>> ();
150 :
151 2 : for (auto &field : item.get_fields ())
152 1 : types.emplace_back (field.get_field_type ().reconstruct ());
153 :
154 2 : expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
155 1 : item.get_identifier ().as_string (),
156 2 : item.get_generic_params ());
157 1 : }
158 :
159 : void
160 39 : DeriveEq::visit_struct (StructStruct &item)
161 : {
162 39 : auto types = std::vector<std::unique_ptr<Type>> ();
163 :
164 109 : for (auto &field : item.get_fields ())
165 70 : types.emplace_back (field.get_field_type ().reconstruct ());
166 :
167 78 : expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
168 39 : item.get_identifier ().as_string (),
169 78 : item.get_generic_params ());
170 39 : }
171 :
172 : void
173 7 : DeriveEq::visit_enum (Enum &item)
174 : {
175 7 : auto types = std::vector<std::unique_ptr<Type>> ();
176 :
177 21 : for (auto &variant : item.get_variants ())
178 : {
179 14 : switch (variant->get_enum_item_kind ())
180 : {
181 7 : case EnumItem::Kind::Identifier:
182 7 : case EnumItem::Kind::Discriminant:
183 : // nothing to do as they contain no inner types
184 7 : continue;
185 7 : case EnumItem::Kind::Tuple:
186 7 : {
187 7 : auto &tuple = static_cast<EnumItemTuple &> (*variant);
188 :
189 14 : for (auto &field : tuple.get_tuple_fields ())
190 7 : types.emplace_back (field.get_field_type ().reconstruct ());
191 : break;
192 : }
193 0 : case EnumItem::Kind::Struct:
194 0 : {
195 0 : auto &tuple = static_cast<EnumItemStruct &> (*variant);
196 :
197 0 : for (auto &field : tuple.get_struct_fields ())
198 0 : types.emplace_back (field.get_field_type ().reconstruct ());
199 :
200 : break;
201 : }
202 7 : }
203 : }
204 :
205 21 : expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
206 7 : item.get_identifier ().as_string (),
207 14 : item.get_generic_params ());
208 7 : }
209 :
210 : void
211 0 : DeriveEq::visit_union (Union &item)
212 : {
213 0 : auto types = std::vector<std::unique_ptr<Type>> ();
214 :
215 0 : for (auto &field : item.get_variants ())
216 0 : types.emplace_back (field.get_field_type ().reconstruct ());
217 :
218 0 : expanded = eq_impls (assert_receiver_is_total_eq_fn (std::move (types)),
219 0 : item.get_identifier ().as_string (),
220 0 : item.get_generic_params ());
221 0 : }
222 :
223 : } // namespace AST
224 : } // namespace Rust
|