This commit is contained in:
agra
2026-02-14 19:33:33 +02:00
parent d61c6488f3
commit 0e777e9d2e
7 changed files with 957 additions and 72 deletions

View File

@@ -247,8 +247,9 @@ pub const CodeGen = struct {
const TaggedEnumInfo = struct {
variant_names: []const []const u8,
variant_types: []const Type, // void_type for void variants
llvm_type: c.LLVMTypeRef, // { i32, [max_payload_size x i8] }
llvm_type: c.LLVMTypeRef, // layout struct or { tag, [max_payload_size x i8] }
max_payload_size: u64,
payload_field_index: c_uint = 1, // struct field index of the payload array
};
const PromotedField = struct {
@@ -281,13 +282,33 @@ pub const CodeGen = struct {
const module = c.LLVMModuleCreateWithNameInContext(module_name, ctx);
const builder = c.LLVMCreateBuilderInContext(ctx);
// Set target triple on module so it appears in IR output
if (target_config.triple) |t| {
c.LLVMSetTarget(module, t);
// Initialize LLVM targets and set data layout early so alignment queries work
llvm.initAllTargets();
const triple_owned = target_config.triple == null;
const triple = target_config.triple orelse c.LLVMGetDefaultTargetTriple();
defer if (triple_owned) c.LLVMDisposeMessage(@constCast(triple));
c.LLVMSetTarget(module, triple);
var target: c.LLVMTargetRef = null;
var err_msg: [*c]u8 = null;
if (c.LLVMGetTargetFromTriple(triple, &target, &err_msg) == 0) {
const tm = c.LLVMCreateTargetMachine(
target,
triple,
target_config.getCpu(),
target_config.getFeatures(),
target_config.opt_level.toLLVM(),
c.LLVMRelocPIC,
c.LLVMCodeModelDefault,
);
const dl = c.LLVMCreateTargetDataLayout(tm);
c.LLVMSetModuleDataLayout(module, dl);
c.LLVMDisposeTargetData(dl);
c.LLVMDisposeTargetMachine(tm);
} else {
const default_triple = c.LLVMGetDefaultTargetTriple();
c.LLVMSetTarget(module, default_triple);
c.LLVMDisposeMessage(default_triple);
if (err_msg != null) c.LLVMDisposeMessage(err_msg);
}
return .{
.context = ctx,
@@ -1285,6 +1306,7 @@ pub const CodeGen = struct {
.variant_types = build.variant_sx_types,
.llvm_type = build.llvm_type,
.max_payload_size = build.max_payload_size,
.payload_field_index = build.payload_field_index,
});
_ = try self.getAnyTypeId(mangled_name, .{ .union_type = mangled_name });
@@ -1928,6 +1950,16 @@ pub const CodeGen = struct {
const info = self.struct_types.get(sname) orelse return self.emitErrorFmt("unknown struct type '{s}'", .{sname});
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, val, "retval");
_ = c.LLVMBuildRet(self.builder, loaded);
} else if (ret_sx_type.isUnion()) {
// Tagged enum implicit return: val may be alloca or loaded value
const uname = ret_sx_type.union_type;
const resolved = self.type_aliases.get(uname) orelse uname;
const info = self.tagged_enum_types.get(resolved) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved});
const ret_val2 = if (c.LLVMGetTypeKind(c.LLVMTypeOf(val)) == c.LLVMPointerTypeKind)
c.LLVMBuildLoad2(self.builder, info.llvm_type, val, "retval")
else
val;
_ = c.LLVMBuildRet(self.builder, ret_val2);
} else {
const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(val));
const ret_val = self.convertValue(val, src_ty, self.current_return_type);
@@ -2090,6 +2122,17 @@ pub const CodeGen = struct {
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, raw_val, "retval");
try self.emitAllDefers();
_ = c.LLVMBuildRet(self.builder, loaded);
} else if (self.current_return_type.isUnion()) {
// Tagged enum return: raw_val may be alloca (enum literal) or loaded value (identifier/call)
const uname = self.current_return_type.union_type;
const resolved = self.type_aliases.get(uname) orelse uname;
const info = self.tagged_enum_types.get(resolved) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved});
const ret_val = if (c.LLVMGetTypeKind(c.LLVMTypeOf(raw_val)) == c.LLVMPointerTypeKind)
c.LLVMBuildLoad2(self.builder, info.llvm_type, raw_val, "retval")
else
raw_val;
try self.emitAllDefers();
_ = c.LLVMBuildRet(self.builder, ret_val);
} else {
const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(raw_val));
const val = self.convertValue(raw_val, src_ty, self.current_return_type);
@@ -2253,15 +2296,23 @@ pub const CodeGen = struct {
try self.named_values.put(vd.name, .{ .ptr = lit_alloca, .ty = sx_ty });
return null;
} else if (vd.value.?.data == .call) {
// Call returning a union (e.g., Shape.circle(3.14)) — genExpr returns alloca
const result_alloca = try self.genExpr(vd.value.?);
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result_alloca, "union_load");
_ = c.LLVMBuildStore(self.builder, loaded, alloca);
// Call returning a union — could be enum construction (alloca) or function call (value)
const result = try self.genExpr(vd.value.?);
if (c.LLVMGetTypeKind(c.LLVMTypeOf(result)) == c.LLVMPointerTypeKind) {
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result, "union_load");
_ = c.LLVMBuildStore(self.builder, loaded, alloca);
} else {
_ = c.LLVMBuildStore(self.builder, result, alloca);
}
} else {
// Other expression — try genExprAsType
const result_alloca = try self.genExprAsType(vd.value.?, sx_ty);
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result_alloca, "union_load");
_ = c.LLVMBuildStore(self.builder, loaded, alloca);
const result = try self.genExprAsType(vd.value.?, sx_ty);
if (c.LLVMGetTypeKind(c.LLVMTypeOf(result)) == c.LLVMPointerTypeKind) {
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result, "union_load");
_ = c.LLVMBuildStore(self.builder, loaded, alloca);
} else {
_ = c.LLVMBuildStore(self.builder, result, alloca);
}
}
try self.saveShadowed(vd.name);
@@ -2564,13 +2615,16 @@ pub const CodeGen = struct {
return null;
}
// Tagged enum reassignment: s = .circle(3.14) or s = .none
// Tagged enum reassignment: s = .circle(3.14) or s = .none or s = fn_call()
if (entry.ty.isUnion() and asgn.op == .assign) {
if (self.tagged_enum_types.get(entry.ty.union_type)) |info| {
const new_alloca = try self.genExprAsType(asgn.value, entry.ty);
// Copy from new alloca to existing alloca
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, new_alloca, "union_load");
_ = c.LLVMBuildStore(self.builder, loaded, entry.ptr);
const new_val = try self.genExprAsType(asgn.value, entry.ty);
// genExprAsType returns alloca for enum literals, loaded value for calls
const store_val = if (c.LLVMGetTypeKind(c.LLVMTypeOf(new_val)) == c.LLVMPointerTypeKind)
c.LLVMBuildLoad2(self.builder, info.llvm_type, new_val, "union_load")
else
new_val;
_ = c.LLVMBuildStore(self.builder, store_val, entry.ptr);
return null;
}
// C-style union: full assignment not supported, use field assignment
@@ -2860,6 +2914,33 @@ pub const CodeGen = struct {
const lhs_ty = self.inferType(binop.lhs);
const rhs_ty = self.inferType(binop.rhs);
const result_type = Type.widen(lhs_ty, rhs_ty);
// Tagged enum comparison: compare tags only
if (result_type.isUnion() and (binop.op == .eq or binop.op == .neq)) {
const uname = result_type.union_type;
const resolved = self.type_aliases.get(uname) orelse uname;
const info = self.tagged_enum_types.get(resolved) orelse return self.emitError("unknown tagged enum type");
const tag_ty = self.getEnumLLVMType(resolved);
var lhs_val = try self.genExprAsType(binop.lhs, result_type);
var rhs_val = try self.genExprAsType(binop.rhs, result_type);
// If either side is a pointer (alloca from genTaggedEnumLiteral), load it
if (c.LLVMGetTypeKind(c.LLVMTypeOf(lhs_val)) == c.LLVMPointerTypeKind) {
lhs_val = c.LLVMBuildLoad2(self.builder, info.llvm_type, lhs_val, "union_load_l");
}
if (c.LLVMGetTypeKind(c.LLVMTypeOf(rhs_val)) == c.LLVMPointerTypeKind) {
rhs_val = c.LLVMBuildLoad2(self.builder, info.llvm_type, rhs_val, "union_load_r");
}
// Extract tags (field 0) and compare
const lhs_tag = c.LLVMBuildExtractValue(self.builder, lhs_val, 0, "lhs_tag");
const rhs_tag = c.LLVMBuildExtractValue(self.builder, rhs_val, 0, "rhs_tag");
_ = tag_ty;
const pred: c_uint = if (binop.op == .eq) c.LLVMIntEQ else c.LLVMIntNE;
return c.LLVMBuildICmp(self.builder, pred, lhs_tag, rhs_tag, "tag_cmp");
}
const lhs = try self.genExprAsType(binop.lhs, result_type);
const rhs = try self.genExprAsType(binop.rhs, result_type);
return self.genBinaryOp(binop.op, lhs, rhs, result_type);
@@ -2895,6 +2976,15 @@ pub const CodeGen = struct {
.xx, .address_of => unreachable,
};
},
.enum_literal => |el| {
if (self.current_return_type.isUnion()) {
return self.genTaggedEnumLiteral(el, self.current_return_type.union_type);
}
if (self.current_return_type.isEnum()) {
return self.genEnumLiteral(el.name, self.current_return_type.enum_type);
}
return self.emitError("cannot infer enum type for literal");
},
.struct_literal => |sl| {
const ctx_name: ?[]const u8 = if (self.current_return_type.isStruct()) self.current_return_type.struct_type else null;
return self.genStructLiteral(sl, ctx_name);
@@ -2976,10 +3066,29 @@ pub const CodeGen = struct {
.return_stmt => |rs| {
if (rs.value) |val_node| {
const raw_val = try self.genExpr(val_node);
const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(raw_val));
const val = self.convertValue(raw_val, src_ty, self.current_return_type);
try self.emitAllDefers();
_ = c.LLVMBuildRet(self.builder, val);
if (self.current_return_type.isStruct()) {
const sname = self.current_return_type.struct_type;
const resolved = self.type_aliases.get(sname) orelse sname;
const sinfo = self.struct_types.get(resolved) orelse return self.emitErrorFmt("unknown struct type '{s}'", .{resolved});
const loaded = c.LLVMBuildLoad2(self.builder, sinfo.llvm_type, raw_val, "retval");
try self.emitAllDefers();
_ = c.LLVMBuildRet(self.builder, loaded);
} else if (self.current_return_type.isUnion()) {
const uname = self.current_return_type.union_type;
const resolved = self.type_aliases.get(uname) orelse uname;
const info = self.tagged_enum_types.get(resolved) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved});
const ret_val = if (c.LLVMGetTypeKind(c.LLVMTypeOf(raw_val)) == c.LLVMPointerTypeKind)
c.LLVMBuildLoad2(self.builder, info.llvm_type, raw_val, "retval")
else
raw_val;
try self.emitAllDefers();
_ = c.LLVMBuildRet(self.builder, ret_val);
} else {
const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(raw_val));
const val = self.convertValue(raw_val, src_ty, self.current_return_type);
try self.emitAllDefers();
_ = c.LLVMBuildRet(self.builder, val);
}
} else {
try self.emitAllDefers();
_ = c.LLVMBuildRetVoid(self.builder);
@@ -3111,6 +3220,7 @@ pub const CodeGen = struct {
variant_sx_types: []const Type,
llvm_type: c.LLVMTypeRef,
max_payload_size: u64,
payload_field_index: c_uint,
};
fn buildUnionFields(self: *CodeGen, name: []const u8, variant_type_nodes: []const ?*Node) !UnionBuildResult {
@@ -3134,6 +3244,7 @@ pub const CodeGen = struct {
const union_ty = c.LLVMStructCreateNamed(self.context, name_z.ptr);
const tag_ty = self.getEnumLLVMType(name);
const i8_ty = c.LLVMInt8TypeInContext(self.context);
const payload_array_ty = c.LLVMArrayType2(i8_ty, max_payload_size);
var union_fields = [2]c.LLVMTypeRef{ tag_ty, payload_array_ty };
c.LLVMStructSetBody(union_ty, &union_fields, 2, 0);
@@ -3142,6 +3253,7 @@ pub const CodeGen = struct {
.variant_sx_types = try variant_sx_types.toOwnedSlice(self.allocator),
.llvm_type = union_ty,
.max_payload_size = max_payload_size,
.payload_field_index = 1,
};
}
@@ -3226,21 +3338,176 @@ pub const CodeGen = struct {
}
}
// Register backing type before buildUnionFields (which uses getEnumLLVMType)
if (ud.backing_type) |bt_node| {
const bt = self.resolveType(bt_node);
try self.enum_backing_types.put(ud.name, self.typeToLLVM(bt));
// Check if backing type is a struct layout specification
const layout_info = try self.resolveEnumLayout(ud);
if (layout_info) |layout| {
// Struct-backed layout: use the struct's LLVM type directly
try self.enum_backing_types.put(ud.name, layout.tag_llvm_type);
// Resolve variant sx types
var variant_sx_types = std.ArrayList(Type).empty;
for (ud.variant_types) |vt| {
if (vt) |type_node| {
try variant_sx_types.append(self.allocator, self.resolveType(type_node));
} else {
try variant_sx_types.append(self.allocator, .void_type);
}
}
try self.tagged_enum_types.put(ud.name, .{
.variant_names = ud.variant_names,
.variant_types = try variant_sx_types.toOwnedSlice(self.allocator),
.llvm_type = layout.llvm_type,
.max_payload_size = layout.payload_size,
.payload_field_index = layout.payload_field_index,
});
} else {
// Primitive backing type (e.g. enum u32 { ... })
if (ud.backing_type) |bt_node| {
const bt = self.resolveType(bt_node);
try self.enum_backing_types.put(ud.name, self.typeToLLVM(bt));
}
const build = try self.buildUnionFields(ud.name, ud.variant_types);
try self.tagged_enum_types.put(ud.name, .{
.variant_names = ud.variant_names,
.variant_types = build.variant_sx_types,
.llvm_type = build.llvm_type,
.max_payload_size = build.max_payload_size,
.payload_field_index = build.payload_field_index,
});
}
const build = try self.buildUnionFields(ud.name, ud.variant_types);
try self.tagged_enum_types.put(ud.name, .{
.variant_names = ud.variant_names,
.variant_types = build.variant_sx_types,
.llvm_type = build.llvm_type,
.max_payload_size = build.max_payload_size,
});
_ = try self.getAnyTypeId(ud.name, .{ .union_type = ud.name });
// Compute and store variant values (explicit or sequential)
const values = try self.allocator.alloc(i64, ud.variant_names.len);
for (ud.variant_names, 0..) |_, i| {
if (ud.variant_values.len > i and ud.variant_values[i] != null) {
const val_node = ud.variant_values[i].?;
values[i] = switch (val_node.data) {
.int_literal => |il| il.value,
else => @as(i64, @intCast(i)),
};
} else {
values[i] = @intCast(i);
}
}
try self.enum_variant_values.put(ud.name, values);
}
const EnumLayoutInfo = struct {
llvm_type: c.LLVMTypeRef,
tag_llvm_type: c.LLVMTypeRef,
payload_field_index: c_uint,
payload_size: u64,
};
/// Resolve a struct-backed layout for a tagged enum.
/// Returns null if the backing type is a primitive (e.g. u32), in which case
/// the caller should fall back to buildUnionFields.
///
/// The layout struct must have:
/// - A field named `tag` (integer type) — the discriminant
/// - A field named `payload` (array type) — the overlay area for variant data
/// - Any other fields are treated as padding/reserved
fn resolveEnumLayout(self: *CodeGen, ud: ast.EnumDecl) !?EnumLayoutInfo {
const bt_node = ud.backing_type orelse return null;
// Check for inline struct: enum struct { ... } { ... }
if (bt_node.data == .struct_decl) {
const layout_name = try std.fmt.allocPrint(self.allocator, "{s}.__layout", .{ud.name});
var sd = bt_node.data.struct_decl;
sd.name = layout_name;
try self.registerStructType(sd);
return try self.validateEnumLayout(ud.name, layout_name);
}
// Check for named struct reference: enum MyLayout { ... }
if (bt_node.data == .type_expr) {
const name = bt_node.data.type_expr.name;
// If it resolves to a primitive type, it's not a layout struct
if (Type.fromName(name) != null) return null;
// Check type aliases
const resolved = self.type_aliases.get(name) orelse name;
if (Type.fromName(resolved) != null) return null;
// Must be a registered struct
if (self.struct_types.contains(resolved)) {
return try self.validateEnumLayout(ud.name, resolved);
}
}
return null;
}
fn validateEnumLayout(self: *CodeGen, enum_name: []const u8, layout_name: []const u8) !EnumLayoutInfo {
const layout = self.struct_types.get(layout_name) orelse {
return self.emitErrorFmt("enum '{s}': layout type '{s}' is not a registered struct", .{ enum_name, layout_name });
};
// Find 'tag' field
var tag_index: ?usize = null;
var payload_index: ?usize = null;
for (layout.field_names, 0..) |fname, i| {
if (std.mem.eql(u8, fname, "tag")) {
tag_index = i;
} else if (std.mem.eql(u8, fname, "payload")) {
payload_index = i;
}
}
if (tag_index == null) {
return self.emitErrorFmt(
"enum '{s}': layout struct '{s}' must have a field named 'tag' (the discriminant). Expected layout: struct {{ tag: <int_type>; payload: [N]<type>; }}",
.{ enum_name, layout_name },
);
}
if (payload_index == null) {
return self.emitErrorFmt(
"enum '{s}': layout struct '{s}' must have a field named 'payload' (the variant data area). Expected layout: struct {{ tag: <int_type>; payload: [N]<type>; }}",
.{ enum_name, layout_name },
);
}
const tag_ty = layout.field_types[tag_index.?];
const payload_ty = layout.field_types[payload_index.?];
// Validate tag is an integer type
switch (tag_ty) {
.signed, .unsigned => {},
else => return self.emitErrorFmt(
"enum '{s}': layout field 'tag' must be an integer type (e.g. u32), got '{s}'",
.{ enum_name, tag_ty.displayName(self.allocator) catch "?" },
),
}
// Validate payload is an array type
const payload_size = switch (payload_ty) {
.array_type => |info| blk: {
const elem_ty = Type.fromName(info.element_name) orelse {
return self.emitErrorFmt(
"enum '{s}': layout field 'payload' has unresolved element type '{s}'",
.{ enum_name, info.element_name },
);
};
const elem_llvm = self.typeToLLVM(elem_ty);
const data_layout = c.LLVMGetModuleDataLayout(self.module);
break :blk c.LLVMStoreSizeOfType(data_layout, elem_llvm) * info.length;
},
else => return self.emitErrorFmt(
"enum '{s}': layout field 'payload' must be an array type (e.g. [30]u32), got '{s}'",
.{ enum_name, payload_ty.displayName(self.allocator) catch "?" },
),
};
return .{
.llvm_type = layout.llvm_type,
.tag_llvm_type = self.typeToLLVM(tag_ty),
.payload_field_index = @intCast(payload_index.?),
.payload_size = payload_size,
};
}
fn registerUnionType(self: *CodeGen, ud: ast.UnionDecl) !void {
@@ -3320,16 +3587,17 @@ pub const CodeGen = struct {
const alloca = self.buildEntryBlockAlloca(info.llvm_type, "union_tmp");
const tag_ty = self.getEnumLLVMType(resolved_name);
// Store tag (field 0)
// Store tag (field 0) — use explicit value if available, otherwise index
const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 0, "tag");
_ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_ty, idx, 0), tag_gep);
const tag_val: u64 = if (self.enum_variant_values.get(resolved_name)) |vals| @bitCast(vals[idx]) else idx;
_ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_ty, tag_val, 0), tag_gep);
// Store payload (field 1) if not void
if (el.payload) |payload_node| {
const variant_ty = info.variant_types[idx];
if (variant_ty != .void_type) {
const payload_val = try self.genExprAsType(payload_node, variant_ty);
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 1, "payload");
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, info.payload_field_index, "payload");
// genExprAsType returns a loaded value for all types (including structs)
_ = c.LLVMBuildStore(self.builder, payload_val, payload_gep);
}
@@ -3572,7 +3840,7 @@ pub const CodeGen = struct {
.type_expr = null,
.field_inits = sl.field_inits,
}, payload_struct_name);
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 1, "payload");
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, info.payload_field_index, "payload");
const payload_llvm_ty = self.typeToLLVM(variant_ty);
const struct_val = c.LLVMBuildLoad2(self.builder, payload_llvm_ty, payload_alloca, "struct_load");
_ = c.LLVMBuildStore(self.builder, struct_val, payload_gep);
@@ -3687,7 +3955,12 @@ pub const CodeGen = struct {
std.mem.eql(u8, src_ty.struct_type, pointee_name) or
(if (self.type_aliases.get(src_ty.struct_type)) |alias| std.mem.eql(u8, alias, pointee_name) else false) or
(if (self.type_aliases.get(pointee_name)) |alias| std.mem.eql(u8, alias, src_ty.struct_type) else false)
else if (Type.fromName(pointee_name)) |pointee_ty|
else if (src_ty.isUnion()) blk: {
const uname = src_ty.union_type;
break :blk std.mem.eql(u8, uname, pointee_name) or
(if (self.type_aliases.get(uname)) |alias| std.mem.eql(u8, alias, pointee_name) else false) or
(if (self.type_aliases.get(pointee_name)) |alias| std.mem.eql(u8, alias, uname) else false);
} else if (Type.fromName(pointee_name)) |pointee_ty|
src_ty.eql(pointee_ty)
else
false;
@@ -4178,7 +4451,7 @@ pub const CodeGen = struct {
// Read tag (field 0)
const tag_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 0, "fv_tag_ptr");
const tag_val = c.LLVMBuildLoad2(self.builder, self.getEnumLLVMType(val_ty.union_type), tag_ptr, "fv_tag");
const payload_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 1, "fv_payload_ptr");
const payload_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, uinfo.payload_field_index, "fv_payload_ptr");
const n = uinfo.variant_names.len;
const function = self.current_function;
@@ -4371,6 +4644,72 @@ pub const CodeGen = struct {
return phi;
}
fn genFieldIndex(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
if (call_node.args.len != 2) return self.emitError("field_index expects 2 arguments: field_index(T, value)");
const ty = self.resolveType(call_node.args[0]);
const i64_type = c.LLVMInt64TypeInContext(self.context);
if (!ty.isEnum()) {
_ = try self.genExpr(call_node.args[1]);
return c.LLVMConstInt(i64_type, 0, 0);
}
const enum_name = ty.enum_type;
// Flags enums don't use sequential indices
if (self.flags_enum_types.contains(enum_name)) {
_ = try self.genExpr(call_node.args[1]);
return c.LLVMConstInt(i64_type, 0, 0);
}
const values = self.enum_variant_values.get(enum_name);
const variants = self.enum_types.get(enum_name) orelse return try self.genExpr(call_node.args[1]);
const n = variants.len;
const val = try self.genExpr(call_node.args[1]);
// Ensure the switch value uses the enum's backing type
const enum_llvm_ty = self.getEnumLLVMType(enum_name);
const sw_val = if (c.LLVMTypeOf(val) != enum_llvm_ty)
c.LLVMBuildIntCast2(self.builder, val, enum_llvm_ty, 0, "fi_cast")
else
val;
const function = self.current_function;
const merge_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fi_merge");
const default_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fi_default");
const sw = c.LLVMBuildSwitch(self.builder, sw_val, default_bb, @intCast(n));
var phi_vals = std.ArrayList(c.LLVMValueRef).empty;
var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty;
var seen_values = std.ArrayList(u64).empty;
for (0..n) |i| {
const explicit_val: u64 = if (values) |vals| @bitCast(vals[i]) else i;
// Skip duplicate values (first one wins)
var is_dup = false;
for (seen_values.items) |sv| {
if (sv == explicit_val) { is_dup = true; break; }
}
if (is_dup) continue;
try seen_values.append(self.allocator, explicit_val);
const case_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fi_case");
c.LLVMAddCase(sw, c.LLVMConstInt(enum_llvm_ty, explicit_val, 0), case_bb);
c.LLVMPositionBuilderAtEnd(self.builder, case_bb);
try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, i, 0));
try phi_bbs.append(self.allocator, case_bb);
_ = c.LLVMBuildBr(self.builder, merge_bb);
}
c.LLVMPositionBuilderAtEnd(self.builder, default_bb);
const neg_one = c.LLVMConstInt(i64_type, @bitCast(@as(i64, -1)), 0);
try phi_vals.append(self.allocator, neg_one);
try phi_bbs.append(self.allocator, default_bb);
_ = c.LLVMBuildBr(self.builder, merge_bb);
c.LLVMPositionBuilderAtEnd(self.builder, merge_bb);
const vals_slice = try phi_vals.toOwnedSlice(self.allocator);
const bbs_slice = try phi_bbs.toOwnedSlice(self.allocator);
const phi = c.LLVMBuildPhi(self.builder, i64_type, "fi_result");
c.LLVMAddIncoming(phi, vals_slice.ptr, bbs_slice.ptr, @intCast(vals_slice.len));
return phi;
}
fn genCast(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
if (call_node.args.len != 2) return self.emitError("cast expects: cast(Type) expr");
const target_ty = self.resolveType(call_node.args[0]);
@@ -4505,8 +4844,8 @@ pub const CodeGen = struct {
const idx = vidx orelse return self.emitErrorFmt("no variant '{s}' in enum '{s}'", .{ fa.field, uname });
const variant_ty = info.variant_types[idx];
if (variant_ty == .void_type) return self.emitErrorFmt("cannot access payload of void variant '{s}'", .{fa.field});
// GEP to field 1 (payload area), load as variant type
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, 1, "payload");
// GEP to payload area, load as variant type
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, info.payload_field_index, "payload");
return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(variant_ty), payload_gep, "union_payload");
}
if (entry.ty.isVector()) {
@@ -6123,12 +6462,9 @@ pub const CodeGen = struct {
// Determine subject type for enum vs union dispatch
var enum_name: ?[]const u8 = null;
var union_name: ?[]const u8 = null;
if (match.subject.data == .identifier) {
if (self.named_values.get(match.subject.data.identifier.name)) |entry| {
if (entry.ty.isEnum()) enum_name = entry.ty.enum_type;
if (entry.ty.isUnion()) union_name = entry.ty.union_type;
}
}
const subject_ty = self.inferType(match.subject);
if (subject_ty.isEnum()) enum_name = subject_ty.enum_type;
if (subject_ty.isUnion()) union_name = subject_ty.union_type;
// Get the switch value: for unions, load the tag from field 0; for enums, use the value directly
const subject_val: c.LLVMValueRef = if (union_name != null) blk: {
@@ -6224,6 +6560,32 @@ pub const CodeGen = struct {
// Category/type arm with no matching types — BB is unreachable, skip body
_ = c.LLVMBuildBr(self.builder, merge_bb);
} else {
// Payload capture: bind variant payload as a local variable
if (arm.capture) |cap_name| {
if (union_name) |un| {
const uinfo = self.tagged_enum_types.get(un).?;
const pat = arm.pattern.?;
if (pat.data == .enum_literal) {
const vname = pat.data.enum_literal.name;
var vidx: ?usize = null;
for (uinfo.variant_names, 0..) |vn, vi| {
if (std.mem.eql(u8, vn, vname)) { vidx = vi; break; }
}
if (vidx) |vi| {
const variant_ty = uinfo.variant_types[vi];
if (variant_ty != .void_type) {
const subject_entry = self.named_values.get(match.subject.data.identifier.name).?;
const payload_gep = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, subject_entry.ptr, uinfo.payload_field_index, "cap_payload");
const payload_llvm_ty = self.typeToLLVM(variant_ty);
const payload_val = c.LLVMBuildLoad2(self.builder, payload_llvm_ty, payload_gep, "cap_load");
const cap_alloca = c.LLVMBuildAlloca(self.builder, payload_llvm_ty, @ptrCast(cap_name.ptr));
_ = c.LLVMBuildStore(self.builder, payload_val, cap_alloca);
try self.named_values.put(cap_name, .{ .ptr = cap_alloca, .ty = variant_ty });
}
}
}
}
}
// Set match arm context for runtime type dispatch
const saved_match_tags = self.current_match_tags;
self.current_match_tags = arm_tag_values.items[i];
@@ -6428,6 +6790,7 @@ pub const CodeGen = struct {
if (std.mem.eql(u8, base, "field_value")) return self.genFieldValue(call_node);
if (std.mem.eql(u8, base, "is_flags")) return self.genIsFlags(call_node);
if (std.mem.eql(u8, base, "field_value_int")) return self.genFieldValueInt(call_node);
if (std.mem.eql(u8, base, "field_index")) return self.genFieldIndex(call_node);
return self.emitErrorFmt("unknown builtin function '{s}'", .{name});
}
@@ -6663,6 +7026,8 @@ pub const CodeGen = struct {
if (std.mem.eql(u8, base_name, "is_flags")) return .boolean;
// Built-in: field_value_int returns s64
if (std.mem.eql(u8, base_name, "field_value_int")) return Type.s(64);
// Built-in: field_index returns s64
if (std.mem.eql(u8, base_name, "field_index")) return Type.s(64);
// Built-in: cast returns the target type (first arg)
if (std.mem.eql(u8, base_name, "cast")) {
if (call_node.args.len > 0) return self.resolveType(call_node.args[0]);