Branch data Line data Source code
1 : : // Copyright (C) 2025 Free Software Foundation, Inc.
2 : :
3 : : // This file is part of GCC.
4 : :
5 : : // GCC is free software; you can redistribute it and/or modify it under
6 : : // the terms of the GNU General Public License as published by the Free
7 : : // Software Foundation; either version 3, or (at your option) any later
8 : : // version.
9 : :
10 : : // GCC is distributed in the hope that it will be useful, but WITHOUT ANY
11 : : // WARRANTY; without even the implied warranty of MERCHANTABILITY or
12 : : // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 : : // for more details.
14 : :
15 : : // You should have received a copy of the GNU General Public License
16 : : // along with GCC; see the file COPYING3. If not see
17 : : // <http://www.gnu.org/licenses/>.
18 : :
19 : : #include "rust-derive-ord.h"
20 : : #include "rust-ast.h"
21 : : #include "rust-derive-cmp-common.h"
22 : : #include "rust-derive.h"
23 : : #include "rust-item.h"
24 : : #include "rust-system.h"
25 : :
26 : : namespace Rust {
27 : : namespace AST {
28 : :
29 : 75 : DeriveOrd::DeriveOrd (Ordering ordering, location_t loc)
30 : 75 : : DeriveVisitor (loc), ordering (ordering)
31 : 75 : {}
32 : :
33 : : std::unique_ptr<Item>
34 : 75 : DeriveOrd::go (Item &item)
35 : : {
36 : 75 : item.accept_vis (*this);
37 : :
38 : 75 : return std::move (expanded);
39 : : }
40 : :
41 : : std::unique_ptr<Expr>
42 : 191 : DeriveOrd::cmp_call (std::unique_ptr<Expr> &&self_expr,
43 : : std::unique_ptr<Expr> &&other_expr)
44 : : {
45 : 1146 : auto cmp_fn_path = builder.path_in_expression (
46 : 382 : {"core", "cmp", trait (ordering), fn (ordering)}, true);
47 : :
48 : 382 : return builder.call (ptrify (cmp_fn_path),
49 : 382 : vec (builder.ref (std::move (self_expr)),
50 : 573 : builder.ref (std::move (other_expr))));
51 : 191 : }
52 : :
53 : : std::unique_ptr<Item>
54 : 75 : DeriveOrd::cmp_impl (
55 : : std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name,
56 : : const std::vector<std::unique_ptr<GenericParam>> &type_generics)
57 : : {
58 : 75 : auto fn = cmp_fn (std::move (fn_block), type_name);
59 : :
60 : 75 : auto trait = ordering == Ordering::Partial ? "PartialOrd" : "Ord";
61 : 300 : auto trait_path = builder.type_path ({"core", "cmp", trait}, true);
62 : :
63 : 75 : auto trait_bound
64 : 300 : = builder.trait_bound (builder.type_path ({"core", "cmp", trait}, true));
65 : :
66 : 75 : auto trait_items = vec (std::move (fn));
67 : :
68 : 75 : auto cmp_generics
69 : : = setup_impl_generics (type_name.as_string (), type_generics,
70 : 75 : std::move (trait_bound));
71 : :
72 : 150 : return builder.trait_impl (trait_path, std::move (cmp_generics.self_type),
73 : : std::move (trait_items),
74 : 75 : std::move (cmp_generics.impl));
75 : 75 : }
76 : :
77 : : std::unique_ptr<AssociatedItem>
78 : 75 : DeriveOrd::cmp_fn (std::unique_ptr<BlockExpr> &&block, Identifier type_name)
79 : : {
80 : : // Ordering
81 : 75 : auto return_type = builder.type_path ({"core", "cmp", "Ordering"}, true);
82 : :
83 : : // In the case of PartialOrd, we return an Option<Ordering>
84 : 75 : if (ordering == Ordering::Partial)
85 : : {
86 : 38 : auto generic = GenericArg::create_type (ptrify (return_type));
87 : :
88 : 38 : auto generic_seg = builder.type_path_segment_generic (
89 : 76 : "Option", GenericArgs ({}, {generic}, {}, loc));
90 : 38 : auto core = builder.type_path_segment ("core");
91 : 38 : auto option = builder.type_path_segment ("option");
92 : :
93 : 38 : return_type
94 : 76 : = builder.type_path (vec (std::move (core), std::move (option),
95 : : std::move (generic_seg)),
96 : 38 : true);
97 : 38 : }
98 : :
99 : : // &self, other: &Self
100 : 75 : auto params = vec (
101 : 150 : builder.self_ref_param (),
102 : 150 : builder.function_param (builder.identifier_pattern ("other"),
103 : 150 : builder.reference_type (ptrify (
104 : 300 : builder.type_path (type_name.as_string ())))));
105 : :
106 : 75 : auto function_name = fn (ordering);
107 : :
108 : 375 : return builder.function (function_name, std::move (params),
109 : 225 : ptrify (return_type), std::move (block));
110 : 75 : }
111 : :
112 : : std::unique_ptr<Pattern>
113 : 86 : DeriveOrd::make_equal ()
114 : : {
115 : 86 : std::unique_ptr<Pattern> equal = ptrify (
116 : 172 : builder.path_in_expression ({"core", "cmp", "Ordering", "Equal"}, true));
117 : :
118 : : // We need to wrap the pattern in Option::Some if we are doing partial
119 : : // ordering
120 : 86 : if (ordering == Ordering::Partial)
121 : : {
122 : 55 : auto pattern_items = std::unique_ptr<TupleStructItems> (
123 : 55 : new TupleStructItemsNoRange (vec (std::move (equal))));
124 : :
125 : 55 : equal
126 : 110 : = std::make_unique<TupleStructPattern> (builder.path_in_expression (
127 : : LangItem::Kind::OPTION_SOME),
128 : 55 : std::move (pattern_items));
129 : 55 : }
130 : :
131 : 86 : return equal;
132 : : }
133 : :
134 : : std::pair<MatchArm, MatchArm>
135 : 86 : DeriveOrd::make_cmp_arms ()
136 : : {
137 : : // All comparison results other than Ordering::Equal
138 : 86 : auto non_equal = builder.identifier_pattern (DeriveOrd::not_equal);
139 : 86 : auto equal = make_equal ();
140 : :
141 : 86 : return {builder.match_arm (std::move (equal)),
142 : 172 : builder.match_arm (std::move (non_equal))};
143 : 86 : }
144 : :
145 : : std::unique_ptr<Expr>
146 : 83 : DeriveOrd::recursive_match (std::vector<SelfOther> &&members)
147 : : {
148 : 83 : if (members.empty ())
149 : : {
150 : 0 : std::unique_ptr<Expr> value = ptrify (
151 : 0 : builder.path_in_expression ({"core", "cmp", "Ordering", "Equal"},
152 : 0 : true));
153 : :
154 : 0 : if (ordering == Ordering::Partial)
155 : 0 : value = builder.call (ptrify (builder.path_in_expression (
156 : : LangItem::Kind::OPTION_SOME)),
157 : 0 : std::move (value));
158 : :
159 : : return value;
160 : : }
161 : :
162 : 83 : std::unique_ptr<Expr> final_expr = nullptr;
163 : :
164 : 252 : for (auto it = members.rbegin (); it != members.rend (); it++)
165 : : {
166 : 169 : auto &member = *it;
167 : :
168 : 169 : auto call = cmp_call (std::move (member.self_expr),
169 : 169 : std::move (member.other_expr));
170 : :
171 : : // For the last member (so the first iterator), we just create a call
172 : : // expression
173 : 169 : if (it == members.rbegin ())
174 : : {
175 : 83 : final_expr = std::move (call);
176 : 83 : continue;
177 : : }
178 : :
179 : : // If we aren't dealing with the last member, then we need to wrap all of
180 : : // that in a big match expression and keep going
181 : 86 : auto match_arms = make_cmp_arms ();
182 : :
183 : 86 : auto match_cases
184 : : = {builder.match_case (std::move (match_arms.first),
185 : : std::move (final_expr)),
186 : : builder.match_case (std::move (match_arms.second),
187 : 430 : builder.identifier (DeriveOrd::not_equal))};
188 : :
189 : 86 : final_expr = builder.match (std::move (call), std::move (match_cases));
190 : 427 : }
191 : :
192 : 83 : return final_expr;
193 : 83 : }
194 : :
195 : : // we need to do a recursive match expression for all of the fields used in a
196 : : // struct so for something like struct Foo { a: i32, b: i32, c: i32 } we must
197 : : // first compare each `a` field, then `b`, then `c`, like this:
198 : : //
199 : : // match cmp_fn(self.<field>, other.<field>) {
200 : : // Ordering::Equal => <recurse>,
201 : : // cmp => cmp,
202 : : // }
203 : : //
204 : : // and the recurse will be the exact same expression, on the next field. so that
205 : : // our result looks like this:
206 : : //
207 : : // match cmp_fn(self.a, other.a) {
208 : : // Ordering::Equal => match cmp_fn(self.b, other.b) {
209 : : // Ordering::Equal =>cmp_fn(self.c, other.c),
210 : : // cmp => cmp,
211 : : // }
212 : : // cmp => cmp,
213 : : // }
214 : : //
215 : : // the last field comparison needs not to be a match but just the function call.
216 : : // this is going to be annoying lol
217 : : void
218 : 53 : DeriveOrd::visit_struct (StructStruct &item)
219 : : {
220 : 53 : auto fields = SelfOther::fields (builder, item.get_fields ());
221 : :
222 : 53 : auto match_expr = recursive_match (std::move (fields));
223 : :
224 : 106 : expanded = cmp_impl (builder.block (std::move (match_expr)),
225 : 106 : item.get_identifier (), item.get_generic_params ());
226 : 53 : }
227 : :
228 : : // same as structs, but for each field index instead of each field name -
229 : : // straightforward once we have `visit_struct` working
230 : : void
231 : 0 : DeriveOrd::visit_tuple (TupleStruct &item)
232 : : {
233 : 0 : auto fields = SelfOther::indexes (builder, item.get_fields ());
234 : :
235 : 0 : auto match_expr = recursive_match (std::move (fields));
236 : :
237 : 0 : expanded = cmp_impl (builder.block (std::move (match_expr)),
238 : 0 : item.get_identifier (), item.get_generic_params ());
239 : 0 : }
240 : :
241 : : // for enums, we need to generate a match for each of the enum's variant that
242 : : // contains data and then do the same thing as visit_struct or visit_enum. if
243 : : // the two aren't the same variant, then compare the two discriminant values for
244 : : // all the dataless enum variants and in the general case.
245 : : //
246 : : // so for enum Foo { A(i32, i32), B, C } we need to do the following
247 : : //
248 : : // match (self, other) {
249 : : // (A(self_0, self_1), A(other_0, other_1)) => {
250 : : // match cmp_fn(self_0, other_0) {
251 : : // Ordering::Equal => cmp_fn(self_1, other_1),
252 : : // cmp => cmp,
253 : : // },
254 : : // _ => cmp_fn(discr_value(self), discr_value(other))
255 : : // }
256 : : void
257 : 22 : DeriveOrd::visit_enum (Enum &item)
258 : : {
259 : : // NOTE: We can factor this even further with DerivePartialEq, but this is
260 : : // getting out of scope for this PR surely
261 : :
262 : 22 : auto cases = std::vector<MatchCase> ();
263 : 44 : auto type_name = item.get_identifier ().as_string ();
264 : :
265 : 22 : auto let_sd = builder.discriminant_value (DeriveOrd::self_discr, "self");
266 : 22 : auto let_od = builder.discriminant_value (DeriveOrd::other_discr, "other");
267 : :
268 : 44 : auto discr_cmp = cmp_call (builder.identifier (DeriveOrd::self_discr),
269 : 66 : builder.identifier (DeriveOrd::other_discr));
270 : :
271 : 52 : auto recursive_match_fn = [this] (std::vector<SelfOther> &&fields) {
272 : 30 : return recursive_match (std::move (fields));
273 : 22 : };
274 : :
275 : 74 : for (auto &variant : item.get_variants ())
276 : : {
277 : 52 : auto enum_builder
278 : 104 : = EnumMatchBuilder (type_name, variant->get_identifier ().as_string (),
279 : 52 : recursive_match_fn, builder);
280 : :
281 : 52 : switch (variant->get_enum_item_kind ())
282 : : {
283 : 8 : case EnumItem::Kind::Struct:
284 : 8 : cases.emplace_back (enum_builder.strukt (*variant));
285 : 8 : break;
286 : 22 : case EnumItem::Kind::Tuple:
287 : 22 : cases.emplace_back (enum_builder.tuple (*variant));
288 : 22 : break;
289 : : case EnumItem::Kind::Identifier:
290 : : case EnumItem::Kind::Discriminant:
291 : : // We don't need to do anything for these, as they are handled by the
292 : : // discriminant value comparison
293 : : break;
294 : : }
295 : 52 : }
296 : :
297 : : // Add the last case which compares the discriminant values in case `self` and
298 : : // `other` are actually different variants of the enum
299 : 22 : cases.emplace_back (
300 : 44 : builder.match_case (builder.wildcard (), std::move (discr_cmp)));
301 : :
302 : 22 : auto match
303 : 44 : = builder.match (builder.tuple (vec (builder.identifier ("self"),
304 : 44 : builder.identifier ("other"))),
305 : 22 : std::move (cases));
306 : :
307 : 22 : expanded
308 : 44 : = cmp_impl (builder.block (vec (std::move (let_sd), std::move (let_od)),
309 : : std::move (match)),
310 : 110 : type_name, item.get_generic_params ());
311 : 22 : }
312 : :
313 : : void
314 : 0 : DeriveOrd::visit_union (Union &item)
315 : : {
316 : 0 : auto trait_name = trait (ordering);
317 : :
318 : 0 : rust_error_at (item.get_locus (), "derive(%s) cannot be used on unions",
319 : : trait_name.c_str ());
320 : 0 : }
321 : :
322 : : } // namespace AST
323 : : } // namespace Rust
|