This commit is contained in:
agra
2026-02-14 14:03:16 +02:00
parent 025b790411
commit fe7efeadb0
10 changed files with 320 additions and 10 deletions

View File

@@ -112,6 +112,10 @@ pub const CodeGen = struct {
tagged_enum_types: std.StringHashMap(TaggedEnumInfo),
// Union registry: maps name to field info + LLVM type (untagged, C-style)
union_types: std.StringHashMap(UnionInfo),
// Flags enum registry: tracks which enum names are flags
flags_enum_types: std.StringHashMap(void),
// Enum variant values: maps enum name → resolved i64 values per variant
enum_variant_values: std.StringHashMap([]const i64),
// Built-in functions (printf, etc.)
builtins: ?Builtins,
// Current function being generated (for alloca insertion)
@@ -294,6 +298,8 @@ pub const CodeGen = struct {
.struct_types = std.StringHashMap(StructInfo).init(allocator),
.tagged_enum_types = std.StringHashMap(TaggedEnumInfo).init(allocator),
.union_types = std.StringHashMap(UnionInfo).init(allocator),
.flags_enum_types = std.StringHashMap(void).init(allocator),
.enum_variant_values = std.StringHashMap([]const i64).init(allocator),
.builtins = null,
.current_function = null,
.scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty,
@@ -709,6 +715,30 @@ pub const CodeGen = struct {
// Payload-less enum
try self.enum_types.put(ed.name, ed.variant_names);
_ = try self.getAnyTypeId(ed.name, .{ .enum_type = ed.name });
if (ed.is_flags) {
try self.flags_enum_types.put(ed.name, {});
}
// Compute and store variant values
const values = try self.allocator.alloc(i64, ed.variant_names.len);
for (ed.variant_names, 0..) |_, i| {
if (ed.variant_values.len > i and ed.variant_values[i] != null) {
// Explicit value: evaluate comptime int literal
const val_node = ed.variant_values[i].?;
values[i] = switch (val_node.data) {
.int_literal => |il| il.value,
else => @as(i64, @intCast(i)), // fallback
};
} else if (ed.is_flags) {
// Auto power-of-2: 1, 2, 4, 8, ...
values[i] = @as(i64, 1) << @intCast(i);
} else {
// Regular enum: sequential 0, 1, 2, ...
values[i] = @intCast(i);
}
}
try self.enum_variant_values.put(ed.name, values);
}
},
.struct_decl => |sd| try self.registerStructType(sd),
@@ -3446,6 +3476,23 @@ pub const CodeGen = struct {
return c.LLVMBuildGlobalStringPtr(self.builder, str_z.ptr, "str");
}
// Enum literal assigned to enum type: resolve variant value
if (node.data == .enum_literal and target_ty.isEnum()) {
return self.genEnumLiteral(node.data.enum_literal.name, target_ty.enum_type);
}
// Bitwise op on enum type: recursively generate both sides with enum context
if (node.data == .binary_op and (node.data.binary_op.op == .bit_or or node.data.binary_op.op == .bit_and) and target_ty.isEnum()) {
const binop = node.data.binary_op;
const lhs = try self.genExprAsType(binop.lhs, target_ty);
const rhs = try self.genExprAsType(binop.rhs, target_ty);
const b = self.builder;
return if (binop.op == .bit_or)
c.LLVMBuildOr(b, lhs, rhs, "bortmp")
else
c.LLVMBuildAnd(b, lhs, rhs, "bandtmp");
}
// Enum/union literal assigned to union type: construct tagged enum
if (node.data == .enum_literal and target_ty.isUnion()) {
const el = node.data.enum_literal;
@@ -4203,6 +4250,62 @@ pub const CodeGen = struct {
return phi;
}
fn genIsFlags(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
if (call_node.args.len != 1) return self.emitError("is_flags expects exactly 1 argument");
const ty = self.resolveType(call_node.args[0]);
const i1_type = c.LLVMInt1TypeInContext(self.context);
if (ty.isEnum()) {
const is_flags = self.flags_enum_types.contains(ty.enum_type);
return c.LLVMConstInt(i1_type, @intFromBool(is_flags), 0);
}
return c.LLVMConstInt(i1_type, 0, 0);
}
fn genFieldValueInt(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
if (call_node.args.len != 2) return self.emitError("field_value_int expects 2 arguments: field_value_int(T, idx)");
const ty = self.resolveType(call_node.args[0]);
const i64_type = c.LLVMInt64TypeInContext(self.context);
// For non-enum types (e.g. tagged enums compiled via dead code), return the index as value
if (!ty.isEnum()) {
return try self.genExpr(call_node.args[1]);
}
const enum_name = ty.enum_type;
const values = self.enum_variant_values.get(enum_name);
const variants = self.enum_types.get(enum_name) orelse return try self.genExpr(call_node.args[1]);
const n = variants.len;
const idx = try self.genExpr(call_node.args[1]);
const function = self.current_function;
const merge_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_merge");
const default_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_default");
const sw = c.LLVMBuildSwitch(self.builder, idx, default_bb, @intCast(n));
var phi_vals = std.ArrayList(c.LLVMValueRef).empty;
var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty;
for (0..n) |i| {
const case_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_case");
c.LLVMAddCase(sw, c.LLVMConstInt(i64_type, i, 0), case_bb);
c.LLVMPositionBuilderAtEnd(self.builder, case_bb);
const val: u64 = if (values) |vals| @bitCast(vals[i]) else i;
try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, val, 0));
try phi_bbs.append(self.allocator, case_bb);
_ = c.LLVMBuildBr(self.builder, merge_bb);
}
c.LLVMPositionBuilderAtEnd(self.builder, default_bb);
try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, 0, 0));
try phi_bbs.append(self.allocator, default_bb);
_ = c.LLVMBuildBr(self.builder, merge_bb);
c.LLVMPositionBuilderAtEnd(self.builder, merge_bb);
const vals_slice = try phi_vals.toOwnedSlice(self.allocator);
const bbs_slice = try phi_bbs.toOwnedSlice(self.allocator);
const phi = c.LLVMBuildPhi(self.builder, i64_type, "fvi_result");
c.LLVMAddIncoming(phi, vals_slice.ptr, bbs_slice.ptr, @intCast(vals_slice.len));
return phi;
}
fn genCast(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
if (call_node.args.len != 2) return self.emitError("cast expects: cast(Type) expr");
const target_ty = self.resolveType(call_node.args[0]);
@@ -4591,6 +4694,8 @@ pub const CodeGen = struct {
.lte => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOLE, lhs, rhs, "letmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntULE, lhs, rhs, "letmp") else c.LLVMBuildICmp(b, c.LLVMIntSLE, lhs, rhs, "letmp"),
.gt => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOGT, lhs, rhs, "gttmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntUGT, lhs, rhs, "gttmp") else c.LLVMBuildICmp(b, c.LLVMIntSGT, lhs, rhs, "gttmp"),
.gte => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOGE, lhs, rhs, "getmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntUGE, lhs, rhs, "getmp") else c.LLVMBuildICmp(b, c.LLVMIntSGE, lhs, rhs, "getmp"),
.bit_and => c.LLVMBuildAnd(b, lhs, rhs, "bandtmp"),
.bit_or => c.LLVMBuildOr(b, lhs, rhs, "bortmp"),
.and_op, .or_op => unreachable,
};
}
@@ -5198,6 +5303,7 @@ pub const CodeGen = struct {
try self.instantiateGeneric(fd, bindings, mangled);
// Generate arguments with type conversion to match parameter types
const saved_call_bindings = self.type_param_bindings;
self.type_param_bindings = bindings;
var arg_vals = std.ArrayList(c.LLVMValueRef).empty;
for (call_node.args, 0..) |arg, i| {
@@ -5208,7 +5314,7 @@ pub const CodeGen = struct {
try arg_vals.append(self.allocator, try self.genExpr(arg));
}
}
self.type_param_bindings = null;
self.type_param_bindings = saved_call_bindings;
const args_slice = try arg_vals.toOwnedSlice(self.allocator);
const fn_type = c.LLVMGlobalGetValueType(callee_fn);
@@ -5584,9 +5690,10 @@ pub const CodeGen = struct {
self.scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty;
self.defer_stack = std.ArrayList(std.ArrayList(*Node)).empty;
// Set type param bindings
// Set type param bindings (save/restore to support nested generic instantiation)
const saved_bindings = self.type_param_bindings;
self.type_param_bindings = bindings;
defer self.type_param_bindings = null;
defer self.type_param_bindings = saved_bindings;
// Build the specialized function type
const fn_type = try self.buildFnType(fd.params, fd.return_type, mangled);
@@ -5913,18 +6020,28 @@ pub const CodeGen = struct {
fn genEnumLiteral(self: *CodeGen, variant_name: []const u8, enum_type_name: []const u8) c.LLVMValueRef {
const i64_type = c.LLVMInt64TypeInContext(self.context);
const variants = self.enum_types.get(enum_type_name) orelse return c.LLVMConstInt(i64_type, 0, 0);
const values = self.enum_variant_values.get(enum_type_name);
for (variants, 0..) |v, i| {
if (std.mem.eql(u8, v, variant_name)) {
return c.LLVMConstInt(i64_type, @intCast(i), 0);
const val: u64 = if (values) |vals| @bitCast(vals[i]) else @intCast(i);
return c.LLVMConstInt(i64_type, val, 0);
}
}
return c.LLVMConstInt(i64_type, 0, 0);
}
fn lookupVariantIndex(variants: ?[]const []const u8, name: []const u8) u64 {
fn lookupVariantValue(self: *CodeGen, enum_name: ?[]const u8, variants: ?[]const []const u8, name: []const u8) u64 {
if (variants) |vs| {
for (vs, 0..) |v, i| {
if (std.mem.eql(u8, v, name)) return i;
if (std.mem.eql(u8, v, name)) {
// Use resolved values if available (flags enums, explicit values)
if (enum_name) |en| {
if (self.enum_variant_values.get(en)) |vals| {
return @bitCast(vals[i]);
}
}
return i;
}
}
}
return 0;
@@ -5985,7 +6102,7 @@ pub const CodeGen = struct {
for (match.arms, 0..) |arm, i| {
const pat = arm.pattern orelse continue; // skip else arm
if (pat.data == .enum_literal) {
const idx = lookupVariantIndex(variants, pat.data.enum_literal.name);
const idx = self.lookupVariantValue(enum_name orelse union_name, variants, pat.data.enum_literal.name);
const case_val = c.LLVMConstInt(i64_type, idx, 0);
c.LLVMAddCase(sw, case_val, case_bbs.items[i]);
} else if (pat.data == .type_expr) {
@@ -6235,6 +6352,8 @@ pub const CodeGen = struct {
if (std.mem.eql(u8, base, "field_count")) return self.genFieldCount(call_node);
if (std.mem.eql(u8, base, "field_name")) return self.genFieldName(call_node);
if (std.mem.eql(u8, base, "field_value")) return self.genFieldValue(call_node);
if (std.mem.eql(u8, base, "is_flags")) return self.genIsFlags(call_node);
if (std.mem.eql(u8, base, "field_value_int")) return self.genFieldValueInt(call_node);
return self.emitErrorFmt("unknown builtin function '{s}'", .{name});
}
@@ -6466,6 +6585,10 @@ pub const CodeGen = struct {
if (std.mem.eql(u8, base_name, "field_name")) return .string_type;
// Built-in: field_value returns Any
if (std.mem.eql(u8, base_name, "field_value")) return .{ .any_type = {} };
// Built-in: is_flags returns bool
if (std.mem.eql(u8, base_name, "is_flags")) return .boolean;
// Built-in: field_value_int returns s64
if (std.mem.eql(u8, base_name, "field_value_int")) return Type.s(64);
// Built-in: cast returns the target type (first arg)
if (std.mem.eql(u8, base_name, "cast")) {
if (call_node.args.len > 0) return self.resolveType(call_node.args[0]);