@ enum type
This commit is contained in:
118
src/codegen.zig
118
src/codegen.zig
@@ -116,6 +116,8 @@ pub const CodeGen = struct {
|
||||
flags_enum_types: std.StringHashMap(void),
|
||||
// Enum variant values: maps enum name → resolved i64 values per variant
|
||||
enum_variant_values: std.StringHashMap([]const i64),
|
||||
// Enum backing types: maps enum name → LLVM type for the backing integer (default i64)
|
||||
enum_backing_types: std.StringHashMap(c.LLVMTypeRef),
|
||||
// Built-in functions (printf, etc.)
|
||||
builtins: ?Builtins,
|
||||
// Current function being generated (for alloca insertion)
|
||||
@@ -300,6 +302,7 @@ pub const CodeGen = struct {
|
||||
.union_types = std.StringHashMap(UnionInfo).init(allocator),
|
||||
.flags_enum_types = std.StringHashMap(void).init(allocator),
|
||||
.enum_variant_values = std.StringHashMap([]const i64).init(allocator),
|
||||
.enum_backing_types = std.StringHashMap(c.LLVMTypeRef).init(allocator),
|
||||
.builtins = null,
|
||||
.current_function = null,
|
||||
.scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty,
|
||||
@@ -333,6 +336,7 @@ pub const CodeGen = struct {
|
||||
self.tagged_enum_types.deinit();
|
||||
self.union_types.deinit();
|
||||
self.comptime_globals.deinit();
|
||||
self.enum_backing_types.deinit();
|
||||
self.generic_templates.deinit();
|
||||
self.generic_instances.deinit();
|
||||
self.generic_struct_templates.deinit();
|
||||
@@ -384,7 +388,7 @@ pub const CodeGen = struct {
|
||||
.void_type => c.LLVMVoidTypeInContext(self.context),
|
||||
.boolean => c.LLVMInt1TypeInContext(self.context),
|
||||
.string_type, .slice_type => self.getStringStructType(), // slices use same {ptr, i32} layout
|
||||
.enum_type => c.LLVMInt64TypeInContext(self.context),
|
||||
.enum_type => |name| self.getEnumLLVMType(name),
|
||||
.struct_type => |name| if (self.struct_types.get(name)) |info| info.llvm_type else unreachable,
|
||||
.union_type => |name| if (self.tagged_enum_types.get(name)) |info| info.llvm_type else if (self.union_types.get(name)) |info| info.llvm_type else unreachable,
|
||||
.array_type => |info| {
|
||||
@@ -401,6 +405,11 @@ pub const CodeGen = struct {
|
||||
};
|
||||
}
|
||||
|
||||
fn getEnumLLVMType(self: *CodeGen, enum_name: []const u8) c.LLVMTypeRef {
|
||||
if (self.enum_backing_types.get(enum_name)) |llvm_ty| return llvm_ty;
|
||||
return c.LLVMInt64TypeInContext(self.context);
|
||||
}
|
||||
|
||||
fn getAnyStructType(self: *CodeGen) c.LLVMTypeRef {
|
||||
if (self.any_struct_type) |t| return t;
|
||||
var field_types = [_]c.LLVMTypeRef{
|
||||
@@ -558,9 +567,14 @@ pub const CodeGen = struct {
|
||||
_ = c.LLVMBuildStore(self.builder, val, alloca);
|
||||
break :blk c.LLVMBuildPtrToInt(self.builder, alloca, i64_ty, "any_struct");
|
||||
},
|
||||
.enum_type => blk: {
|
||||
// Enum is i32 tag — extend to i64
|
||||
break :blk c.LLVMBuildZExt(self.builder, val, i64_ty, "any_enum");
|
||||
.enum_type => |ename| blk: {
|
||||
// Enum — extend to i64 for Any storage (no-op if already i64)
|
||||
const enum_llvm_ty = self.getEnumLLVMType(ename);
|
||||
const enum_bits = c.LLVMGetIntTypeWidth(enum_llvm_ty);
|
||||
if (enum_bits < 64)
|
||||
break :blk c.LLVMBuildZExt(self.builder, val, i64_ty, "any_enum")
|
||||
else
|
||||
break :blk val;
|
||||
},
|
||||
.union_type => |uname| blk: {
|
||||
// Union — store to alloca, pass pointer as i64
|
||||
@@ -720,6 +734,12 @@ pub const CodeGen = struct {
|
||||
try self.flags_enum_types.put(ed.name, {});
|
||||
}
|
||||
|
||||
// Register backing type if specified
|
||||
if (ed.backing_type) |bt_node| {
|
||||
const bt = self.resolveType(bt_node);
|
||||
try self.enum_backing_types.put(ed.name, self.typeToLLVM(bt));
|
||||
}
|
||||
|
||||
// Compute and store variant values
|
||||
const values = try self.allocator.alloc(i64, ed.variant_names.len);
|
||||
for (ed.variant_names, 0..) |_, i| {
|
||||
@@ -1634,6 +1654,10 @@ pub const CodeGen = struct {
|
||||
} else {
|
||||
const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns.name, ed.name });
|
||||
try self.enum_types.put(qualified, ed.variant_names);
|
||||
if (ed.backing_type) |bt_node| {
|
||||
const bt = self.resolveType(bt_node);
|
||||
try self.enum_backing_types.put(qualified, self.typeToLLVM(bt));
|
||||
}
|
||||
}
|
||||
},
|
||||
.struct_decl => |sd| {
|
||||
@@ -3108,10 +3132,10 @@ pub const CodeGen = struct {
|
||||
|
||||
const name_z = try self.allocator.dupeZ(u8, name);
|
||||
const union_ty = c.LLVMStructCreateNamed(self.context, name_z.ptr);
|
||||
const i64_ty = c.LLVMInt64TypeInContext(self.context);
|
||||
const tag_ty = self.getEnumLLVMType(name);
|
||||
const i8_ty = c.LLVMInt8TypeInContext(self.context);
|
||||
const payload_array_ty = c.LLVMArrayType2(i8_ty, max_payload_size);
|
||||
var union_fields = [2]c.LLVMTypeRef{ i64_ty, payload_array_ty };
|
||||
var union_fields = [2]c.LLVMTypeRef{ tag_ty, payload_array_ty };
|
||||
c.LLVMStructSetBody(union_ty, &union_fields, 2, 0);
|
||||
|
||||
return .{
|
||||
@@ -3145,6 +3169,10 @@ pub const CodeGen = struct {
|
||||
} else {
|
||||
try self.enum_types.put(synthetic_name, inline_ed.variant_names);
|
||||
_ = try self.getAnyTypeId(synthetic_name, .{ .enum_type = synthetic_name });
|
||||
if (inline_ed.backing_type) |bt_node| {
|
||||
const bt = self.resolveType(bt_node);
|
||||
try self.enum_backing_types.put(synthetic_name, self.typeToLLVM(bt));
|
||||
}
|
||||
}
|
||||
type_node.data = .{ .type_expr = .{ .name = synthetic_name } };
|
||||
},
|
||||
@@ -3198,6 +3226,12 @@ pub const CodeGen = struct {
|
||||
}
|
||||
}
|
||||
|
||||
// Register backing type before buildUnionFields (which uses getEnumLLVMType)
|
||||
if (ud.backing_type) |bt_node| {
|
||||
const bt = self.resolveType(bt_node);
|
||||
try self.enum_backing_types.put(ud.name, self.typeToLLVM(bt));
|
||||
}
|
||||
|
||||
const build = try self.buildUnionFields(ud.name, ud.variant_types);
|
||||
|
||||
try self.tagged_enum_types.put(ud.name, .{
|
||||
@@ -3284,11 +3318,11 @@ pub const CodeGen = struct {
|
||||
|
||||
// Alloca union
|
||||
const alloca = self.buildEntryBlockAlloca(info.llvm_type, "union_tmp");
|
||||
const i64_ty = c.LLVMInt64TypeInContext(self.context);
|
||||
const tag_ty = self.getEnumLLVMType(resolved_name);
|
||||
|
||||
// Store tag (field 0)
|
||||
const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 0, "tag");
|
||||
_ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(i64_ty, idx, 0), tag_gep);
|
||||
_ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_ty, idx, 0), tag_gep);
|
||||
|
||||
// Store payload (field 1) if not void
|
||||
if (el.payload) |payload_node| {
|
||||
@@ -3526,9 +3560,9 @@ pub const CodeGen = struct {
|
||||
|
||||
// Alloca union, store tag
|
||||
const alloca = self.buildEntryBlockAlloca(info.llvm_type, "union_lit");
|
||||
const i32_ty = c.LLVMInt32TypeInContext(self.context);
|
||||
const tag_llvm_ty = self.getEnumLLVMType(uname);
|
||||
const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 0, "tag");
|
||||
_ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(i32_ty, idx, 0), tag_gep);
|
||||
_ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_llvm_ty, idx, 0), tag_gep);
|
||||
|
||||
// Store struct payload
|
||||
if (variant_ty != .void_type) {
|
||||
@@ -3750,6 +3784,9 @@ pub const CodeGen = struct {
|
||||
}
|
||||
}
|
||||
if (target_ty.isEnum()) {
|
||||
const enum_llvm_ty = self.getEnumLLVMType(target_ty.enum_type);
|
||||
const enum_bits = c.LLVMGetIntTypeWidth(enum_llvm_ty);
|
||||
if (enum_bits < 64) return c.LLVMBuildTrunc(self.builder, i64_val, enum_llvm_ty, "any_to_enum");
|
||||
return i64_val;
|
||||
}
|
||||
if (target_ty.isUnion()) {
|
||||
@@ -3805,12 +3842,14 @@ pub const CodeGen = struct {
|
||||
if (src_ty.isUnion() and target_ty.isInt()) {
|
||||
const uname = src_ty.union_type;
|
||||
if (self.tagged_enum_types.get(uname)) |info| {
|
||||
const tag_llvm_ty = self.getEnumLLVMType(uname);
|
||||
const tag_bits = c.LLVMGetIntTypeWidth(tag_llvm_ty);
|
||||
const tmp = self.buildEntryBlockAlloca(info.llvm_type, "union_cast");
|
||||
_ = c.LLVMBuildStore(self.builder, val, tmp);
|
||||
const tag_ptr = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, tmp, 0, "tag_ptr");
|
||||
const tag_val = c.LLVMBuildLoad2(self.builder, c.LLVMInt32TypeInContext(self.context), tag_ptr, "tag_val");
|
||||
if (target_ty.bitWidth() == 32) return tag_val;
|
||||
if (target_ty.bitWidth() > 32) return c.LLVMBuildSExt(self.builder, tag_val, target_llvm, "tag_ext");
|
||||
const tag_val = c.LLVMBuildLoad2(self.builder, tag_llvm_ty, tag_ptr, "tag_val");
|
||||
if (target_ty.bitWidth() == tag_bits) return tag_val;
|
||||
if (target_ty.bitWidth() > tag_bits) return c.LLVMBuildSExt(self.builder, tag_val, target_llvm, "tag_ext");
|
||||
return c.LLVMBuildTrunc(self.builder, tag_val, target_llvm, "tag_trunc");
|
||||
}
|
||||
}
|
||||
@@ -3844,6 +3883,31 @@ pub const CodeGen = struct {
|
||||
return c.LLVMBuildExtractValue(self.builder, val, 0, "slice_to_ptr");
|
||||
}
|
||||
|
||||
// Enum → int: extend or truncate from backing type to target int
|
||||
if (src_ty.isEnum() and target_ty.isInt()) {
|
||||
const enum_bits = c.LLVMGetIntTypeWidth(self.getEnumLLVMType(src_ty.enum_type));
|
||||
const target_bits = target_ty.bitWidth();
|
||||
if (target_bits > enum_bits) {
|
||||
return c.LLVMBuildZExt(self.builder, val, target_llvm, "enum_to_int");
|
||||
} else if (target_bits < enum_bits) {
|
||||
return c.LLVMBuildTrunc(self.builder, val, target_llvm, "enum_to_int");
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
// Int → enum: extend or truncate from source int to backing type
|
||||
if (src_ty.isInt() and target_ty.isEnum()) {
|
||||
const enum_llvm_ty = self.getEnumLLVMType(target_ty.enum_type);
|
||||
const enum_bits = c.LLVMGetIntTypeWidth(enum_llvm_ty);
|
||||
const src_bits = src_ty.bitWidth();
|
||||
if (enum_bits > src_bits) {
|
||||
return c.LLVMBuildZExt(self.builder, val, enum_llvm_ty, "int_to_enum");
|
||||
} else if (enum_bits < src_bits) {
|
||||
return c.LLVMBuildTrunc(self.builder, val, enum_llvm_ty, "int_to_enum");
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
// *[N]T → [*]T: pointer to array decays to many-pointer (both opaque ptrs, no-op)
|
||||
if (src_ty.isPointer() and target_ty.isManyPointer()) {
|
||||
return val;
|
||||
@@ -4113,7 +4177,7 @@ pub const CodeGen = struct {
|
||||
|
||||
// Read tag (field 0)
|
||||
const tag_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 0, "fv_tag_ptr");
|
||||
const tag_val = c.LLVMBuildLoad2(self.builder, c.LLVMInt64TypeInContext(self.context), tag_ptr, "fv_tag");
|
||||
const tag_val = c.LLVMBuildLoad2(self.builder, self.getEnumLLVMType(val_ty.union_type), tag_ptr, "fv_tag");
|
||||
const payload_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 1, "fv_payload_ptr");
|
||||
|
||||
const n = uinfo.variant_names.len;
|
||||
@@ -4126,9 +4190,10 @@ pub const CodeGen = struct {
|
||||
var phi_vals = std.ArrayList(c.LLVMValueRef).empty;
|
||||
var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty;
|
||||
|
||||
const tag_llvm_ty = self.getEnumLLVMType(val_ty.union_type);
|
||||
for (uinfo.variant_types, 0..) |vty, vi| {
|
||||
const case_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fv_ucase");
|
||||
c.LLVMAddCase(sw, c.LLVMConstInt(c.LLVMInt64TypeInContext(self.context), @intCast(vi), 0), case_bb);
|
||||
c.LLVMAddCase(sw, c.LLVMConstInt(tag_llvm_ty, @intCast(vi), 0), case_bb);
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, case_bb);
|
||||
|
||||
const any_val = if (vty == .void_type) blk: {
|
||||
@@ -5608,7 +5673,14 @@ pub const CodeGen = struct {
|
||||
const ptr = c.LLVMBuildIntToPtr(self.builder, any_i64, c.LLVMPointerTypeInContext(self.context, 0), "any_to_struct_ptr");
|
||||
break :blk c.LLVMBuildLoad2(self.builder, info.llvm_type, ptr, "any_to_struct");
|
||||
},
|
||||
.enum_type => any_i64,
|
||||
.enum_type => |ename| blk: {
|
||||
const enum_llvm_ty = self.getEnumLLVMType(ename);
|
||||
const enum_bits = c.LLVMGetIntTypeWidth(enum_llvm_ty);
|
||||
if (enum_bits < 64)
|
||||
break :blk c.LLVMBuildTrunc(self.builder, any_i64, enum_llvm_ty, "any_to_enum")
|
||||
else
|
||||
break :blk any_i64;
|
||||
},
|
||||
.union_type => |uname| blk: {
|
||||
const info = self.tagged_enum_types.get(uname) orelse return self.emitErrorFmt("unknown enum '{s}'", .{uname});
|
||||
const ptr = c.LLVMBuildIntToPtr(self.builder, any_i64, c.LLVMPointerTypeInContext(self.context, 0), "any_to_union_ptr");
|
||||
@@ -6018,16 +6090,16 @@ pub const CodeGen = struct {
|
||||
}
|
||||
|
||||
fn genEnumLiteral(self: *CodeGen, variant_name: []const u8, enum_type_name: []const u8) c.LLVMValueRef {
|
||||
const i64_type = c.LLVMInt64TypeInContext(self.context);
|
||||
const variants = self.enum_types.get(enum_type_name) orelse return c.LLVMConstInt(i64_type, 0, 0);
|
||||
const enum_ty = self.getEnumLLVMType(enum_type_name);
|
||||
const variants = self.enum_types.get(enum_type_name) orelse return c.LLVMConstInt(enum_ty, 0, 0);
|
||||
const values = self.enum_variant_values.get(enum_type_name);
|
||||
for (variants, 0..) |v, i| {
|
||||
if (std.mem.eql(u8, v, variant_name)) {
|
||||
const val: u64 = if (values) |vals| @bitCast(vals[i]) else @intCast(i);
|
||||
return c.LLVMConstInt(i64_type, val, 0);
|
||||
return c.LLVMConstInt(enum_ty, val, 0);
|
||||
}
|
||||
}
|
||||
return c.LLVMConstInt(i64_type, 0, 0);
|
||||
return c.LLVMConstInt(enum_ty, 0, 0);
|
||||
}
|
||||
|
||||
fn lookupVariantValue(self: *CodeGen, enum_name: ?[]const u8, variants: ?[]const []const u8, name: []const u8) u64 {
|
||||
@@ -6064,7 +6136,7 @@ pub const CodeGen = struct {
|
||||
const entry = self.named_values.get(match.subject.data.identifier.name).?;
|
||||
const info = self.tagged_enum_types.get(union_name.?).?;
|
||||
const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, 0, "tag");
|
||||
break :blk c.LLVMBuildLoad2(self.builder, c.LLVMInt64TypeInContext(self.context), tag_gep, "tag_val");
|
||||
break :blk c.LLVMBuildLoad2(self.builder, self.getEnumLLVMType(union_name.?), tag_gep, "tag_val");
|
||||
} else try self.genExpr(match.subject);
|
||||
|
||||
const variants: ?[]const []const u8 = if (union_name) |un|
|
||||
@@ -6076,6 +6148,8 @@ pub const CodeGen = struct {
|
||||
|
||||
const function = self.current_function;
|
||||
const i64_type = c.LLVMInt64TypeInContext(self.context);
|
||||
// Enum/union case constants use the backing type; Any dispatch uses i64
|
||||
const case_int_type = if (enum_name) |en| self.getEnumLLVMType(en) else if (union_name) |un| self.getEnumLLVMType(un) else i64_type;
|
||||
const merge_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "match_end");
|
||||
|
||||
// Create case basic blocks
|
||||
@@ -6103,7 +6177,7 @@ pub const CodeGen = struct {
|
||||
const pat = arm.pattern orelse continue; // skip else arm
|
||||
if (pat.data == .enum_literal) {
|
||||
const idx = self.lookupVariantValue(enum_name orelse union_name, variants, pat.data.enum_literal.name);
|
||||
const case_val = c.LLVMConstInt(i64_type, idx, 0);
|
||||
const case_val = c.LLVMConstInt(case_int_type, idx, 0);
|
||||
c.LLVMAddCase(sw, case_val, case_bbs.items[i]);
|
||||
} else if (pat.data == .type_expr) {
|
||||
// Type-match: resolve type name to Any tag value(s)
|
||||
|
||||
Reference in New Issue
Block a user