Line data Source code
1 : // Copyright (C) 2025-2026 Free Software Foundation, Inc.
2 :
3 : // This file is part of GCC.
4 :
5 : // GCC is free software; you can redistribute it and/or modify it under
6 : // the terms of the GNU General Public License as published by the Free
7 : // Software Foundation; either version 3, or (at your option) any later
8 : // version.
9 :
10 : // GCC is distributed in the hope that it will be useful, but WITHOUT ANY
11 : // WARRANTY; without even the implied warranty of MERCHANTABILITY or
12 : // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 : // for more details.
14 :
15 : // You should have received a copy of the GNU General Public License
16 : // along with GCC; see the file COPYING3. If not see
17 : // <http://www.gnu.org/licenses/>.
18 :
19 : #include "rust-derive-ord.h"
20 : #include "rust-ast.h"
21 : #include "rust-derive-cmp-common.h"
22 : #include "rust-derive.h"
23 : #include "rust-item.h"
24 : #include "rust-system.h"
25 :
26 : namespace Rust {
27 : namespace AST {
28 :
29 75 : DeriveOrd::DeriveOrd (Ordering ordering, location_t loc)
30 75 : : DeriveVisitor (loc), ordering (ordering)
31 75 : {}
32 :
33 : std::unique_ptr<Item>
34 75 : DeriveOrd::go (Item &item)
35 : {
36 75 : item.accept_vis (*this);
37 :
38 75 : return std::move (expanded);
39 : }
40 :
41 : std::unique_ptr<Expr>
42 191 : DeriveOrd::cmp_call (std::unique_ptr<Expr> &&self_expr,
43 : std::unique_ptr<Expr> &&other_expr)
44 : {
45 1146 : auto cmp_fn_path = builder.path_in_expression (
46 382 : {"core", "cmp", trait (ordering), fn (ordering)}, true);
47 :
48 382 : return builder.call (ptrify (cmp_fn_path),
49 382 : vec (builder.ref (std::move (self_expr)),
50 573 : builder.ref (std::move (other_expr))));
51 191 : }
52 :
53 : std::unique_ptr<Item>
54 75 : DeriveOrd::cmp_impl (
55 : std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name,
56 : const std::vector<std::unique_ptr<GenericParam>> &type_generics)
57 : {
58 75 : auto fn = cmp_fn (std::move (fn_block), type_name);
59 :
60 75 : auto trait = ordering == Ordering::Partial ? "PartialOrd" : "Ord";
61 150 : auto trait_path = [&, this] () {
62 300 : return builder.type_path ({"core", "cmp", trait}, true);
63 75 : };
64 :
65 75 : auto trait_bound
66 0 : = [&, this] () { return builder.trait_bound (trait_path ()); };
67 :
68 75 : auto trait_items = vec (std::move (fn));
69 :
70 75 : auto cmp_generics
71 75 : = setup_impl_generics (type_name.as_string (), type_generics, trait_bound);
72 :
73 150 : return builder.trait_impl (trait_path (), std::move (cmp_generics.self_type),
74 : std::move (trait_items),
75 150 : std::move (cmp_generics.impl));
76 75 : }
77 :
78 : std::unique_ptr<AssociatedItem>
79 75 : DeriveOrd::cmp_fn (std::unique_ptr<BlockExpr> &&block, Identifier type_name)
80 : {
81 : // Ordering
82 75 : auto return_type = builder.type_path ({"core", "cmp", "Ordering"}, true);
83 :
84 : // In the case of PartialOrd, we return an Option<Ordering>
85 75 : if (ordering == Ordering::Partial)
86 : {
87 38 : auto generic = GenericArg::create_type (ptrify (return_type));
88 :
89 38 : auto generic_seg = builder.type_path_segment_generic (
90 76 : "Option", GenericArgs ({}, {generic}, {}, loc));
91 38 : auto core = builder.type_path_segment ("core");
92 38 : auto option = builder.type_path_segment ("option");
93 :
94 38 : return_type
95 76 : = builder.type_path (vec (std::move (core), std::move (option),
96 : std::move (generic_seg)),
97 38 : true);
98 38 : }
99 :
100 : // &self, other: &Self
101 75 : auto params = vec (
102 150 : builder.self_ref_param (),
103 150 : builder.function_param (builder.identifier_pattern ("other"),
104 150 : builder.reference_type (ptrify (
105 300 : builder.type_path (type_name.as_string ())))));
106 :
107 75 : auto function_name = fn (ordering);
108 :
109 375 : return builder.function (function_name, std::move (params),
110 225 : ptrify (return_type), std::move (block));
111 75 : }
112 :
113 : std::unique_ptr<Pattern>
114 86 : DeriveOrd::make_equal ()
115 : {
116 86 : std::unique_ptr<Pattern> equal = ptrify (
117 172 : builder.path_in_expression ({"core", "cmp", "Ordering", "Equal"}, true));
118 :
119 : // We need to wrap the pattern in Option::Some if we are doing partial
120 : // ordering
121 86 : if (ordering == Ordering::Partial)
122 : {
123 55 : auto pattern_items = std::unique_ptr<TupleStructItems> (
124 55 : new TupleStructItemsNoRest (vec (std::move (equal))));
125 :
126 55 : equal
127 110 : = std::make_unique<TupleStructPattern> (builder.path_in_expression (
128 : LangItem::Kind::OPTION_SOME),
129 55 : std::move (pattern_items));
130 55 : }
131 :
132 86 : return equal;
133 : }
134 :
135 : std::pair<MatchArm, MatchArm>
136 86 : DeriveOrd::make_cmp_arms ()
137 : {
138 : // All comparison results other than Ordering::Equal
139 86 : auto non_equal = builder.identifier_pattern (DeriveOrd::not_equal);
140 86 : auto equal = make_equal ();
141 :
142 86 : return {builder.match_arm (std::move (equal)),
143 86 : builder.match_arm (std::move (non_equal))};
144 86 : }
145 :
146 : std::unique_ptr<Expr>
147 83 : DeriveOrd::recursive_match (std::vector<SelfOther> &&members)
148 : {
149 83 : if (members.empty ())
150 : {
151 0 : std::unique_ptr<Expr> value = ptrify (
152 0 : builder.path_in_expression ({"core", "cmp", "Ordering", "Equal"},
153 0 : true));
154 :
155 0 : if (ordering == Ordering::Partial)
156 0 : value = builder.call (ptrify (builder.path_in_expression (
157 : LangItem::Kind::OPTION_SOME)),
158 0 : std::move (value));
159 :
160 : return value;
161 : }
162 :
163 83 : std::unique_ptr<Expr> final_expr = nullptr;
164 :
165 252 : for (auto it = members.rbegin (); it != members.rend (); it++)
166 : {
167 169 : auto &member = *it;
168 :
169 169 : auto call = cmp_call (std::move (member.self_expr),
170 169 : std::move (member.other_expr));
171 :
172 : // For the last member (so the first iterator), we just create a call
173 : // expression
174 169 : if (it == members.rbegin ())
175 : {
176 83 : final_expr = std::move (call);
177 83 : continue;
178 : }
179 :
180 : // If we aren't dealing with the last member, then we need to wrap all of
181 : // that in a big match expression and keep going
182 86 : auto match_arms = make_cmp_arms ();
183 :
184 86 : auto match_cases
185 : = {builder.match_case (std::move (match_arms.first),
186 : std::move (final_expr)),
187 : builder.match_case (std::move (match_arms.second),
188 430 : builder.identifier (DeriveOrd::not_equal))};
189 :
190 86 : final_expr = builder.match (std::move (call), std::move (match_cases));
191 427 : }
192 :
193 83 : return final_expr;
194 83 : }
195 :
196 : // we need to do a recursive match expression for all of the fields used in a
197 : // struct so for something like struct Foo { a: i32, b: i32, c: i32 } we must
198 : // first compare each `a` field, then `b`, then `c`, like this:
199 : //
200 : // match cmp_fn(self.<field>, other.<field>) {
201 : // Ordering::Equal => <recurse>,
202 : // cmp => cmp,
203 : // }
204 : //
205 : // and the recurse will be the exact same expression, on the next field. so that
206 : // our result looks like this:
207 : //
208 : // match cmp_fn(self.a, other.a) {
209 : // Ordering::Equal => match cmp_fn(self.b, other.b) {
210 : // Ordering::Equal =>cmp_fn(self.c, other.c),
211 : // cmp => cmp,
212 : // }
213 : // cmp => cmp,
214 : // }
215 : //
216 : // the last field comparison needs not to be a match but just the function call.
217 : // this is going to be annoying lol
218 : void
219 53 : DeriveOrd::visit_struct (StructStruct &item)
220 : {
221 53 : auto fields = SelfOther::fields (builder, item.get_fields ());
222 :
223 53 : auto match_expr = recursive_match (std::move (fields));
224 :
225 106 : expanded = cmp_impl (builder.block (std::move (match_expr)),
226 106 : item.get_identifier (), item.get_generic_params ());
227 53 : }
228 :
229 : // same as structs, but for each field index instead of each field name -
230 : // straightforward once we have `visit_struct` working
231 : void
232 0 : DeriveOrd::visit_tuple (TupleStruct &item)
233 : {
234 0 : auto fields = SelfOther::indexes (builder, item.get_fields ());
235 :
236 0 : auto match_expr = recursive_match (std::move (fields));
237 :
238 0 : expanded = cmp_impl (builder.block (std::move (match_expr)),
239 0 : item.get_identifier (), item.get_generic_params ());
240 0 : }
241 :
242 : // for enums, we need to generate a match for each of the enum's variant that
243 : // contains data and then do the same thing as visit_struct or visit_enum. if
244 : // the two aren't the same variant, then compare the two discriminant values for
245 : // all the dataless enum variants and in the general case.
246 : //
247 : // so for enum Foo { A(i32, i32), B, C } we need to do the following
248 : //
249 : // match (self, other) {
250 : // (A(self_0, self_1), A(other_0, other_1)) => {
251 : // match cmp_fn(self_0, other_0) {
252 : // Ordering::Equal => cmp_fn(self_1, other_1),
253 : // cmp => cmp,
254 : // },
255 : // _ => cmp_fn(discr_value(self), discr_value(other))
256 : // }
257 : void
258 22 : DeriveOrd::visit_enum (Enum &item)
259 : {
260 : // NOTE: We can factor this even further with DerivePartialEq, but this is
261 : // getting out of scope for this PR surely
262 :
263 22 : auto cases = std::vector<MatchCase> ();
264 44 : auto type_name = item.get_identifier ().as_string ();
265 :
266 22 : auto let_sd = builder.discriminant_value (DeriveOrd::self_discr, "self");
267 22 : auto let_od = builder.discriminant_value (DeriveOrd::other_discr, "other");
268 :
269 44 : auto discr_cmp = cmp_call (builder.identifier (DeriveOrd::self_discr),
270 66 : builder.identifier (DeriveOrd::other_discr));
271 :
272 52 : auto recursive_match_fn = [this] (std::vector<SelfOther> &&fields) {
273 30 : return recursive_match (std::move (fields));
274 22 : };
275 :
276 74 : for (auto &variant : item.get_variants ())
277 : {
278 52 : auto enum_builder
279 104 : = EnumMatchBuilder (type_name, variant->get_identifier ().as_string (),
280 52 : recursive_match_fn, builder);
281 :
282 52 : switch (variant->get_enum_item_kind ())
283 : {
284 8 : case EnumItem::Kind::Struct:
285 8 : cases.emplace_back (enum_builder.strukt (*variant));
286 8 : break;
287 22 : case EnumItem::Kind::Tuple:
288 22 : cases.emplace_back (enum_builder.tuple (*variant));
289 22 : break;
290 : case EnumItem::Kind::Identifier:
291 : case EnumItem::Kind::Discriminant:
292 : // We don't need to do anything for these, as they are handled by the
293 : // discriminant value comparison
294 : break;
295 : }
296 52 : }
297 :
298 : // Add the last case which compares the discriminant values in case `self` and
299 : // `other` are actually different variants of the enum
300 22 : cases.emplace_back (
301 44 : builder.match_case (builder.wildcard (), std::move (discr_cmp)));
302 :
303 22 : auto match
304 44 : = builder.match (builder.tuple (vec (builder.identifier ("self"),
305 44 : builder.identifier ("other"))),
306 22 : std::move (cases));
307 :
308 22 : expanded
309 44 : = cmp_impl (builder.block (vec (std::move (let_sd), std::move (let_od)),
310 : std::move (match)),
311 110 : type_name, item.get_generic_params ());
312 22 : }
313 :
314 : void
315 0 : DeriveOrd::visit_union (Union &item)
316 : {
317 0 : auto trait_name = trait (ordering);
318 :
319 0 : rust_error_at (item.get_locus (), "derive(%s) cannot be used on unions",
320 : trait_name.c_str ());
321 0 : }
322 :
323 : } // namespace AST
324 : } // namespace Rust
|