From 025b7904113f1c38a19f3e38b0335d391b8890d7 Mon Sep 17 00:00:00 2001 From: agra Date: Sat, 14 Feb 2026 13:17:22 +0200 Subject: [PATCH] enum, union --- examples/03-structs.sx | 2 +- examples/10-generic-struct.sx | 2 +- examples/16-union.sx | 12 +- examples/24-list.sx | 2 +- examples/28-sdl-graphics.sx | 3 +- examples/30-union.sx | 28 +++ examples/modules/std.sx | 9 +- specs.md | 126 ++++++---- src/ast.zig | 23 +- src/codegen.zig | 425 ++++++++++++++++++++++++---------- src/comptime.zig | 3 +- src/lsp/server.zig | 27 ++- src/parser.zig | 89 ++++--- src/sema.zig | 31 +-- 14 files changed, 537 insertions(+), 245 deletions(-) create mode 100644 examples/30-union.sx diff --git a/examples/03-structs.sx b/examples/03-structs.sx index c6b9e75..c0aa83a 100644 --- a/examples/03-structs.sx +++ b/examples/03-structs.sx @@ -4,7 +4,7 @@ Vec4 :: struct { } Complex :: struct { - foo : union { + foo : enum { S: s32; B: struct { val: string; diff --git a/examples/10-generic-struct.sx b/examples/10-generic-struct.sx index fc22a80..a51a6cf 100644 --- a/examples/10-generic-struct.sx +++ b/examples/10-generic-struct.sx @@ -71,7 +71,7 @@ main :: () { // inline generic type Sx :: (user: $T) -> Type { - return union { + return enum { counter: s32; user: T; }; diff --git a/examples/16-union.sx b/examples/16-union.sx index bb14a86..855bd6c 100644 --- a/examples/16-union.sx +++ b/examples/16-union.sx @@ -1,8 +1,8 @@ #import "modules/std.sx"; -Shape :: union { +Shape :: enum { circle: f32; - rect: s32; + rect: struct { w, h: f32;}; none; } @@ -20,19 +20,19 @@ main :: () { print("none: {}\n", s); // Reassign with payload - s = .rect(42); + s = .rect(.{4, 2}); print("rect: {}\n", s); // Explicit prefix construction sh :Shape = Shape.circle(2.71); print("sh: {}\n", sh); - // Field access on second union variable - sh2 :Shape = .rect(10); + // Field access on second variable + sh2 :Shape = .rect(.{2,4}); val := sh2.rect; print("rect val: {}\n", val); - // Match on union + // Match on enum if sh2 == { case .circle: print("matched circle\n"); case .rect: print("matched rect\n"); diff --git a/examples/24-list.sx b/examples/24-list.sx index 350f4be..cbaae2b 100644 --- a/examples/24-list.sx +++ b/examples/24-list.sx @@ -1,7 +1,7 @@ #import "modules/std.sx"; main :: () { - list := List(s32).{}; + list : List(s32) = .{}; append(list, 1); diff --git a/examples/28-sdl-graphics.sx b/examples/28-sdl-graphics.sx index fa0c34b..e1cc2e8 100644 --- a/examples/28-sdl-graphics.sx +++ b/examples/28-sdl-graphics.sx @@ -153,9 +153,10 @@ main :: () { #version 330 core layout (location = 0) in vec3 aPos; layout (location = 1) in vec3 aNormal; -uniform mat4 uMVP; +uniform mat4 uMVP; out vec3 vNormal; out vec3 vPos; + void main() { gl_Position = uMVP * vec4(aPos, 1.0); vNormal = aNormal; diff --git a/examples/30-union.sx b/examples/30-union.sx new file mode 100644 index 0000000..6e510a8 --- /dev/null +++ b/examples/30-union.sx @@ -0,0 +1,28 @@ +#import "modules/std.sx"; + +Overlay :: union { + f: f32; + i: s32; +} + +Vec2 :: union { + data: [2]f32; + struct { x, y: f32; }; +} + +main :: () { + // Basic union: type punning + o :Overlay = ---; + o.f = 3.14; + print("f={}\n", o.f); + print("i={}\n", o.i); + + // Union with anonymous struct: member promotion + v :Vec2 = ---; + v.x = 1.0; + v.y = 2.0; + print("x={}\n", v.x); + print("y={}\n", v.y); + print("data[0]={}\n", v.data[0]); + print("data[1]={}\n", v.data[1]); +} diff --git a/examples/modules/std.sx b/examples/modules/std.sx index 7b64f2d..ba53b3a 100644 --- a/examples/modules/std.sx +++ b/examples/modules/std.sx @@ -161,10 +161,6 @@ struct_to_string :: (s: $T) -> string { concat(result, "}"); } -enum_to_string :: (e: $T) -> string { - concat(".", field_name(T, cast(s64) e)); -} - vector_to_string :: (v: $T) -> string { result := "["; i := 0; @@ -205,7 +201,7 @@ pointer_to_string :: (p: $T) -> string { } } -union_to_string :: (u: $T) -> string { +enum_to_string :: (u: $T) -> string { tag := cast(s64) u; result := concat(".", field_name(T, tag)); payload := field_value(u, tag); @@ -230,7 +226,6 @@ any_to_string :: (val: Any) -> string { case vector: result = vector_to_string(cast(type) val); case array: result = array_to_string(cast(type) val); case slice: result = slice_to_string(cast(type) val); - case union: result = union_to_string(cast(type) val); case pointer: result = pointer_to_string(cast(type) val); case type: { s : string = xx val; result = s; } } @@ -314,7 +309,7 @@ List :: struct ($T: Type) { cap: s64 = 0; } -append :: (list: *List($T), item: T) { +append ::(list: *List($T), item: T) { if list.len >= list.cap { new_cap := if list.cap == 0 then 4 else list.cap * 2; new_items : [*]T = xx malloc(new_cap * size_of(T)); diff --git a/specs.md b/specs.md index d9efe47..2f86b07 100644 --- a/specs.md +++ b/specs.md @@ -48,6 +48,8 @@ GLSL; ### Keywords `if`, `else`, `then`, `while`, `break`, `continue`, `true`, `false`, `enum`, `struct`, `union`, `case`, `return`, `defer`, `xx`, `and`, `or` +> Note: `enum` is used for both payload-less and payload-bearing sum types (tagged unions). `union` is reserved for C-style untagged unions (memory overlays). + ### Operators | Operator | Meaning | @@ -102,15 +104,88 @@ GLSL; - `Type` — compile-time type value. At runtime, represented as an `i64` type tag (same tag space as `Any`). ### Enum Types -User-defined sum types with named variants. +User-defined sum types with named variants. Variants may optionally carry typed data (tagged unions). Internally, payload-less enums are represented as `i64` (variant index). Enums with payloads are represented as `{ i64, [max_payload_size x i8] }` (tag + data). + +#### Declaration ```sx -Foo :: enum { - variant1; - variant2; +// Payload-less enum +Color :: enum { + red; + green; + blue; +} + +// Enum with payloads (tagged union) +Shape :: enum { + circle: f32; // typed variant + rect: s32; // typed variant + none; // void variant } ``` Variants are referenced with dot-prefix syntax: `.variant1` +#### Construction +```sx +c := Color.red; // payload-less +s :Shape = .circle(3.14); // inferred from context +s = .none; // void variant +s = Shape.rect(42); // explicit prefix +``` + +#### Payload Access +```sx +r := s.circle; // load payload as f32 (undefined behavior if wrong variant active) +``` + +#### Pattern Matching +```sx +if s == { + case .circle: print("circle\n"); + case .rect: print("rect\n"); + case .none: print("none\n"); +} +``` + +#### Enum Interpolation +Payload-less enums print as `.variant`. Enums with payloads print as `.variant(value)` or ``: +```sx +print("{}", s); // .circle(3.140000) +``` + +### Union Types (Untagged) +C-style untagged unions for zero-cost memory overlays (type punning). All fields share the same memory — no tag, no runtime overhead. The LLVM representation is `[max_field_size x i8]`. + +#### Declaration +```sx +Overlay :: union { + f: f32; + i: s32; +} +``` +All fields must have types (unlike enums, which may have void variants). + +#### Anonymous Struct Fields (Member Promotion) +Anonymous `struct` fields inside a union have their members promoted to the union namespace: +```sx +Vec2 :: union { + data: [2]f32; + struct { x, y: f32; }; +} +``` +Access promoted members directly: `v.x`, `v.y` — these are zero-cost GEPs into the same underlying memory as `v.data[0]`, `v.data[1]`. + +#### Initialization +Unions must be initialized with `---` (undefined) and then assigned per-field: +```sx +o :Overlay = ---; +o.f = 3.14; +print("{}\n", o.i); // reinterpret bits as s32 +``` + +#### Restrictions +- Pattern matching (`if x == { case ... }`) is not supported on unions. +- Unions cannot be printed directly via `print("{}", union_val)` — access individual fields instead. + ### Struct Types User-defined product types with named fields. ```sx @@ -159,45 +234,6 @@ Struct values in string interpolation print as `TypeName{field:value, ...}`: print("{}", v1); // Vec4{x:1.0, y:2.0, z:3.0, w:0.0} ``` -### Union Types (Tagged Unions) -Sum types where each variant can carry typed data or be void. Internally represented as `{ i64, [max_payload_size x i8] }`. - -#### Declaration -```sx -Shape :: union { - circle: f32; // typed variant - rect: s32; // typed variant - none; // void variant -} -``` - -#### Construction -```sx -s :Shape = .circle(3.14); // inferred from context -s = .none; // void variant (enum literal syntax) -s = Shape.rect(42); // explicit prefix -``` - -#### Payload Access -```sx -r := s.circle; // load payload as f32 (undefined behavior if wrong variant active) -``` - -#### Pattern Matching -```sx -if s == { - case .circle: print("circle\n"); - case .rect: print("rect\n"); - case .none: print("none\n"); -} -``` - -#### Union Interpolation -Union values in string interpolation print as ``: -```sx -print("{}", s); // -``` - ### Array Types Fixed-size arrays with element type and length. ```sx @@ -623,7 +659,9 @@ if type == { case enum: result = enum_to_string(cast(type) val); } ``` -Available categories: `int`, `float`, `bool`, `string`, `struct`, `enum`, `union`. +Available categories: `int`, `float`, `bool`, `string`, `struct`, `enum`, `vector`, `array`, `slice`, `pointer`, `type`. + +> Note: `case enum:` matches both payload-less enums and tagged enums (enums with payloads). C-style untagged unions are not registered with the Any type system and cannot be matched by category. Inside a category arm, `cast(type) val` performs **runtime generic dispatch**: the compiler generates a switch over all types in the category, monomorphizing the callee for each concrete type. diff --git a/src/ast.zig b/src/ast.zig index e05c5b8..e045423 100644 --- a/src/ast.zig +++ b/src/ast.zig @@ -34,7 +34,6 @@ pub const Node = struct { struct_decl: StructDecl, struct_literal: StructLiteral, union_decl: UnionDecl, - union_literal: UnionLiteral, lambda: Lambda, type_expr: TypeExpr, param: Param, @@ -127,6 +126,7 @@ pub const Identifier = struct { pub const EnumLiteral = struct { name: []const u8, // without the leading dot + payload: ?*Node = null, // non-null for enum variants with payloads (tagged unions) }; pub const BinaryOp = struct { @@ -225,7 +225,14 @@ pub const Assignment = struct { pub const EnumDecl = struct { name: []const u8, - variants: []const []const u8, + variant_names: []const []const u8, + variant_types: []const ?*Node = &.{}, // null entries = no payload; empty = payload-less enum +}; + +pub const UnionDecl = struct { + name: []const u8, + field_names: []const []const u8, + field_types: []const *Node, }; pub const StructTypeParam = struct { @@ -341,18 +348,6 @@ pub const SpreadExpr = struct { operand: *Node, }; -pub const UnionDecl = struct { - name: []const u8, - variant_names: []const []const u8, - variant_types: []const ?*Node, // null for void variants -}; - -pub const UnionLiteral = struct { - union_name: ?[]const u8, // null for anonymous `.variant(expr)` - variant_name: []const u8, - payload: ?*Node, // null for void variants -}; - pub const NamespaceDecl = struct { name: []const u8, decls: []const *Node, diff --git a/src/codegen.zig b/src/codegen.zig index 194d479..6742687 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -108,7 +108,9 @@ pub const CodeGen = struct { type_aliases: std.StringHashMap([]const u8), // Struct type registry: maps struct name to field info + LLVM type struct_types: std.StringHashMap(StructInfo), - // Union type registry: maps union name to variant info + LLVM type + // Tagged enum registry: maps name to variant info + LLVM type (enums with payloads) + tagged_enum_types: std.StringHashMap(TaggedEnumInfo), + // Union registry: maps name to field info + LLVM type (untagged, C-style) union_types: std.StringHashMap(UnionInfo), // Built-in functions (printf, etc.) builtins: ?Builtins, @@ -193,7 +195,6 @@ pub const CodeGen = struct { const TypeCategory = enum { struct_cat, enum_cat, - union_cat, vector_cat, array_cat, slice_cat, @@ -237,13 +238,27 @@ pub const CodeGen = struct { template_name: ?[]const u8 = null, // original template name (e.g. "List") }; - const UnionInfo = struct { + const TaggedEnumInfo = struct { variant_names: []const []const u8, variant_types: []const Type, // void_type for void variants llvm_type: c.LLVMTypeRef, // { i32, [max_payload_size x i8] } max_payload_size: u64, }; + const PromotedField = struct { + struct_name: []const u8, // the anonymous struct type name + field_index: usize, // field index within that struct + field_type: Type, // type of the promoted field + }; + + const UnionInfo = struct { + field_names: []const []const u8, + field_types: []const Type, + llvm_type: c.LLVMTypeRef, // [max_size x i8] + total_size: u64, + promoted_fields: std.StringHashMap(PromotedField), + }; + // Scope stack entry: records what a name mapped to before being shadowed const ScopeEntry = struct { name: []const u8, @@ -277,6 +292,7 @@ pub const CodeGen = struct { .enum_types = std.StringHashMap([]const []const u8).init(allocator), .type_aliases = std.StringHashMap([]const u8).init(allocator), .struct_types = std.StringHashMap(StructInfo).init(allocator), + .tagged_enum_types = std.StringHashMap(TaggedEnumInfo).init(allocator), .union_types = std.StringHashMap(UnionInfo).init(allocator), .builtins = null, .current_function = null, @@ -308,6 +324,7 @@ pub const CodeGen = struct { self.enum_types.deinit(); self.type_aliases.deinit(); self.struct_types.deinit(); + self.tagged_enum_types.deinit(); self.union_types.deinit(); self.comptime_globals.deinit(); self.generic_templates.deinit(); @@ -363,7 +380,7 @@ pub const CodeGen = struct { .string_type, .slice_type => self.getStringStructType(), // slices use same {ptr, i32} layout .enum_type => c.LLVMInt64TypeInContext(self.context), .struct_type => |name| if (self.struct_types.get(name)) |info| info.llvm_type else unreachable, - .union_type => |name| if (self.union_types.get(name)) |info| info.llvm_type else unreachable, + .union_type => |name| if (self.tagged_enum_types.get(name)) |info| info.llvm_type else if (self.union_types.get(name)) |info| info.llvm_type else unreachable, .array_type => |info| { const elem_ty = Type.fromName(info.element_name) orelse unreachable; return c.LLVMArrayType2(self.typeToLLVM(elem_ty), info.length); @@ -413,7 +430,7 @@ pub const CodeGen = struct { const category: TypeCategory = switch (sx_type) { .struct_type => .struct_cat, .enum_type => .enum_cat, - .union_type => .union_cat, + .union_type => .enum_cat, .vector_type => .vector_cat, .array_type => .array_cat, .slice_type => .slice_cat, @@ -541,7 +558,7 @@ pub const CodeGen = struct { }, .union_type => |uname| blk: { // Union — store to alloca, pass pointer as i64 - const info = self.union_types.get(uname) orelse + const info = self.tagged_enum_types.get(uname) orelse return c.LLVMGetUndef(any_ty); const alloca = self.buildEntryBlockAlloca(info.llvm_type, "any_union_tmp"); _ = c.LLVMBuildStore(self.builder, val, alloca); @@ -685,8 +702,14 @@ pub const CodeGen = struct { try self.foreign_libraries.append(self.allocator, ld.lib_name); }, .enum_decl => |ed| { - try self.enum_types.put(ed.name, ed.variants); - _ = try self.getAnyTypeId(ed.name, .{ .enum_type = ed.name }); + if (ed.variant_types.len > 0) { + // Tagged enum with payloads + try self.registerTaggedEnum(ed); + } else { + // Payload-less enum + try self.enum_types.put(ed.name, ed.variant_names); + _ = try self.getAnyTypeId(ed.name, .{ .enum_type = ed.name }); + } }, .struct_decl => |sd| try self.registerStructType(sd), .union_decl => |ud| try self.registerUnionType(ud), @@ -768,7 +791,7 @@ pub const CodeGen = struct { } } { - var it = self.union_types.iterator(); + var it = self.tagged_enum_types.iterator(); while (it.next()) |entry| { _ = try self.getAnyTypeId(entry.key_ptr.*, .{ .union_type = entry.key_ptr.* }); } @@ -1015,13 +1038,15 @@ pub const CodeGen = struct { if (self.type_aliases.get(name)) |target| { if (Type.fromName(target)) |t| return t; if (self.struct_types.contains(target)) return .{ .struct_type = target }; + if (self.tagged_enum_types.contains(target)) return .{ .union_type = target }; if (self.union_types.contains(target)) return .{ .union_type = target }; } // Check enum types if (self.enum_types.contains(name)) return .{ .enum_type = name }; // Check struct types if (self.struct_types.contains(name)) return .{ .struct_type = name }; - // Check union types + // Check union types (tagged enums and C-style unions) + if (self.tagged_enum_types.contains(name)) return .{ .union_type = name }; if (self.union_types.contains(name)) return .{ .union_type = name }; } // Safety net: inline declarations that should have been hoisted @@ -1029,12 +1054,9 @@ pub const CodeGen = struct { const sn = tn.data.struct_decl.name; if (self.struct_types.contains(sn)) return .{ .struct_type = sn }; } - if (tn.data == .union_decl) { - const un = tn.data.union_decl.name; - if (self.union_types.contains(un)) return .{ .union_type = un }; - } if (tn.data == .enum_decl) { const en = tn.data.enum_decl.name; + if (self.tagged_enum_types.contains(en)) return .{ .union_type = en }; if (self.enum_types.contains(en)) return .{ .enum_type = en }; } return .void_type; @@ -1178,13 +1200,13 @@ pub const CodeGen = struct { // Try union if (self.findUnionInBody(fd.body)) |union_decl| { - if (self.union_types.contains(mangled_name)) { + if (self.tagged_enum_types.contains(mangled_name)) { return .{ .union_type = mangled_name }; } - return self.registerInstantiatedUnion(mangled_name, union_decl); + return self.registerInstantiatedTaggedEnum(mangled_name, union_decl); } - return self.emitErrorFmt("type function '{s}' does not return a struct or union", .{template_name}); + return self.emitErrorFmt("type function '{s}' does not return a struct or enum", .{template_name}); } fn registerInstantiatedStruct(self: *CodeGen, mangled_name: []const u8, alias_name: []const u8, struct_decl: ast.StructDecl) !Type { @@ -1205,10 +1227,10 @@ pub const CodeGen = struct { return .{ .struct_type = mangled_name }; } - fn registerInstantiatedUnion(self: *CodeGen, mangled_name: []const u8, union_decl: ast.UnionDecl) !Type { + fn registerInstantiatedTaggedEnum(self: *CodeGen, mangled_name: []const u8, union_decl: ast.EnumDecl) !Type { const build = try self.buildUnionFields(mangled_name, union_decl.variant_types); - try self.union_types.put(mangled_name, .{ + try self.tagged_enum_types.put(mangled_name, .{ .variant_names = union_decl.variant_names, .variant_types = build.variant_sx_types, .llvm_type = build.llvm_type, @@ -1248,8 +1270,27 @@ pub const CodeGen = struct { return findDeclInBody(ast.StructDecl, .struct_decl, body); } - fn findUnionInBody(_: *CodeGen, body: *Node) ?ast.UnionDecl { - return findDeclInBody(ast.UnionDecl, .union_decl, body); + fn findUnionInBody(_: *CodeGen, body: *Node) ?ast.EnumDecl { + // Tagged enums with payloads are now stored as .enum_decl with variant_types populated + const isTaggedEnum = struct { + fn check(node: *Node) ?ast.EnumDecl { + if (node.data == .enum_decl and node.data.enum_decl.variant_types.len > 0) { + return node.data.enum_decl; + } + return null; + } + }; + if (isTaggedEnum.check(body)) |ed| return ed; + const stmts = if (body.data == .block) body.data.block.stmts else return null; + for (stmts) |stmt| { + if (stmt.data == .return_stmt) { + if (stmt.data.return_stmt.value) |val| { + if (isTaggedEnum.check(val)) |ed| return ed; + } + } + if (isTaggedEnum.check(stmt)) |ed| return ed; + } + return null; } fn buildFnType(self: *CodeGen, params: []const ast.Param, return_type: ?*Node, name: []const u8) !c.LLVMTypeRef { @@ -1555,8 +1596,15 @@ pub const CodeGen = struct { } }, .enum_decl => |ed| { - const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns.name, ed.name }); - try self.enum_types.put(qualified, ed.variants); + if (ed.variant_types.len > 0) { + // Tagged enum with payloads + try self.registerTaggedEnum(ed); + const qualified_u = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns.name, ed.name }); + try self.type_aliases.put(qualified_u, ed.name); + } else { + const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns.name, ed.name }); + try self.enum_types.put(qualified, ed.variant_names); + } }, .struct_decl => |sd| { try self.registerStructType(sd); @@ -1663,7 +1711,7 @@ pub const CodeGen = struct { if (name_ptr != null) { const name = std.mem.span(name_ptr); if (self.struct_types.contains(name)) return .{ .struct_type = name }; - if (self.union_types.contains(name)) return .{ .union_type = name }; + if (self.tagged_enum_types.contains(name)) return .{ .union_type = name }; } } // Check for array types @@ -1970,8 +2018,8 @@ pub const CodeGen = struct { try self.registerStructType(sd); return null; }, - .union_decl => |ud| { - try self.registerUnionType(ud); + .union_decl => { + // C-style union — registration handled in type pass return null; }, .assignment => |asgn| { @@ -2111,11 +2159,31 @@ pub const CodeGen = struct { return null; } - // Union-typed variable + // Union-typed variable (tagged enum or C-style union) if (sx_ty.isUnion()) { const uname = self.type_aliases.get(sx_ty.union_type) orelse sx_ty.union_type; sx_ty = .{ .union_type = uname }; - const info = self.union_types.get(uname) orelse return self.emitErrorFmt("unknown union type '{s}'", .{uname}); + + // C-style (untagged) union + if (self.union_types.get(uname)) |info| { + const name_z = try self.allocator.dupeZ(u8, vd.name); + const alloca = self.buildEntryBlockAlloca(info.llvm_type, name_z.ptr); + + if (vd.value == null) { + _ = c.LLVMBuildStore(self.builder, c.LLVMConstNull(info.llvm_type), alloca); + } else if (vd.value.?.data == .undef_literal) { + _ = c.LLVMBuildStore(self.builder, c.LLVMGetUndef(info.llvm_type), alloca); + } else { + return self.emitErrorFmt("union '{s}' must be initialized with '---' or field assignment", .{uname}); + } + + try self.saveShadowed(vd.name); + try self.named_values.put(vd.name, .{ .ptr = alloca, .ty = sx_ty }); + return null; + } + + // Tagged enum + const info = self.tagged_enum_types.get(uname) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{uname}); const name_z = try self.allocator.dupeZ(u8, vd.name); const alloca = self.buildEntryBlockAlloca(info.llvm_type, name_z.ptr); @@ -2124,19 +2192,9 @@ pub const CodeGen = struct { _ = c.LLVMBuildStore(self.builder, c.LLVMConstNull(info.llvm_type), alloca); } else if (vd.value.?.data == .undef_literal) { _ = c.LLVMBuildStore(self.builder, c.LLVMGetUndef(info.llvm_type), alloca); - } else if (vd.value.?.data == .union_literal) { - const lit_alloca = try self.genUnionLiteral(vd.value.?.data.union_literal, uname); - try self.saveShadowed(vd.name); - try self.named_values.put(vd.name, .{ .ptr = lit_alloca, .ty = sx_ty }); - return null; } else if (vd.value.?.data == .enum_literal) { - // Void variant: .none assigned to union variable - const ul = ast.UnionLiteral{ - .union_name = uname, - .variant_name = vd.value.?.data.enum_literal.name, - .payload = null, - }; - const lit_alloca = try self.genUnionLiteral(ul, uname); + const el = vd.value.?.data.enum_literal; + const lit_alloca = try self.genTaggedEnumLiteral(el, uname); try self.saveShadowed(vd.name); try self.named_values.put(vd.name, .{ .ptr = lit_alloca, .ty = sx_ty }); return null; @@ -2330,7 +2388,7 @@ pub const CodeGen = struct { sx_ty = self.inferType(cd.value); } - // Union-typed constant: delegate to genExprAsType which handles enum_literal + union_literal + // Enum-typed constant: delegate to genExprAsType which handles enum_literal if (sx_ty.isUnion()) { const val = try self.genExprAsType(cd.value, sx_ty); try self.saveShadowed(cd.name); @@ -2452,14 +2510,16 @@ pub const CodeGen = struct { return null; } - // Union reassignment: s = .circle(3.14) or s = .none + // Tagged enum reassignment: s = .circle(3.14) or s = .none if (entry.ty.isUnion() and asgn.op == .assign) { - const new_alloca = try self.genExprAsType(asgn.value, entry.ty); - // Copy from new alloca to existing alloca - const info = self.union_types.get(entry.ty.union_type).?; - const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, new_alloca, "union_load"); - _ = c.LLVMBuildStore(self.builder, loaded, entry.ptr); - return null; + if (self.tagged_enum_types.get(entry.ty.union_type)) |info| { + const new_alloca = try self.genExprAsType(asgn.value, entry.ty); + // Copy from new alloca to existing alloca + const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, new_alloca, "union_load"); + _ = c.LLVMBuildStore(self.builder, loaded, entry.ptr); + return null; + } + // C-style union: full assignment not supported, use field assignment } const new_val = try self.genExpr(asgn.value); @@ -2507,6 +2567,42 @@ pub const CodeGen = struct { return self.emitError("field assignment through pointer requires a struct pointee"); } + // C-style union field assignment + if (entry.ty.isUnion()) { + const uname = entry.ty.union_type; + if (self.union_types.get(uname)) |info| { + if (self.findUnionFieldIndex(info, fa.field)) |fidx| { + const field_ty = info.field_types[fidx]; + const rhs = try self.genExprAsType(asgn.value, field_ty); + if (asgn.op == .assign) { + _ = c.LLVMBuildStore(self.builder, rhs, entry.ptr); + } else { + const field_llvm_ty = self.typeToLLVM(field_ty); + const cur = c.LLVMBuildLoad2(self.builder, field_llvm_ty, entry.ptr, "ucur"); + _ = c.LLVMBuildStore(self.builder, self.genCompoundOp(asgn.op, cur, rhs, field_ty), entry.ptr); + } + return null; + } + // Check promoted fields from anonymous structs + if (info.promoted_fields.get(fa.field)) |pf| { + const sinfo = self.struct_types.get(pf.struct_name) orelse + return self.emitErrorFmt("unknown promoted struct '{s}'", .{pf.struct_name}); + const gep = c.LLVMBuildStructGEP2(self.builder, sinfo.llvm_type, entry.ptr, @intCast(pf.field_index), "promoted_ptr"); + const rhs = try self.genExprAsType(asgn.value, pf.field_type); + if (asgn.op == .assign) { + _ = c.LLVMBuildStore(self.builder, rhs, gep); + } else { + const field_llvm_ty = self.typeToLLVM(pf.field_type); + const cur = c.LLVMBuildLoad2(self.builder, field_llvm_ty, gep, "ucur"); + _ = c.LLVMBuildStore(self.builder, self.genCompoundOp(asgn.op, cur, rhs, pf.field_type), gep); + } + return null; + } + return self.emitErrorFmt("no field '{s}' in union '{s}'", .{ fa.field, uname }); + } + return self.emitErrorFmt("field assignment not supported on tagged enum '{s}'", .{uname}); + } + if (!entry.ty.isStruct()) return self.emitErrorFmt("field access on non-struct variable '{s}'", .{obj_name}); const sname = entry.ty.struct_type; @@ -2749,9 +2845,6 @@ pub const CodeGen = struct { const ctx_name: ?[]const u8 = if (self.current_return_type.isStruct()) self.current_return_type.struct_type else null; return self.genStructLiteral(sl, ctx_name); }, - .union_literal => |ul| { - return self.genUnionLiteral(ul, null); - }, .array_literal => |al| { // Typed array/vector/slice literal: Type.[elems] if (al.type_expr) |te| { @@ -2905,6 +2998,19 @@ pub const CodeGen = struct { const idx = self.findFieldIndex(info, fa.field) orelse return self.emitErrorFmt("no field '{s}' in struct '{s}'", .{ fa.field, sname }); return c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, @intCast(idx), "addr_field"); } + // &u.field where u is a C-style union — all fields at offset 0 + if (entry.ty.isUnion()) { + if (self.union_types.get(entry.ty.union_type)) |info| { + if (self.findUnionFieldIndex(info, fa.field) != null) { + return entry.ptr; + } + if (info.promoted_fields.get(fa.field)) |pf| { + const sinfo = self.struct_types.get(pf.struct_name) orelse + return self.emitErrorFmt("unknown promoted struct '{s}'", .{pf.struct_name}); + return c.LLVMBuildStructGEP2(self.builder, sinfo.llvm_type, entry.ptr, @intCast(pf.field_index), "addr_promoted"); + } + } + } // &p.field where p is *Struct — auto-deref through pointer if (entry.ty.isPointer()) { const pointee_name = entry.ty.pointer_type.pointee_name; @@ -2995,14 +3101,21 @@ pub const CodeGen = struct { type_node.data = .{ .type_expr = .{ .name = synthetic_name } }; }, .union_decl => |inline_ud| { - var hoisted = inline_ud; - hoisted.name = synthetic_name; - try self.registerUnionType(hoisted); + var hoisted_ud = inline_ud; + hoisted_ud.name = synthetic_name; + try self.registerUnionType(hoisted_ud); type_node.data = .{ .type_expr = .{ .name = synthetic_name } }; }, .enum_decl => |inline_ed| { - try self.enum_types.put(synthetic_name, inline_ed.variants); - _ = try self.getAnyTypeId(synthetic_name, .{ .enum_type = synthetic_name }); + if (inline_ed.variant_types.len > 0) { + // Tagged enum with payloads + var hoisted = inline_ed; + hoisted.name = synthetic_name; + try self.registerTaggedEnum(hoisted); + } else { + try self.enum_types.put(synthetic_name, inline_ed.variant_names); + _ = try self.getAnyTypeId(synthetic_name, .{ .enum_type = synthetic_name }); + } type_node.data = .{ .type_expr = .{ .name = synthetic_name } }; }, else => {}, @@ -3047,7 +3160,7 @@ pub const CodeGen = struct { _ = try self.getAnyTypeId(sd.name, .{ .struct_type = sd.name }); } - fn registerUnionType(self: *CodeGen, ud: ast.UnionDecl) !void { + fn registerTaggedEnum(self: *CodeGen, ud: ast.EnumDecl) !void { // Pre-pass: hoist inline type declarations from variant types for (ud.variant_types, 0..) |vt_opt, i| { if (vt_opt) |vt| { @@ -3057,7 +3170,7 @@ pub const CodeGen = struct { const build = try self.buildUnionFields(ud.name, ud.variant_types); - try self.union_types.put(ud.name, .{ + try self.tagged_enum_types.put(ud.name, .{ .variant_names = ud.variant_names, .variant_types = build.variant_sx_types, .llvm_type = build.llvm_type, @@ -3066,22 +3179,78 @@ pub const CodeGen = struct { _ = try self.getAnyTypeId(ud.name, .{ .union_type = ud.name }); } - fn genUnionLiteral(self: *CodeGen, ul: ast.UnionLiteral, expected_union_name: ?[]const u8) !c.LLVMValueRef { - const uname = ul.union_name orelse expected_union_name orelse + fn registerUnionType(self: *CodeGen, ud: ast.UnionDecl) !void { + // Hoist inline type declarations from field types + for (ud.field_types, 0..) |ft, i| { + try self.hoistInlineTypeDecl(ud.name, ud.field_names[i], ft); + } + + // Compute max field size and resolve field types + const data_layout = c.LLVMGetModuleDataLayout(self.module); + var field_sx_types = std.ArrayList(Type).empty; + var max_size: u64 = 0; + for (ud.field_types) |ft| { + const resolved = self.resolveType(ft); + try field_sx_types.append(self.allocator, resolved); + const llvm_ty = self.typeToLLVM(resolved); + const size = c.LLVMABISizeOfType(data_layout, llvm_ty); + if (size > max_size) max_size = size; + } + + // LLVM type: byte array sized to the largest field + const byte_ty = c.LLVMInt8TypeInContext(self.context); + const llvm_type = c.LLVMArrayType(byte_ty, @intCast(max_size)); + + const resolved_field_types = try field_sx_types.toOwnedSlice(self.allocator); + + // Build promoted fields map from anonymous struct members + var promoted = std.StringHashMap(PromotedField).init(self.allocator); + for (ud.field_names, 0..) |_, i| { + const fty = resolved_field_types[i]; + if (fty.isStruct()) { + // Check if this is an anonymous struct (name contains __anon_) + const sname = fty.struct_type; + if (std.mem.indexOf(u8, sname, ".__anon_") != null) { + if (self.struct_types.get(sname)) |sinfo| { + for (sinfo.field_names, 0..) |sf_name, sf_idx| { + try promoted.put(sf_name, .{ + .struct_name = sname, + .field_index = sf_idx, + .field_type = sinfo.field_types[sf_idx], + }); + } + } + } + } + } + + try self.union_types.put(ud.name, .{ + .field_names = ud.field_names, + .field_types = resolved_field_types, + .llvm_type = llvm_type, + .total_size = max_size, + .promoted_fields = promoted, + }); + // Note: C-style unions are not registered with the Any type system. + // They can't be meaningfully printed as a whole — access individual fields instead. + } + + fn genTaggedEnumLiteral(self: *CodeGen, el: ast.EnumLiteral, expected_union_name: ?[]const u8) !c.LLVMValueRef { + const uname = expected_union_name orelse (if (self.current_return_type.isUnion()) self.current_return_type.union_type else null) orelse - return self.emitError("cannot infer union type for literal"); + return self.emitError("cannot infer enum type for literal"); const resolved_name = self.type_aliases.get(uname) orelse uname; - const info = self.union_types.get(resolved_name) orelse return self.emitErrorFmt("unknown union type '{s}'", .{resolved_name}); + const info = self.tagged_enum_types.get(resolved_name) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved_name}); // Find variant index var variant_idx: ?u32 = null; for (info.variant_names, 0..) |vn, i| { - if (std.mem.eql(u8, vn, ul.variant_name)) { + if (std.mem.eql(u8, vn, el.name)) { variant_idx = @intCast(i); break; } } - const idx = variant_idx orelse return self.emitErrorFmt("no variant '{s}' in union '{s}'", .{ ul.variant_name, resolved_name }); + const idx = variant_idx orelse return self.emitErrorFmt("no variant '{s}' in enum '{s}'", .{ el.name, resolved_name }); // Alloca union const alloca = self.buildEntryBlockAlloca(info.llvm_type, "union_tmp"); @@ -3092,20 +3261,13 @@ pub const CodeGen = struct { _ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(i64_ty, idx, 0), tag_gep); // Store payload (field 1) if not void - if (ul.payload) |payload_node| { + if (el.payload) |payload_node| { const variant_ty = info.variant_types[idx]; if (variant_ty != .void_type) { const payload_val = try self.genExprAsType(payload_node, variant_ty); const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 1, "payload"); - const payload_llvm_ty = self.typeToLLVM(variant_ty); - // Bitcast payload area to the variant's type pointer and store - if (variant_ty.isStruct()) { - // Struct payload: load from alloca, store to payload area - const struct_val = c.LLVMBuildLoad2(self.builder, payload_llvm_ty, payload_val, "struct_load"); - _ = c.LLVMBuildStore(self.builder, struct_val, payload_gep); - } else { - _ = c.LLVMBuildStore(self.builder, payload_val, payload_gep); - } + // genExprAsType returns a loaded value for all types (including structs) + _ = c.LLVMBuildStore(self.builder, payload_val, payload_gep); } } @@ -3284,19 +3446,10 @@ pub const CodeGen = struct { return c.LLVMBuildGlobalStringPtr(self.builder, str_z.ptr, "str"); } - // Enum literal assigned to union type: construct tag-only (void variant) union + // Enum/union literal assigned to union type: construct tagged enum if (node.data == .enum_literal and target_ty.isUnion()) { - const ul = ast.UnionLiteral{ - .union_name = null, - .variant_name = node.data.enum_literal.name, - .payload = null, - }; - return self.genUnionLiteral(ul, target_ty.union_type); - } - - // Union literal with target union type: pass context - if (node.data == .union_literal and target_ty.isUnion()) { - return self.genUnionLiteral(node.data.union_literal, target_ty.union_type); + const el = node.data.enum_literal; + return self.genTaggedEnumLiteral(el, target_ty.union_type); } // Struct literal targeting union type: .Variant.{fields} pattern @@ -3308,8 +3461,8 @@ pub const CodeGen = struct { if (te.data == .enum_literal) { const variant_name = te.data.enum_literal.name; const uname = self.type_aliases.get(target_ty.union_type) orelse target_ty.union_type; - const info = self.union_types.get(uname) orelse - return self.emitErrorFmt("unknown union type '{s}'", .{uname}); + const info = self.tagged_enum_types.get(uname) orelse + return self.emitErrorFmt("unknown enum type '{s}'", .{uname}); // Find variant index var variant_idx: ?u32 = null; @@ -3320,7 +3473,7 @@ pub const CodeGen = struct { } } const idx = variant_idx orelse - return self.emitErrorFmt("no variant '{s}' in union '{s}'", .{ variant_name, uname }); + return self.emitErrorFmt("no variant '{s}' in enum '{s}'", .{ variant_name, uname }); const variant_ty = info.variant_types[idx]; @@ -3554,7 +3707,7 @@ pub const CodeGen = struct { } if (target_ty.isUnion()) { const uname = target_ty.union_type; - if (self.union_types.get(uname)) |info| { + if (self.tagged_enum_types.get(uname)) |info| { const ptr = c.LLVMBuildIntToPtr(self.builder, i64_val, c.LLVMPointerTypeInContext(self.context, 0), "any_union_ptr"); return c.LLVMBuildLoad2(self.builder, info.llvm_type, ptr, "any_to_union"); } @@ -3604,7 +3757,7 @@ pub const CodeGen = struct { // Union → int: extract the tag field (index 0) if (src_ty.isUnion() and target_ty.isInt()) { const uname = src_ty.union_type; - if (self.union_types.get(uname)) |info| { + if (self.tagged_enum_types.get(uname)) |info| { const tmp = self.buildEntryBlockAlloca(info.llvm_type, "union_cast"); _ = c.LLVMBuildStore(self.builder, val, tmp); const tag_ptr = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, tmp, 0, "tag_ptr"); @@ -3667,6 +3820,13 @@ pub const CodeGen = struct { return null; } + fn findUnionFieldIndex(_: *CodeGen, info: UnionInfo, name: []const u8) ?usize { + for (info.field_names, 0..) |fn_name, i| { + if (std.mem.eql(u8, fn_name, name)) return i; + } + return null; + } + fn componentToIndex(ch: u8) ?u32 { return switch (ch) { 'x', 'r', 'u' => 0, @@ -3819,14 +3979,14 @@ pub const CodeGen = struct { return c.LLVMConstInt(i64_ty, ty.vector_type.length, 0); } if (ty.isUnion()) { - const info = self.union_types.get(ty.union_type) orelse - return self.emitErrorFmt("unknown union type '{s}'", .{ty.union_type}); + const info = self.tagged_enum_types.get(ty.union_type) orelse + return self.emitErrorFmt("unknown enum type '{s}'", .{ty.union_type}); return c.LLVMConstInt(i64_ty, info.variant_names.len, 0); } if (ty.isArray()) { return c.LLVMConstInt(i64_ty, ty.array_type.length, 0); } - return self.emitError("field_count requires a struct, enum, vector, union, or array type"); + return self.emitError("field_count requires a struct, enum, vector, or array type"); } fn genFieldName(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef { @@ -3843,10 +4003,10 @@ pub const CodeGen = struct { return self.emitErrorFmt("unknown enum type '{s}'", .{ty.enum_type}); break :blk .{ variants, ty.enum_type }; } else if (ty.isUnion()) blk: { - const info = self.union_types.get(ty.union_type) orelse - return self.emitErrorFmt("unknown union type '{s}'", .{ty.union_type}); + const info = self.tagged_enum_types.get(ty.union_type) orelse + return self.emitErrorFmt("unknown enum type '{s}'", .{ty.union_type}); break :blk .{ info.variant_names, ty.union_type }; - } else return self.emitError("field_name requires a struct, enum, or union type"); + } else return self.emitError("field_name requires a struct or enum type"); // Build a global array of string slices const n = names.len; @@ -3891,10 +4051,15 @@ pub const CodeGen = struct { return self.buildAnyValue(elem, elem_ty); } - // Union: switch over tag, extract payload with correct type + // Payload-less enum: return void Any (no payload to extract) + if (val_ty.isEnum() and !val_ty.isUnion()) { + return self.buildAnyValue(c.LLVMConstInt(c.LLVMInt64TypeInContext(self.context), 0, 0), .void_type); + } + + // Tagged enum (with payloads): switch over tag, extract payload with correct type if (val_ty.isUnion()) { - const uinfo = self.union_types.get(val_ty.union_type) orelse - return self.emitErrorFmt("unknown union type '{s}'", .{val_ty.union_type}); + const uinfo = self.tagged_enum_types.get(val_ty.union_type) orelse + return self.emitErrorFmt("unknown enum type '{s}'", .{val_ty.union_type}); const union_alloca = self.buildEntryBlockAlloca(uinfo.llvm_type, "fv_union"); _ = c.LLVMBuildStore(self.builder, val, union_alloca); @@ -3986,7 +4151,7 @@ pub const CodeGen = struct { // Struct: switch over field indices const struct_val = val; const struct_ty = val_ty; - if (!struct_ty.isStruct()) return self.emitError("field_value requires a struct, vector, union, or array value"); + if (!struct_ty.isStruct()) return self.emitError("field_value requires a struct, vector, enum, or array value"); const info = self.struct_types.get(struct_ty.struct_type) orelse return self.emitErrorFmt("unknown struct type '{s}'", .{struct_ty.struct_type}); @@ -4143,7 +4308,24 @@ pub const CodeGen = struct { } if (entry.ty.isUnion()) { const uname = entry.ty.union_type; - const info = self.union_types.get(uname) orelse return self.emitErrorFmt("unknown union type '{s}'", .{uname}); + // C-style (untagged) union: bitcast pointer and load + if (self.union_types.get(uname)) |info| { + if (self.findUnionFieldIndex(info, fa.field)) |fidx| { + const field_ty = info.field_types[fidx]; + return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(field_ty), entry.ptr, "union_field"); + } + // Check promoted fields from anonymous structs + if (info.promoted_fields.get(fa.field)) |pf| { + const sinfo = self.struct_types.get(pf.struct_name) orelse + return self.emitErrorFmt("unknown promoted struct '{s}'", .{pf.struct_name}); + // GEP through union pointer as struct type, then access field + const gep = c.LLVMBuildStructGEP2(self.builder, sinfo.llvm_type, entry.ptr, @intCast(pf.field_index), "promoted_field"); + return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(pf.field_type), gep, "promoted_val"); + } + return self.emitErrorFmt("no field '{s}' in union '{s}'", .{ fa.field, uname }); + } + // Tagged enum: GEP to payload area + const info = self.tagged_enum_types.get(uname) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{uname}); // Find variant by name to determine payload type var vidx: ?usize = null; for (info.variant_names, 0..) |vn, i| { @@ -4152,7 +4334,7 @@ pub const CodeGen = struct { break; } } - const idx = vidx orelse return self.emitErrorFmt("no variant '{s}' in union '{s}'", .{ fa.field, uname }); + const idx = vidx orelse return self.emitErrorFmt("no variant '{s}' in enum '{s}'", .{ fa.field, uname }); const variant_ty = info.variant_types[idx]; if (variant_ty == .void_type) return self.emitErrorFmt("cannot access payload of void variant '{s}'", .{fa.field}); // GEP to field 1 (payload area), load as variant type @@ -4515,7 +4697,7 @@ pub const CodeGen = struct { const resolved_type: ?Type = blk: { if (fa.object.data == .identifier) { const name = self.type_aliases.get(fa.object.data.identifier.name) orelse fa.object.data.identifier.name; - if (self.union_types.contains(name)) break :blk .{ .union_type = name }; + if (self.tagged_enum_types.contains(name)) break :blk .{ .union_type = name }; if (self.struct_types.contains(name)) break :blk .{ .struct_type = name }; } else { const ty = self.resolveType(fa.object); @@ -4527,9 +4709,8 @@ pub const CodeGen = struct { if (rty.isUnion()) { const type_name = rty.union_type; const payload_node: ?*Node = if (call_node.args.len > 0) call_node.args[0] else null; - return self.genUnionLiteral(.{ - .union_name = type_name, - .variant_name = fa.field, + return self.genTaggedEnumLiteral(.{ + .name = fa.field, .payload = payload_node, }, type_name); } @@ -4882,7 +5063,7 @@ pub const CodeGen = struct { Type.fromName(name) == null and !self.struct_types.contains(name) and !self.enum_types.contains(name) and - !self.union_types.contains(name) and + !self.tagged_enum_types.contains(name) and !self.type_aliases.contains(name)) { return self.genGenericCallWithRuntimeDispatch(template, call_node, match_tags); @@ -5323,7 +5504,7 @@ pub const CodeGen = struct { }, .enum_type => any_i64, .union_type => |uname| blk: { - const info = self.union_types.get(uname) orelse return self.emitErrorFmt("unknown union '{s}'", .{uname}); + const info = self.tagged_enum_types.get(uname) orelse return self.emitErrorFmt("unknown enum '{s}'", .{uname}); const ptr = c.LLVMBuildIntToPtr(self.builder, any_i64, c.LLVMPointerTypeInContext(self.context, 0), "any_to_union_ptr"); break :blk c.LLVMBuildLoad2(self.builder, info.llvm_type, ptr, "any_to_union"); }, @@ -5764,13 +5945,13 @@ pub const CodeGen = struct { const subject_val: c.LLVMValueRef = if (union_name != null) blk: { // Union: load tag from field 0 of the alloca const entry = self.named_values.get(match.subject.data.identifier.name).?; - const info = self.union_types.get(union_name.?).?; + const info = self.tagged_enum_types.get(union_name.?).?; const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, 0, "tag"); break :blk c.LLVMBuildLoad2(self.builder, c.LLVMInt64TypeInContext(self.context), tag_gep, "tag_val"); } else try self.genExpr(match.subject); const variants: ?[]const []const u8 = if (union_name) |un| - (if (self.union_types.get(un)) |info| info.variant_names else null) + (if (self.tagged_enum_types.get(un)) |info| info.variant_names else null) else if (enum_name) |en| self.enum_types.get(en) else @@ -5925,7 +6106,7 @@ pub const CodeGen = struct { else if (std.mem.eql(u8, name, "enum")) .enum_cat else if (std.mem.eql(u8, name, "union")) - .union_cat + .enum_cat else if (std.mem.eql(u8, name, "vector")) .vector_cat else if (std.mem.eql(u8, name, "array")) @@ -5979,7 +6160,7 @@ pub const CodeGen = struct { .{ .struct_type = name } else if (self.enum_types.contains(name)) .{ .enum_type = name } - else if (self.union_types.contains(name)) + else if (self.tagged_enum_types.contains(name)) .{ .union_type = name } else .{ .struct_type = name }; // fallback @@ -6183,7 +6364,8 @@ pub const CodeGen = struct { if (Type.fromName(name)) |t| return t; // Structs if (self.struct_types.contains(name)) return .{ .struct_type = name }; - // Unions + // Unions (tagged enums and C-style) + if (self.tagged_enum_types.contains(name)) return .{ .union_type = name }; if (self.union_types.contains(name)) return .{ .union_type = name }; // Enums if (self.enum_types.contains(name)) return .{ .enum_type = name }; @@ -6236,11 +6418,6 @@ pub const CodeGen = struct { } return .void_type; }, - .union_literal => |ul| { - if (ul.union_name) |uname| return .{ .union_type = uname }; - if (self.current_return_type.isUnion()) return self.current_return_type; - return .void_type; - }, .enum_literal => { if (self.current_return_type.isEnum()) return self.current_return_type; if (self.current_return_type.isUnion()) return self.current_return_type; @@ -6259,7 +6436,7 @@ pub const CodeGen = struct { const obj_ty = blk: { if (fa.object.data == .identifier) { const name = self.type_aliases.get(fa.object.data.identifier.name) orelse fa.object.data.identifier.name; - if (self.union_types.contains(name)) break :blk Type{ .union_type = name }; + if (self.tagged_enum_types.contains(name)) break :blk Type{ .union_type = name }; } const ty = self.resolveType(fa.object); if (ty.isUnion()) break :blk ty; @@ -6422,6 +6599,14 @@ pub const CodeGen = struct { } if (obj_ty.isUnion()) { if (self.union_types.get(obj_ty.union_type)) |info| { + if (self.findUnionFieldIndex(info, fa.field)) |idx| { + return info.field_types[idx]; + } + if (info.promoted_fields.get(fa.field)) |pf| { + return pf.field_type; + } + } + if (self.tagged_enum_types.get(obj_ty.union_type)) |info| { for (info.variant_names, 0..) |vn, i| { if (std.mem.eql(u8, vn, fa.field)) { return info.variant_types[i]; diff --git a/src/comptime.zig b/src/comptime.zig index 609b82a..e41b0fb 100644 --- a/src/comptime.zig +++ b/src/comptime.zig @@ -1616,7 +1616,8 @@ pub const VM = struct { }, .enum_decl => |ed| { if (std.mem.eql(u8, ed.name, name)) { - const val = Value{ .type_val = .{ .enum_type = name } }; + const ty: Type = if (ed.variant_types.len > 0) .{ .union_type = name } else .{ .enum_type = name }; + const val = Value{ .type_val = ty }; try self.globals.put(name, val); return val; } diff --git a/src/lsp/server.zig b/src/lsp/server.zig index efa2ee6..c652618 100644 --- a/src/lsp/server.zig +++ b/src/lsp/server.zig @@ -553,7 +553,7 @@ pub const Server = struct { .union_decl => |ud| { try items.append(allocator, .{ .label = ud.name, - .kind = @intFromEnum(lsp.CompletionItemKind.Enum), + .kind = @intFromEnum(lsp.CompletionItemKind.Struct), }); }, .var_decl => |vd| { @@ -602,7 +602,7 @@ pub const Server = struct { if (sx.sema.findNodeAtOffset(lookup_root, sym.def_span.start)) |node| { if (node.data == .enum_decl) { const ed = node.data.enum_decl; - for (ed.variants) |variant| { + for (ed.variant_names) |variant| { try items.append(self.allocator, .{ .label = variant, .kind = @intFromEnum(lsp.CompletionItemKind.EnumMember), @@ -1212,7 +1212,7 @@ pub const Server = struct { if (sx.sema.findNodeAtOffset(lookup_root, sym.def_span.start)) |node| { if (node.data == .enum_decl) { const ed = node.data.enum_decl; - for (ed.variants) |v| { + for (ed.variant_names) |v| { if (!std.mem.eql(u8, v, variant_name)) continue; var buf = std.ArrayList(u8).empty; @@ -1633,7 +1633,7 @@ pub const Server = struct { .enum_decl => |ed| { try buf.appendSlice(allocator, ed.name); try buf.appendSlice(allocator, " :: enum { "); - for (ed.variants, 0..) |v, i| { + for (ed.variant_names, 0..) |v, i| { if (i > 0) try buf.appendSlice(allocator, ", "); try buf.append(allocator, '.'); try buf.appendSlice(allocator, v); @@ -1658,9 +1658,24 @@ pub const Server = struct { .union_decl => |ud| { try buf.appendSlice(allocator, ud.name); try buf.appendSlice(allocator, " :: union { "); - for (ud.variant_names, 0..) |vn, i| { + for (ud.field_names, 0..) |fn_name, i| { if (i > 0) try buf.appendSlice(allocator, ", "); - try buf.appendSlice(allocator, vn); + // Anonymous struct fields: show as "struct { ... }" + if (std.mem.startsWith(u8, fn_name, "__anon_")) { + if (i < ud.field_types.len and ud.field_types[i].data == .type_expr) { + try buf.appendSlice(allocator, ud.field_types[i].data.type_expr.name); + } else { + try buf.appendSlice(allocator, "struct { ... }"); + } + } else { + try buf.appendSlice(allocator, fn_name); + if (i < ud.field_types.len) { + try buf.appendSlice(allocator, ": "); + if (ud.field_types[i].data == .type_expr) { + try buf.appendSlice(allocator, ud.field_types[i].data.type_expr.name); + } + } + } } try buf.appendSlice(allocator, " }"); }, diff --git a/src/parser.zig b/src/parser.zig index 80c7122..25c0a0d 100644 --- a/src/parser.zig +++ b/src/parser.zig @@ -160,7 +160,7 @@ pub const Parser = struct { return self.parseStructDecl(name, start_pos); } - // Union declaration + // C-style union declaration if (self.current.tag == .kw_union) { return self.parseUnionDecl(name, start_pos); } @@ -368,7 +368,7 @@ pub const Parser = struct { if (self.current.tag == .kw_struct) { return try self.parseStructDecl("__anon", start); } - // Inline union type in type position: union { ... } + // Inline C-style union in type position: union { ... } if (self.current.tag == .kw_union) { return try self.parseUnionDecl("__anon", start); } @@ -382,30 +382,13 @@ pub const Parser = struct { fn parseEnumDecl(self: *Parser, name: []const u8, start_pos: u32) anyerror!*Node { self.advance(); // skip 'enum' try self.expect(.l_brace); - var variants = std.ArrayList([]const u8).empty; + var variant_names = std.ArrayList([]const u8).empty; + var variant_types = std.ArrayList(?*Node).empty; + var has_any_type = false; while (self.current.tag != .r_brace and self.current.tag != .eof) { if (self.current.tag != .identifier) { return self.fail("expected variant name"); } - try variants.append(self.allocator, self.tokenSlice(self.current)); - self.advance(); - if (self.current.tag == .semicolon) { - self.advance(); - } - } - try self.expect(.r_brace); - return try self.createNode(start_pos, .{ .enum_decl = .{ .name = name, .variants = try variants.toOwnedSlice(self.allocator) } }); - } - - fn parseUnionDecl(self: *Parser, name: []const u8, start_pos: u32) anyerror!*Node { - self.advance(); // skip 'union' - try self.expect(.l_brace); - var variant_names = std.ArrayList([]const u8).empty; - var variant_types = std.ArrayList(?*Node).empty; - while (self.current.tag != .r_brace and self.current.tag != .eof) { - if (self.current.tag != .identifier) { - return self.fail("expected variant name in union"); - } try variant_names.append(self.allocator, self.tokenSlice(self.current)); self.advance(); if (self.current.tag == .colon) { @@ -413,6 +396,7 @@ pub const Parser = struct { self.advance(); const vtype = try self.parseTypeExpr(); try variant_types.append(self.allocator, vtype); + has_any_type = true; } else { // Void variant: name; try variant_types.append(self.allocator, null); @@ -422,10 +406,54 @@ pub const Parser = struct { } } try self.expect(.r_brace); - return try self.createNode(start_pos, .{ .union_decl = .{ + // Always produce enum_decl; variant_types distinguishes payload-less from tagged + return try self.createNode(start_pos, .{ .enum_decl = .{ .name = name, .variant_names = try variant_names.toOwnedSlice(self.allocator), - .variant_types = try variant_types.toOwnedSlice(self.allocator), + .variant_types = if (has_any_type) try variant_types.toOwnedSlice(self.allocator) else &.{}, + } }); + } + + fn parseUnionDecl(self: *Parser, name: []const u8, start_pos: u32) anyerror!*Node { + self.advance(); // skip 'union' + try self.expect(.l_brace); + var field_names = std.ArrayList([]const u8).empty; + var field_types = std.ArrayList(*Node).empty; + var anon_idx: u32 = 0; + while (self.current.tag != .r_brace and self.current.tag != .eof) { + // Anonymous struct field: struct { x, y: f32; }; + if (self.current.tag == .kw_struct) { + const anon_field = try std.fmt.allocPrint(self.allocator, "__anon_{d}", .{anon_idx}); + anon_idx += 1; + const anon_struct_name = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ name, anon_field }); + const struct_node = try self.parseStructDecl(anon_struct_name, self.current.loc.start); + try field_names.append(self.allocator, anon_field); + try field_types.append(self.allocator, struct_node); + if (self.current.tag == .semicolon) { + self.advance(); + } + continue; + } + if (self.current.tag != .identifier) { + return self.fail("expected field name or 'struct'"); + } + try field_names.append(self.allocator, self.tokenSlice(self.current)); + self.advance(); + if (self.current.tag != .colon) { + return self.fail("union fields must have a type"); + } + self.advance(); + const ftype = try self.parseTypeExpr(); + try field_types.append(self.allocator, ftype); + if (self.current.tag == .semicolon) { + self.advance(); + } + } + try self.expect(.r_brace); + return try self.createNode(start_pos, .{ .union_decl = .{ + .name = name, + .field_names = try field_names.toOwnedSlice(self.allocator), + .field_types = try field_types.toOwnedSlice(self.allocator), } }); } @@ -1141,14 +1169,13 @@ pub const Parser = struct { } const name = self.tokenSlice(self.current); self.advance(); - // Union literal: .variant(payload) + // Enum literal with payload: .variant(payload) — tagged enum (formerly union literal) if (self.current.tag == .l_paren) { self.advance(); // skip '(' const payload = try self.parseExpr(); try self.expect(.r_paren); - return try self.createNode(start, .{ .union_literal = .{ - .union_name = null, - .variant_name = name, + return try self.createNode(start, .{ .enum_literal = .{ + .name = name, .payload = payload, } }); } @@ -1175,8 +1202,12 @@ pub const Parser = struct { // Anonymous struct expression: struct { value: T; count: u32; } return try self.parseStructDecl("__anon", start); }, + .kw_enum => { + // Anonymous enum expression: enum { variant: T; other: u32; } + return try self.parseEnumDecl("__anon", start); + }, .kw_union => { - // Anonymous union expression: union { variant: T; other: u32; } + // Anonymous C-style union expression: union { f: f32; i: s32; } return try self.parseUnionDecl("__anon", start); }, .kw_if => { diff --git a/src/sema.zig b/src/sema.zig index a26fc73..b0c142b 100644 --- a/src/sema.zig +++ b/src/sema.zig @@ -179,8 +179,14 @@ pub const Analyzer = struct { try self.addSymbol(vd.name, .variable, ty, node.span); }, .enum_decl => |ed| { - try self.addSymbol(ed.name, .enum_type, .{ .enum_type = ed.name }, node.span); - try self.enum_types.put(ed.name, ed.variants); + if (ed.variant_types.len > 0) { + // Tagged enum with payloads + try self.addSymbol(ed.name, .enum_type, .{ .union_type = ed.name }, node.span); + } else { + // Payload-less enum + try self.addSymbol(ed.name, .enum_type, .{ .enum_type = ed.name }, node.span); + try self.enum_types.put(ed.name, ed.variant_names); + } }, .struct_decl => |sd| { try self.addSymbol(sd.name, .struct_type, .{ .struct_type = sd.name }, node.span); @@ -392,10 +398,6 @@ pub const Analyzer = struct { .break_expr => .void_type, .continue_expr => .void_type, .enum_literal => .{ .enum_type = "" }, - .union_literal => |ul| { - if (ul.union_name) |name| return .{ .union_type = name }; - return .void_type; - }, .struct_literal => |sl| { if (sl.struct_name) |name| { if (self.struct_types.contains(name)) return .{ .struct_type = name }; @@ -592,7 +594,11 @@ pub const Analyzer = struct { try self.addSymbol(vd.name, .variable, ty, node.span); }, .enum_decl => |ed| { - try self.addSymbol(ed.name, .enum_type, .{ .enum_type = ed.name }, node.span); + if (ed.variant_types.len > 0) { + try self.addSymbol(ed.name, .enum_type, .{ .union_type = ed.name }, node.span); + } else { + try self.addSymbol(ed.name, .enum_type, .{ .enum_type = ed.name }, node.span); + } }, .struct_decl => |sd| { try self.addSymbol(sd.name, .struct_type, .{ .struct_type = sd.name }, node.span); @@ -680,8 +686,8 @@ pub const Analyzer = struct { .union_decl => |ud| { try self.addSymbol(ud.name, .enum_type, .{ .union_type = ud.name }, node.span); }, - .union_literal => |ul| { - if (ul.payload) |p| { + .enum_literal => |el| { + if (el.payload) |p| { try self.analyzeNode(p); } }, @@ -690,7 +696,6 @@ pub const Analyzer = struct { .float_literal, .bool_literal, .string_literal, - .enum_literal, .type_expr, .param, .match_arm, @@ -741,7 +746,6 @@ pub const Analyzer = struct { .comptime_expr, .enum_literal, .struct_literal, - .union_literal, .array_literal, .index_expr, .slice_expr, @@ -936,8 +940,8 @@ pub fn findNodeAtOffset(node: *Node, offset: u32) ?*Node { if (findNodeAtOffset(fi.value, offset)) |found| return found; } }, - .union_literal => |ul| { - if (ul.payload) |p| { + .enum_literal => |el| { + if (el.payload) |p| { if (findNodeAtOffset(p, offset)) |found| return found; } }, @@ -947,7 +951,6 @@ pub fn findNodeAtOffset(node: *Node, offset: u32) ?*Node { .float_literal, .bool_literal, .string_literal, - .enum_literal, .type_expr, .param, .match_arm,