enum, union

This commit is contained in:
agra
2026-02-14 13:17:22 +02:00
parent 4ff828fd1a
commit 025b790411
14 changed files with 537 additions and 245 deletions

View File

@@ -108,7 +108,9 @@ pub const CodeGen = struct {
type_aliases: std.StringHashMap([]const u8),
// Struct type registry: maps struct name to field info + LLVM type
struct_types: std.StringHashMap(StructInfo),
// Union type registry: maps union name to variant info + LLVM type
// Tagged enum registry: maps name to variant info + LLVM type (enums with payloads)
tagged_enum_types: std.StringHashMap(TaggedEnumInfo),
// Union registry: maps name to field info + LLVM type (untagged, C-style)
union_types: std.StringHashMap(UnionInfo),
// Built-in functions (printf, etc.)
builtins: ?Builtins,
@@ -193,7 +195,6 @@ pub const CodeGen = struct {
const TypeCategory = enum {
struct_cat,
enum_cat,
union_cat,
vector_cat,
array_cat,
slice_cat,
@@ -237,13 +238,27 @@ pub const CodeGen = struct {
template_name: ?[]const u8 = null, // original template name (e.g. "List")
};
const UnionInfo = struct {
const TaggedEnumInfo = struct {
variant_names: []const []const u8,
variant_types: []const Type, // void_type for void variants
llvm_type: c.LLVMTypeRef, // { i32, [max_payload_size x i8] }
max_payload_size: u64,
};
const PromotedField = struct {
struct_name: []const u8, // the anonymous struct type name
field_index: usize, // field index within that struct
field_type: Type, // type of the promoted field
};
const UnionInfo = struct {
field_names: []const []const u8,
field_types: []const Type,
llvm_type: c.LLVMTypeRef, // [max_size x i8]
total_size: u64,
promoted_fields: std.StringHashMap(PromotedField),
};
// Scope stack entry: records what a name mapped to before being shadowed
const ScopeEntry = struct {
name: []const u8,
@@ -277,6 +292,7 @@ pub const CodeGen = struct {
.enum_types = std.StringHashMap([]const []const u8).init(allocator),
.type_aliases = std.StringHashMap([]const u8).init(allocator),
.struct_types = std.StringHashMap(StructInfo).init(allocator),
.tagged_enum_types = std.StringHashMap(TaggedEnumInfo).init(allocator),
.union_types = std.StringHashMap(UnionInfo).init(allocator),
.builtins = null,
.current_function = null,
@@ -308,6 +324,7 @@ pub const CodeGen = struct {
self.enum_types.deinit();
self.type_aliases.deinit();
self.struct_types.deinit();
self.tagged_enum_types.deinit();
self.union_types.deinit();
self.comptime_globals.deinit();
self.generic_templates.deinit();
@@ -363,7 +380,7 @@ pub const CodeGen = struct {
.string_type, .slice_type => self.getStringStructType(), // slices use same {ptr, i32} layout
.enum_type => c.LLVMInt64TypeInContext(self.context),
.struct_type => |name| if (self.struct_types.get(name)) |info| info.llvm_type else unreachable,
.union_type => |name| if (self.union_types.get(name)) |info| info.llvm_type else unreachable,
.union_type => |name| if (self.tagged_enum_types.get(name)) |info| info.llvm_type else if (self.union_types.get(name)) |info| info.llvm_type else unreachable,
.array_type => |info| {
const elem_ty = Type.fromName(info.element_name) orelse unreachable;
return c.LLVMArrayType2(self.typeToLLVM(elem_ty), info.length);
@@ -413,7 +430,7 @@ pub const CodeGen = struct {
const category: TypeCategory = switch (sx_type) {
.struct_type => .struct_cat,
.enum_type => .enum_cat,
.union_type => .union_cat,
.union_type => .enum_cat,
.vector_type => .vector_cat,
.array_type => .array_cat,
.slice_type => .slice_cat,
@@ -541,7 +558,7 @@ pub const CodeGen = struct {
},
.union_type => |uname| blk: {
// Union — store to alloca, pass pointer as i64
const info = self.union_types.get(uname) orelse
const info = self.tagged_enum_types.get(uname) orelse
return c.LLVMGetUndef(any_ty);
const alloca = self.buildEntryBlockAlloca(info.llvm_type, "any_union_tmp");
_ = c.LLVMBuildStore(self.builder, val, alloca);
@@ -685,8 +702,14 @@ pub const CodeGen = struct {
try self.foreign_libraries.append(self.allocator, ld.lib_name);
},
.enum_decl => |ed| {
try self.enum_types.put(ed.name, ed.variants);
_ = try self.getAnyTypeId(ed.name, .{ .enum_type = ed.name });
if (ed.variant_types.len > 0) {
// Tagged enum with payloads
try self.registerTaggedEnum(ed);
} else {
// Payload-less enum
try self.enum_types.put(ed.name, ed.variant_names);
_ = try self.getAnyTypeId(ed.name, .{ .enum_type = ed.name });
}
},
.struct_decl => |sd| try self.registerStructType(sd),
.union_decl => |ud| try self.registerUnionType(ud),
@@ -768,7 +791,7 @@ pub const CodeGen = struct {
}
}
{
var it = self.union_types.iterator();
var it = self.tagged_enum_types.iterator();
while (it.next()) |entry| {
_ = try self.getAnyTypeId(entry.key_ptr.*, .{ .union_type = entry.key_ptr.* });
}
@@ -1015,13 +1038,15 @@ pub const CodeGen = struct {
if (self.type_aliases.get(name)) |target| {
if (Type.fromName(target)) |t| return t;
if (self.struct_types.contains(target)) return .{ .struct_type = target };
if (self.tagged_enum_types.contains(target)) return .{ .union_type = target };
if (self.union_types.contains(target)) return .{ .union_type = target };
}
// Check enum types
if (self.enum_types.contains(name)) return .{ .enum_type = name };
// Check struct types
if (self.struct_types.contains(name)) return .{ .struct_type = name };
// Check union types
// Check union types (tagged enums and C-style unions)
if (self.tagged_enum_types.contains(name)) return .{ .union_type = name };
if (self.union_types.contains(name)) return .{ .union_type = name };
}
// Safety net: inline declarations that should have been hoisted
@@ -1029,12 +1054,9 @@ pub const CodeGen = struct {
const sn = tn.data.struct_decl.name;
if (self.struct_types.contains(sn)) return .{ .struct_type = sn };
}
if (tn.data == .union_decl) {
const un = tn.data.union_decl.name;
if (self.union_types.contains(un)) return .{ .union_type = un };
}
if (tn.data == .enum_decl) {
const en = tn.data.enum_decl.name;
if (self.tagged_enum_types.contains(en)) return .{ .union_type = en };
if (self.enum_types.contains(en)) return .{ .enum_type = en };
}
return .void_type;
@@ -1178,13 +1200,13 @@ pub const CodeGen = struct {
// Try union
if (self.findUnionInBody(fd.body)) |union_decl| {
if (self.union_types.contains(mangled_name)) {
if (self.tagged_enum_types.contains(mangled_name)) {
return .{ .union_type = mangled_name };
}
return self.registerInstantiatedUnion(mangled_name, union_decl);
return self.registerInstantiatedTaggedEnum(mangled_name, union_decl);
}
return self.emitErrorFmt("type function '{s}' does not return a struct or union", .{template_name});
return self.emitErrorFmt("type function '{s}' does not return a struct or enum", .{template_name});
}
fn registerInstantiatedStruct(self: *CodeGen, mangled_name: []const u8, alias_name: []const u8, struct_decl: ast.StructDecl) !Type {
@@ -1205,10 +1227,10 @@ pub const CodeGen = struct {
return .{ .struct_type = mangled_name };
}
fn registerInstantiatedUnion(self: *CodeGen, mangled_name: []const u8, union_decl: ast.UnionDecl) !Type {
fn registerInstantiatedTaggedEnum(self: *CodeGen, mangled_name: []const u8, union_decl: ast.EnumDecl) !Type {
const build = try self.buildUnionFields(mangled_name, union_decl.variant_types);
try self.union_types.put(mangled_name, .{
try self.tagged_enum_types.put(mangled_name, .{
.variant_names = union_decl.variant_names,
.variant_types = build.variant_sx_types,
.llvm_type = build.llvm_type,
@@ -1248,8 +1270,27 @@ pub const CodeGen = struct {
return findDeclInBody(ast.StructDecl, .struct_decl, body);
}
fn findUnionInBody(_: *CodeGen, body: *Node) ?ast.UnionDecl {
return findDeclInBody(ast.UnionDecl, .union_decl, body);
fn findUnionInBody(_: *CodeGen, body: *Node) ?ast.EnumDecl {
// Tagged enums with payloads are now stored as .enum_decl with variant_types populated
const isTaggedEnum = struct {
fn check(node: *Node) ?ast.EnumDecl {
if (node.data == .enum_decl and node.data.enum_decl.variant_types.len > 0) {
return node.data.enum_decl;
}
return null;
}
};
if (isTaggedEnum.check(body)) |ed| return ed;
const stmts = if (body.data == .block) body.data.block.stmts else return null;
for (stmts) |stmt| {
if (stmt.data == .return_stmt) {
if (stmt.data.return_stmt.value) |val| {
if (isTaggedEnum.check(val)) |ed| return ed;
}
}
if (isTaggedEnum.check(stmt)) |ed| return ed;
}
return null;
}
fn buildFnType(self: *CodeGen, params: []const ast.Param, return_type: ?*Node, name: []const u8) !c.LLVMTypeRef {
@@ -1555,8 +1596,15 @@ pub const CodeGen = struct {
}
},
.enum_decl => |ed| {
const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns.name, ed.name });
try self.enum_types.put(qualified, ed.variants);
if (ed.variant_types.len > 0) {
// Tagged enum with payloads
try self.registerTaggedEnum(ed);
const qualified_u = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns.name, ed.name });
try self.type_aliases.put(qualified_u, ed.name);
} else {
const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns.name, ed.name });
try self.enum_types.put(qualified, ed.variant_names);
}
},
.struct_decl => |sd| {
try self.registerStructType(sd);
@@ -1663,7 +1711,7 @@ pub const CodeGen = struct {
if (name_ptr != null) {
const name = std.mem.span(name_ptr);
if (self.struct_types.contains(name)) return .{ .struct_type = name };
if (self.union_types.contains(name)) return .{ .union_type = name };
if (self.tagged_enum_types.contains(name)) return .{ .union_type = name };
}
}
// Check for array types
@@ -1970,8 +2018,8 @@ pub const CodeGen = struct {
try self.registerStructType(sd);
return null;
},
.union_decl => |ud| {
try self.registerUnionType(ud);
.union_decl => {
// C-style union — registration handled in type pass
return null;
},
.assignment => |asgn| {
@@ -2111,11 +2159,31 @@ pub const CodeGen = struct {
return null;
}
// Union-typed variable
// Union-typed variable (tagged enum or C-style union)
if (sx_ty.isUnion()) {
const uname = self.type_aliases.get(sx_ty.union_type) orelse sx_ty.union_type;
sx_ty = .{ .union_type = uname };
const info = self.union_types.get(uname) orelse return self.emitErrorFmt("unknown union type '{s}'", .{uname});
// C-style (untagged) union
if (self.union_types.get(uname)) |info| {
const name_z = try self.allocator.dupeZ(u8, vd.name);
const alloca = self.buildEntryBlockAlloca(info.llvm_type, name_z.ptr);
if (vd.value == null) {
_ = c.LLVMBuildStore(self.builder, c.LLVMConstNull(info.llvm_type), alloca);
} else if (vd.value.?.data == .undef_literal) {
_ = c.LLVMBuildStore(self.builder, c.LLVMGetUndef(info.llvm_type), alloca);
} else {
return self.emitErrorFmt("union '{s}' must be initialized with '---' or field assignment", .{uname});
}
try self.saveShadowed(vd.name);
try self.named_values.put(vd.name, .{ .ptr = alloca, .ty = sx_ty });
return null;
}
// Tagged enum
const info = self.tagged_enum_types.get(uname) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{uname});
const name_z = try self.allocator.dupeZ(u8, vd.name);
const alloca = self.buildEntryBlockAlloca(info.llvm_type, name_z.ptr);
@@ -2124,19 +2192,9 @@ pub const CodeGen = struct {
_ = c.LLVMBuildStore(self.builder, c.LLVMConstNull(info.llvm_type), alloca);
} else if (vd.value.?.data == .undef_literal) {
_ = c.LLVMBuildStore(self.builder, c.LLVMGetUndef(info.llvm_type), alloca);
} else if (vd.value.?.data == .union_literal) {
const lit_alloca = try self.genUnionLiteral(vd.value.?.data.union_literal, uname);
try self.saveShadowed(vd.name);
try self.named_values.put(vd.name, .{ .ptr = lit_alloca, .ty = sx_ty });
return null;
} else if (vd.value.?.data == .enum_literal) {
// Void variant: .none assigned to union variable
const ul = ast.UnionLiteral{
.union_name = uname,
.variant_name = vd.value.?.data.enum_literal.name,
.payload = null,
};
const lit_alloca = try self.genUnionLiteral(ul, uname);
const el = vd.value.?.data.enum_literal;
const lit_alloca = try self.genTaggedEnumLiteral(el, uname);
try self.saveShadowed(vd.name);
try self.named_values.put(vd.name, .{ .ptr = lit_alloca, .ty = sx_ty });
return null;
@@ -2330,7 +2388,7 @@ pub const CodeGen = struct {
sx_ty = self.inferType(cd.value);
}
// Union-typed constant: delegate to genExprAsType which handles enum_literal + union_literal
// Enum-typed constant: delegate to genExprAsType which handles enum_literal
if (sx_ty.isUnion()) {
const val = try self.genExprAsType(cd.value, sx_ty);
try self.saveShadowed(cd.name);
@@ -2452,14 +2510,16 @@ pub const CodeGen = struct {
return null;
}
// Union reassignment: s = .circle(3.14) or s = .none
// Tagged enum reassignment: s = .circle(3.14) or s = .none
if (entry.ty.isUnion() and asgn.op == .assign) {
const new_alloca = try self.genExprAsType(asgn.value, entry.ty);
// Copy from new alloca to existing alloca
const info = self.union_types.get(entry.ty.union_type).?;
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, new_alloca, "union_load");
_ = c.LLVMBuildStore(self.builder, loaded, entry.ptr);
return null;
if (self.tagged_enum_types.get(entry.ty.union_type)) |info| {
const new_alloca = try self.genExprAsType(asgn.value, entry.ty);
// Copy from new alloca to existing alloca
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, new_alloca, "union_load");
_ = c.LLVMBuildStore(self.builder, loaded, entry.ptr);
return null;
}
// C-style union: full assignment not supported, use field assignment
}
const new_val = try self.genExpr(asgn.value);
@@ -2507,6 +2567,42 @@ pub const CodeGen = struct {
return self.emitError("field assignment through pointer requires a struct pointee");
}
// C-style union field assignment
if (entry.ty.isUnion()) {
const uname = entry.ty.union_type;
if (self.union_types.get(uname)) |info| {
if (self.findUnionFieldIndex(info, fa.field)) |fidx| {
const field_ty = info.field_types[fidx];
const rhs = try self.genExprAsType(asgn.value, field_ty);
if (asgn.op == .assign) {
_ = c.LLVMBuildStore(self.builder, rhs, entry.ptr);
} else {
const field_llvm_ty = self.typeToLLVM(field_ty);
const cur = c.LLVMBuildLoad2(self.builder, field_llvm_ty, entry.ptr, "ucur");
_ = c.LLVMBuildStore(self.builder, self.genCompoundOp(asgn.op, cur, rhs, field_ty), entry.ptr);
}
return null;
}
// Check promoted fields from anonymous structs
if (info.promoted_fields.get(fa.field)) |pf| {
const sinfo = self.struct_types.get(pf.struct_name) orelse
return self.emitErrorFmt("unknown promoted struct '{s}'", .{pf.struct_name});
const gep = c.LLVMBuildStructGEP2(self.builder, sinfo.llvm_type, entry.ptr, @intCast(pf.field_index), "promoted_ptr");
const rhs = try self.genExprAsType(asgn.value, pf.field_type);
if (asgn.op == .assign) {
_ = c.LLVMBuildStore(self.builder, rhs, gep);
} else {
const field_llvm_ty = self.typeToLLVM(pf.field_type);
const cur = c.LLVMBuildLoad2(self.builder, field_llvm_ty, gep, "ucur");
_ = c.LLVMBuildStore(self.builder, self.genCompoundOp(asgn.op, cur, rhs, pf.field_type), gep);
}
return null;
}
return self.emitErrorFmt("no field '{s}' in union '{s}'", .{ fa.field, uname });
}
return self.emitErrorFmt("field assignment not supported on tagged enum '{s}'", .{uname});
}
if (!entry.ty.isStruct()) return self.emitErrorFmt("field access on non-struct variable '{s}'", .{obj_name});
const sname = entry.ty.struct_type;
@@ -2749,9 +2845,6 @@ pub const CodeGen = struct {
const ctx_name: ?[]const u8 = if (self.current_return_type.isStruct()) self.current_return_type.struct_type else null;
return self.genStructLiteral(sl, ctx_name);
},
.union_literal => |ul| {
return self.genUnionLiteral(ul, null);
},
.array_literal => |al| {
// Typed array/vector/slice literal: Type.[elems]
if (al.type_expr) |te| {
@@ -2905,6 +2998,19 @@ pub const CodeGen = struct {
const idx = self.findFieldIndex(info, fa.field) orelse return self.emitErrorFmt("no field '{s}' in struct '{s}'", .{ fa.field, sname });
return c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, @intCast(idx), "addr_field");
}
// &u.field where u is a C-style union — all fields at offset 0
if (entry.ty.isUnion()) {
if (self.union_types.get(entry.ty.union_type)) |info| {
if (self.findUnionFieldIndex(info, fa.field) != null) {
return entry.ptr;
}
if (info.promoted_fields.get(fa.field)) |pf| {
const sinfo = self.struct_types.get(pf.struct_name) orelse
return self.emitErrorFmt("unknown promoted struct '{s}'", .{pf.struct_name});
return c.LLVMBuildStructGEP2(self.builder, sinfo.llvm_type, entry.ptr, @intCast(pf.field_index), "addr_promoted");
}
}
}
// &p.field where p is *Struct — auto-deref through pointer
if (entry.ty.isPointer()) {
const pointee_name = entry.ty.pointer_type.pointee_name;
@@ -2995,14 +3101,21 @@ pub const CodeGen = struct {
type_node.data = .{ .type_expr = .{ .name = synthetic_name } };
},
.union_decl => |inline_ud| {
var hoisted = inline_ud;
hoisted.name = synthetic_name;
try self.registerUnionType(hoisted);
var hoisted_ud = inline_ud;
hoisted_ud.name = synthetic_name;
try self.registerUnionType(hoisted_ud);
type_node.data = .{ .type_expr = .{ .name = synthetic_name } };
},
.enum_decl => |inline_ed| {
try self.enum_types.put(synthetic_name, inline_ed.variants);
_ = try self.getAnyTypeId(synthetic_name, .{ .enum_type = synthetic_name });
if (inline_ed.variant_types.len > 0) {
// Tagged enum with payloads
var hoisted = inline_ed;
hoisted.name = synthetic_name;
try self.registerTaggedEnum(hoisted);
} else {
try self.enum_types.put(synthetic_name, inline_ed.variant_names);
_ = try self.getAnyTypeId(synthetic_name, .{ .enum_type = synthetic_name });
}
type_node.data = .{ .type_expr = .{ .name = synthetic_name } };
},
else => {},
@@ -3047,7 +3160,7 @@ pub const CodeGen = struct {
_ = try self.getAnyTypeId(sd.name, .{ .struct_type = sd.name });
}
fn registerUnionType(self: *CodeGen, ud: ast.UnionDecl) !void {
fn registerTaggedEnum(self: *CodeGen, ud: ast.EnumDecl) !void {
// Pre-pass: hoist inline type declarations from variant types
for (ud.variant_types, 0..) |vt_opt, i| {
if (vt_opt) |vt| {
@@ -3057,7 +3170,7 @@ pub const CodeGen = struct {
const build = try self.buildUnionFields(ud.name, ud.variant_types);
try self.union_types.put(ud.name, .{
try self.tagged_enum_types.put(ud.name, .{
.variant_names = ud.variant_names,
.variant_types = build.variant_sx_types,
.llvm_type = build.llvm_type,
@@ -3066,22 +3179,78 @@ pub const CodeGen = struct {
_ = try self.getAnyTypeId(ud.name, .{ .union_type = ud.name });
}
fn genUnionLiteral(self: *CodeGen, ul: ast.UnionLiteral, expected_union_name: ?[]const u8) !c.LLVMValueRef {
const uname = ul.union_name orelse expected_union_name orelse
fn registerUnionType(self: *CodeGen, ud: ast.UnionDecl) !void {
// Hoist inline type declarations from field types
for (ud.field_types, 0..) |ft, i| {
try self.hoistInlineTypeDecl(ud.name, ud.field_names[i], ft);
}
// Compute max field size and resolve field types
const data_layout = c.LLVMGetModuleDataLayout(self.module);
var field_sx_types = std.ArrayList(Type).empty;
var max_size: u64 = 0;
for (ud.field_types) |ft| {
const resolved = self.resolveType(ft);
try field_sx_types.append(self.allocator, resolved);
const llvm_ty = self.typeToLLVM(resolved);
const size = c.LLVMABISizeOfType(data_layout, llvm_ty);
if (size > max_size) max_size = size;
}
// LLVM type: byte array sized to the largest field
const byte_ty = c.LLVMInt8TypeInContext(self.context);
const llvm_type = c.LLVMArrayType(byte_ty, @intCast(max_size));
const resolved_field_types = try field_sx_types.toOwnedSlice(self.allocator);
// Build promoted fields map from anonymous struct members
var promoted = std.StringHashMap(PromotedField).init(self.allocator);
for (ud.field_names, 0..) |_, i| {
const fty = resolved_field_types[i];
if (fty.isStruct()) {
// Check if this is an anonymous struct (name contains __anon_)
const sname = fty.struct_type;
if (std.mem.indexOf(u8, sname, ".__anon_") != null) {
if (self.struct_types.get(sname)) |sinfo| {
for (sinfo.field_names, 0..) |sf_name, sf_idx| {
try promoted.put(sf_name, .{
.struct_name = sname,
.field_index = sf_idx,
.field_type = sinfo.field_types[sf_idx],
});
}
}
}
}
}
try self.union_types.put(ud.name, .{
.field_names = ud.field_names,
.field_types = resolved_field_types,
.llvm_type = llvm_type,
.total_size = max_size,
.promoted_fields = promoted,
});
// Note: C-style unions are not registered with the Any type system.
// They can't be meaningfully printed as a whole — access individual fields instead.
}
fn genTaggedEnumLiteral(self: *CodeGen, el: ast.EnumLiteral, expected_union_name: ?[]const u8) !c.LLVMValueRef {
const uname = expected_union_name orelse
(if (self.current_return_type.isUnion()) self.current_return_type.union_type else null) orelse
return self.emitError("cannot infer union type for literal");
return self.emitError("cannot infer enum type for literal");
const resolved_name = self.type_aliases.get(uname) orelse uname;
const info = self.union_types.get(resolved_name) orelse return self.emitErrorFmt("unknown union type '{s}'", .{resolved_name});
const info = self.tagged_enum_types.get(resolved_name) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved_name});
// Find variant index
var variant_idx: ?u32 = null;
for (info.variant_names, 0..) |vn, i| {
if (std.mem.eql(u8, vn, ul.variant_name)) {
if (std.mem.eql(u8, vn, el.name)) {
variant_idx = @intCast(i);
break;
}
}
const idx = variant_idx orelse return self.emitErrorFmt("no variant '{s}' in union '{s}'", .{ ul.variant_name, resolved_name });
const idx = variant_idx orelse return self.emitErrorFmt("no variant '{s}' in enum '{s}'", .{ el.name, resolved_name });
// Alloca union
const alloca = self.buildEntryBlockAlloca(info.llvm_type, "union_tmp");
@@ -3092,20 +3261,13 @@ pub const CodeGen = struct {
_ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(i64_ty, idx, 0), tag_gep);
// Store payload (field 1) if not void
if (ul.payload) |payload_node| {
if (el.payload) |payload_node| {
const variant_ty = info.variant_types[idx];
if (variant_ty != .void_type) {
const payload_val = try self.genExprAsType(payload_node, variant_ty);
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 1, "payload");
const payload_llvm_ty = self.typeToLLVM(variant_ty);
// Bitcast payload area to the variant's type pointer and store
if (variant_ty.isStruct()) {
// Struct payload: load from alloca, store to payload area
const struct_val = c.LLVMBuildLoad2(self.builder, payload_llvm_ty, payload_val, "struct_load");
_ = c.LLVMBuildStore(self.builder, struct_val, payload_gep);
} else {
_ = c.LLVMBuildStore(self.builder, payload_val, payload_gep);
}
// genExprAsType returns a loaded value for all types (including structs)
_ = c.LLVMBuildStore(self.builder, payload_val, payload_gep);
}
}
@@ -3284,19 +3446,10 @@ pub const CodeGen = struct {
return c.LLVMBuildGlobalStringPtr(self.builder, str_z.ptr, "str");
}
// Enum literal assigned to union type: construct tag-only (void variant) union
// Enum/union literal assigned to union type: construct tagged enum
if (node.data == .enum_literal and target_ty.isUnion()) {
const ul = ast.UnionLiteral{
.union_name = null,
.variant_name = node.data.enum_literal.name,
.payload = null,
};
return self.genUnionLiteral(ul, target_ty.union_type);
}
// Union literal with target union type: pass context
if (node.data == .union_literal and target_ty.isUnion()) {
return self.genUnionLiteral(node.data.union_literal, target_ty.union_type);
const el = node.data.enum_literal;
return self.genTaggedEnumLiteral(el, target_ty.union_type);
}
// Struct literal targeting union type: .Variant.{fields} pattern
@@ -3308,8 +3461,8 @@ pub const CodeGen = struct {
if (te.data == .enum_literal) {
const variant_name = te.data.enum_literal.name;
const uname = self.type_aliases.get(target_ty.union_type) orelse target_ty.union_type;
const info = self.union_types.get(uname) orelse
return self.emitErrorFmt("unknown union type '{s}'", .{uname});
const info = self.tagged_enum_types.get(uname) orelse
return self.emitErrorFmt("unknown enum type '{s}'", .{uname});
// Find variant index
var variant_idx: ?u32 = null;
@@ -3320,7 +3473,7 @@ pub const CodeGen = struct {
}
}
const idx = variant_idx orelse
return self.emitErrorFmt("no variant '{s}' in union '{s}'", .{ variant_name, uname });
return self.emitErrorFmt("no variant '{s}' in enum '{s}'", .{ variant_name, uname });
const variant_ty = info.variant_types[idx];
@@ -3554,7 +3707,7 @@ pub const CodeGen = struct {
}
if (target_ty.isUnion()) {
const uname = target_ty.union_type;
if (self.union_types.get(uname)) |info| {
if (self.tagged_enum_types.get(uname)) |info| {
const ptr = c.LLVMBuildIntToPtr(self.builder, i64_val, c.LLVMPointerTypeInContext(self.context, 0), "any_union_ptr");
return c.LLVMBuildLoad2(self.builder, info.llvm_type, ptr, "any_to_union");
}
@@ -3604,7 +3757,7 @@ pub const CodeGen = struct {
// Union → int: extract the tag field (index 0)
if (src_ty.isUnion() and target_ty.isInt()) {
const uname = src_ty.union_type;
if (self.union_types.get(uname)) |info| {
if (self.tagged_enum_types.get(uname)) |info| {
const tmp = self.buildEntryBlockAlloca(info.llvm_type, "union_cast");
_ = c.LLVMBuildStore(self.builder, val, tmp);
const tag_ptr = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, tmp, 0, "tag_ptr");
@@ -3667,6 +3820,13 @@ pub const CodeGen = struct {
return null;
}
fn findUnionFieldIndex(_: *CodeGen, info: UnionInfo, name: []const u8) ?usize {
for (info.field_names, 0..) |fn_name, i| {
if (std.mem.eql(u8, fn_name, name)) return i;
}
return null;
}
fn componentToIndex(ch: u8) ?u32 {
return switch (ch) {
'x', 'r', 'u' => 0,
@@ -3819,14 +3979,14 @@ pub const CodeGen = struct {
return c.LLVMConstInt(i64_ty, ty.vector_type.length, 0);
}
if (ty.isUnion()) {
const info = self.union_types.get(ty.union_type) orelse
return self.emitErrorFmt("unknown union type '{s}'", .{ty.union_type});
const info = self.tagged_enum_types.get(ty.union_type) orelse
return self.emitErrorFmt("unknown enum type '{s}'", .{ty.union_type});
return c.LLVMConstInt(i64_ty, info.variant_names.len, 0);
}
if (ty.isArray()) {
return c.LLVMConstInt(i64_ty, ty.array_type.length, 0);
}
return self.emitError("field_count requires a struct, enum, vector, union, or array type");
return self.emitError("field_count requires a struct, enum, vector, or array type");
}
fn genFieldName(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
@@ -3843,10 +4003,10 @@ pub const CodeGen = struct {
return self.emitErrorFmt("unknown enum type '{s}'", .{ty.enum_type});
break :blk .{ variants, ty.enum_type };
} else if (ty.isUnion()) blk: {
const info = self.union_types.get(ty.union_type) orelse
return self.emitErrorFmt("unknown union type '{s}'", .{ty.union_type});
const info = self.tagged_enum_types.get(ty.union_type) orelse
return self.emitErrorFmt("unknown enum type '{s}'", .{ty.union_type});
break :blk .{ info.variant_names, ty.union_type };
} else return self.emitError("field_name requires a struct, enum, or union type");
} else return self.emitError("field_name requires a struct or enum type");
// Build a global array of string slices
const n = names.len;
@@ -3891,10 +4051,15 @@ pub const CodeGen = struct {
return self.buildAnyValue(elem, elem_ty);
}
// Union: switch over tag, extract payload with correct type
// Payload-less enum: return void Any (no payload to extract)
if (val_ty.isEnum() and !val_ty.isUnion()) {
return self.buildAnyValue(c.LLVMConstInt(c.LLVMInt64TypeInContext(self.context), 0, 0), .void_type);
}
// Tagged enum (with payloads): switch over tag, extract payload with correct type
if (val_ty.isUnion()) {
const uinfo = self.union_types.get(val_ty.union_type) orelse
return self.emitErrorFmt("unknown union type '{s}'", .{val_ty.union_type});
const uinfo = self.tagged_enum_types.get(val_ty.union_type) orelse
return self.emitErrorFmt("unknown enum type '{s}'", .{val_ty.union_type});
const union_alloca = self.buildEntryBlockAlloca(uinfo.llvm_type, "fv_union");
_ = c.LLVMBuildStore(self.builder, val, union_alloca);
@@ -3986,7 +4151,7 @@ pub const CodeGen = struct {
// Struct: switch over field indices
const struct_val = val;
const struct_ty = val_ty;
if (!struct_ty.isStruct()) return self.emitError("field_value requires a struct, vector, union, or array value");
if (!struct_ty.isStruct()) return self.emitError("field_value requires a struct, vector, enum, or array value");
const info = self.struct_types.get(struct_ty.struct_type) orelse
return self.emitErrorFmt("unknown struct type '{s}'", .{struct_ty.struct_type});
@@ -4143,7 +4308,24 @@ pub const CodeGen = struct {
}
if (entry.ty.isUnion()) {
const uname = entry.ty.union_type;
const info = self.union_types.get(uname) orelse return self.emitErrorFmt("unknown union type '{s}'", .{uname});
// C-style (untagged) union: bitcast pointer and load
if (self.union_types.get(uname)) |info| {
if (self.findUnionFieldIndex(info, fa.field)) |fidx| {
const field_ty = info.field_types[fidx];
return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(field_ty), entry.ptr, "union_field");
}
// Check promoted fields from anonymous structs
if (info.promoted_fields.get(fa.field)) |pf| {
const sinfo = self.struct_types.get(pf.struct_name) orelse
return self.emitErrorFmt("unknown promoted struct '{s}'", .{pf.struct_name});
// GEP through union pointer as struct type, then access field
const gep = c.LLVMBuildStructGEP2(self.builder, sinfo.llvm_type, entry.ptr, @intCast(pf.field_index), "promoted_field");
return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(pf.field_type), gep, "promoted_val");
}
return self.emitErrorFmt("no field '{s}' in union '{s}'", .{ fa.field, uname });
}
// Tagged enum: GEP to payload area
const info = self.tagged_enum_types.get(uname) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{uname});
// Find variant by name to determine payload type
var vidx: ?usize = null;
for (info.variant_names, 0..) |vn, i| {
@@ -4152,7 +4334,7 @@ pub const CodeGen = struct {
break;
}
}
const idx = vidx orelse return self.emitErrorFmt("no variant '{s}' in union '{s}'", .{ fa.field, uname });
const idx = vidx orelse return self.emitErrorFmt("no variant '{s}' in enum '{s}'", .{ fa.field, uname });
const variant_ty = info.variant_types[idx];
if (variant_ty == .void_type) return self.emitErrorFmt("cannot access payload of void variant '{s}'", .{fa.field});
// GEP to field 1 (payload area), load as variant type
@@ -4515,7 +4697,7 @@ pub const CodeGen = struct {
const resolved_type: ?Type = blk: {
if (fa.object.data == .identifier) {
const name = self.type_aliases.get(fa.object.data.identifier.name) orelse fa.object.data.identifier.name;
if (self.union_types.contains(name)) break :blk .{ .union_type = name };
if (self.tagged_enum_types.contains(name)) break :blk .{ .union_type = name };
if (self.struct_types.contains(name)) break :blk .{ .struct_type = name };
} else {
const ty = self.resolveType(fa.object);
@@ -4527,9 +4709,8 @@ pub const CodeGen = struct {
if (rty.isUnion()) {
const type_name = rty.union_type;
const payload_node: ?*Node = if (call_node.args.len > 0) call_node.args[0] else null;
return self.genUnionLiteral(.{
.union_name = type_name,
.variant_name = fa.field,
return self.genTaggedEnumLiteral(.{
.name = fa.field,
.payload = payload_node,
}, type_name);
}
@@ -4882,7 +5063,7 @@ pub const CodeGen = struct {
Type.fromName(name) == null and
!self.struct_types.contains(name) and
!self.enum_types.contains(name) and
!self.union_types.contains(name) and
!self.tagged_enum_types.contains(name) and
!self.type_aliases.contains(name))
{
return self.genGenericCallWithRuntimeDispatch(template, call_node, match_tags);
@@ -5323,7 +5504,7 @@ pub const CodeGen = struct {
},
.enum_type => any_i64,
.union_type => |uname| blk: {
const info = self.union_types.get(uname) orelse return self.emitErrorFmt("unknown union '{s}'", .{uname});
const info = self.tagged_enum_types.get(uname) orelse return self.emitErrorFmt("unknown enum '{s}'", .{uname});
const ptr = c.LLVMBuildIntToPtr(self.builder, any_i64, c.LLVMPointerTypeInContext(self.context, 0), "any_to_union_ptr");
break :blk c.LLVMBuildLoad2(self.builder, info.llvm_type, ptr, "any_to_union");
},
@@ -5764,13 +5945,13 @@ pub const CodeGen = struct {
const subject_val: c.LLVMValueRef = if (union_name != null) blk: {
// Union: load tag from field 0 of the alloca
const entry = self.named_values.get(match.subject.data.identifier.name).?;
const info = self.union_types.get(union_name.?).?;
const info = self.tagged_enum_types.get(union_name.?).?;
const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, 0, "tag");
break :blk c.LLVMBuildLoad2(self.builder, c.LLVMInt64TypeInContext(self.context), tag_gep, "tag_val");
} else try self.genExpr(match.subject);
const variants: ?[]const []const u8 = if (union_name) |un|
(if (self.union_types.get(un)) |info| info.variant_names else null)
(if (self.tagged_enum_types.get(un)) |info| info.variant_names else null)
else if (enum_name) |en|
self.enum_types.get(en)
else
@@ -5925,7 +6106,7 @@ pub const CodeGen = struct {
else if (std.mem.eql(u8, name, "enum"))
.enum_cat
else if (std.mem.eql(u8, name, "union"))
.union_cat
.enum_cat
else if (std.mem.eql(u8, name, "vector"))
.vector_cat
else if (std.mem.eql(u8, name, "array"))
@@ -5979,7 +6160,7 @@ pub const CodeGen = struct {
.{ .struct_type = name }
else if (self.enum_types.contains(name))
.{ .enum_type = name }
else if (self.union_types.contains(name))
else if (self.tagged_enum_types.contains(name))
.{ .union_type = name }
else
.{ .struct_type = name }; // fallback
@@ -6183,7 +6364,8 @@ pub const CodeGen = struct {
if (Type.fromName(name)) |t| return t;
// Structs
if (self.struct_types.contains(name)) return .{ .struct_type = name };
// Unions
// Unions (tagged enums and C-style)
if (self.tagged_enum_types.contains(name)) return .{ .union_type = name };
if (self.union_types.contains(name)) return .{ .union_type = name };
// Enums
if (self.enum_types.contains(name)) return .{ .enum_type = name };
@@ -6236,11 +6418,6 @@ pub const CodeGen = struct {
}
return .void_type;
},
.union_literal => |ul| {
if (ul.union_name) |uname| return .{ .union_type = uname };
if (self.current_return_type.isUnion()) return self.current_return_type;
return .void_type;
},
.enum_literal => {
if (self.current_return_type.isEnum()) return self.current_return_type;
if (self.current_return_type.isUnion()) return self.current_return_type;
@@ -6259,7 +6436,7 @@ pub const CodeGen = struct {
const obj_ty = blk: {
if (fa.object.data == .identifier) {
const name = self.type_aliases.get(fa.object.data.identifier.name) orelse fa.object.data.identifier.name;
if (self.union_types.contains(name)) break :blk Type{ .union_type = name };
if (self.tagged_enum_types.contains(name)) break :blk Type{ .union_type = name };
}
const ty = self.resolveType(fa.object);
if (ty.isUnion()) break :blk ty;
@@ -6422,6 +6599,14 @@ pub const CodeGen = struct {
}
if (obj_ty.isUnion()) {
if (self.union_types.get(obj_ty.union_type)) |info| {
if (self.findUnionFieldIndex(info, fa.field)) |idx| {
return info.field_types[idx];
}
if (info.promoted_fields.get(fa.field)) |pf| {
return pf.field_type;
}
}
if (self.tagged_enum_types.get(obj_ty.union_type)) |info| {
for (info.variant_names, 0..) |vn, i| {
if (std.mem.eql(u8, vn, fa.field)) {
return info.variant_types[i];