flags
This commit is contained in:
137
src/codegen.zig
137
src/codegen.zig
@@ -112,6 +112,10 @@ pub const CodeGen = struct {
|
||||
tagged_enum_types: std.StringHashMap(TaggedEnumInfo),
|
||||
// Union registry: maps name to field info + LLVM type (untagged, C-style)
|
||||
union_types: std.StringHashMap(UnionInfo),
|
||||
// Flags enum registry: tracks which enum names are flags
|
||||
flags_enum_types: std.StringHashMap(void),
|
||||
// Enum variant values: maps enum name → resolved i64 values per variant
|
||||
enum_variant_values: std.StringHashMap([]const i64),
|
||||
// Built-in functions (printf, etc.)
|
||||
builtins: ?Builtins,
|
||||
// Current function being generated (for alloca insertion)
|
||||
@@ -294,6 +298,8 @@ pub const CodeGen = struct {
|
||||
.struct_types = std.StringHashMap(StructInfo).init(allocator),
|
||||
.tagged_enum_types = std.StringHashMap(TaggedEnumInfo).init(allocator),
|
||||
.union_types = std.StringHashMap(UnionInfo).init(allocator),
|
||||
.flags_enum_types = std.StringHashMap(void).init(allocator),
|
||||
.enum_variant_values = std.StringHashMap([]const i64).init(allocator),
|
||||
.builtins = null,
|
||||
.current_function = null,
|
||||
.scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty,
|
||||
@@ -709,6 +715,30 @@ pub const CodeGen = struct {
|
||||
// Payload-less enum
|
||||
try self.enum_types.put(ed.name, ed.variant_names);
|
||||
_ = try self.getAnyTypeId(ed.name, .{ .enum_type = ed.name });
|
||||
|
||||
if (ed.is_flags) {
|
||||
try self.flags_enum_types.put(ed.name, {});
|
||||
}
|
||||
|
||||
// Compute and store variant values
|
||||
const values = try self.allocator.alloc(i64, ed.variant_names.len);
|
||||
for (ed.variant_names, 0..) |_, i| {
|
||||
if (ed.variant_values.len > i and ed.variant_values[i] != null) {
|
||||
// Explicit value: evaluate comptime int literal
|
||||
const val_node = ed.variant_values[i].?;
|
||||
values[i] = switch (val_node.data) {
|
||||
.int_literal => |il| il.value,
|
||||
else => @as(i64, @intCast(i)), // fallback
|
||||
};
|
||||
} else if (ed.is_flags) {
|
||||
// Auto power-of-2: 1, 2, 4, 8, ...
|
||||
values[i] = @as(i64, 1) << @intCast(i);
|
||||
} else {
|
||||
// Regular enum: sequential 0, 1, 2, ...
|
||||
values[i] = @intCast(i);
|
||||
}
|
||||
}
|
||||
try self.enum_variant_values.put(ed.name, values);
|
||||
}
|
||||
},
|
||||
.struct_decl => |sd| try self.registerStructType(sd),
|
||||
@@ -3446,6 +3476,23 @@ pub const CodeGen = struct {
|
||||
return c.LLVMBuildGlobalStringPtr(self.builder, str_z.ptr, "str");
|
||||
}
|
||||
|
||||
// Enum literal assigned to enum type: resolve variant value
|
||||
if (node.data == .enum_literal and target_ty.isEnum()) {
|
||||
return self.genEnumLiteral(node.data.enum_literal.name, target_ty.enum_type);
|
||||
}
|
||||
|
||||
// Bitwise op on enum type: recursively generate both sides with enum context
|
||||
if (node.data == .binary_op and (node.data.binary_op.op == .bit_or or node.data.binary_op.op == .bit_and) and target_ty.isEnum()) {
|
||||
const binop = node.data.binary_op;
|
||||
const lhs = try self.genExprAsType(binop.lhs, target_ty);
|
||||
const rhs = try self.genExprAsType(binop.rhs, target_ty);
|
||||
const b = self.builder;
|
||||
return if (binop.op == .bit_or)
|
||||
c.LLVMBuildOr(b, lhs, rhs, "bortmp")
|
||||
else
|
||||
c.LLVMBuildAnd(b, lhs, rhs, "bandtmp");
|
||||
}
|
||||
|
||||
// Enum/union literal assigned to union type: construct tagged enum
|
||||
if (node.data == .enum_literal and target_ty.isUnion()) {
|
||||
const el = node.data.enum_literal;
|
||||
@@ -4203,6 +4250,62 @@ pub const CodeGen = struct {
|
||||
return phi;
|
||||
}
|
||||
|
||||
fn genIsFlags(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
|
||||
if (call_node.args.len != 1) return self.emitError("is_flags expects exactly 1 argument");
|
||||
const ty = self.resolveType(call_node.args[0]);
|
||||
const i1_type = c.LLVMInt1TypeInContext(self.context);
|
||||
if (ty.isEnum()) {
|
||||
const is_flags = self.flags_enum_types.contains(ty.enum_type);
|
||||
return c.LLVMConstInt(i1_type, @intFromBool(is_flags), 0);
|
||||
}
|
||||
return c.LLVMConstInt(i1_type, 0, 0);
|
||||
}
|
||||
|
||||
fn genFieldValueInt(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
|
||||
if (call_node.args.len != 2) return self.emitError("field_value_int expects 2 arguments: field_value_int(T, idx)");
|
||||
const ty = self.resolveType(call_node.args[0]);
|
||||
const i64_type = c.LLVMInt64TypeInContext(self.context);
|
||||
// For non-enum types (e.g. tagged enums compiled via dead code), return the index as value
|
||||
if (!ty.isEnum()) {
|
||||
return try self.genExpr(call_node.args[1]);
|
||||
}
|
||||
const enum_name = ty.enum_type;
|
||||
const values = self.enum_variant_values.get(enum_name);
|
||||
const variants = self.enum_types.get(enum_name) orelse return try self.genExpr(call_node.args[1]);
|
||||
const n = variants.len;
|
||||
|
||||
const idx = try self.genExpr(call_node.args[1]);
|
||||
const function = self.current_function;
|
||||
const merge_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_merge");
|
||||
const default_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_default");
|
||||
const sw = c.LLVMBuildSwitch(self.builder, idx, default_bb, @intCast(n));
|
||||
|
||||
var phi_vals = std.ArrayList(c.LLVMValueRef).empty;
|
||||
var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty;
|
||||
|
||||
for (0..n) |i| {
|
||||
const case_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_case");
|
||||
c.LLVMAddCase(sw, c.LLVMConstInt(i64_type, i, 0), case_bb);
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, case_bb);
|
||||
const val: u64 = if (values) |vals| @bitCast(vals[i]) else i;
|
||||
try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, val, 0));
|
||||
try phi_bbs.append(self.allocator, case_bb);
|
||||
_ = c.LLVMBuildBr(self.builder, merge_bb);
|
||||
}
|
||||
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, default_bb);
|
||||
try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, 0, 0));
|
||||
try phi_bbs.append(self.allocator, default_bb);
|
||||
_ = c.LLVMBuildBr(self.builder, merge_bb);
|
||||
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, merge_bb);
|
||||
const vals_slice = try phi_vals.toOwnedSlice(self.allocator);
|
||||
const bbs_slice = try phi_bbs.toOwnedSlice(self.allocator);
|
||||
const phi = c.LLVMBuildPhi(self.builder, i64_type, "fvi_result");
|
||||
c.LLVMAddIncoming(phi, vals_slice.ptr, bbs_slice.ptr, @intCast(vals_slice.len));
|
||||
return phi;
|
||||
}
|
||||
|
||||
fn genCast(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
|
||||
if (call_node.args.len != 2) return self.emitError("cast expects: cast(Type) expr");
|
||||
const target_ty = self.resolveType(call_node.args[0]);
|
||||
@@ -4591,6 +4694,8 @@ pub const CodeGen = struct {
|
||||
.lte => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOLE, lhs, rhs, "letmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntULE, lhs, rhs, "letmp") else c.LLVMBuildICmp(b, c.LLVMIntSLE, lhs, rhs, "letmp"),
|
||||
.gt => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOGT, lhs, rhs, "gttmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntUGT, lhs, rhs, "gttmp") else c.LLVMBuildICmp(b, c.LLVMIntSGT, lhs, rhs, "gttmp"),
|
||||
.gte => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOGE, lhs, rhs, "getmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntUGE, lhs, rhs, "getmp") else c.LLVMBuildICmp(b, c.LLVMIntSGE, lhs, rhs, "getmp"),
|
||||
.bit_and => c.LLVMBuildAnd(b, lhs, rhs, "bandtmp"),
|
||||
.bit_or => c.LLVMBuildOr(b, lhs, rhs, "bortmp"),
|
||||
.and_op, .or_op => unreachable,
|
||||
};
|
||||
}
|
||||
@@ -5198,6 +5303,7 @@ pub const CodeGen = struct {
|
||||
try self.instantiateGeneric(fd, bindings, mangled);
|
||||
|
||||
// Generate arguments with type conversion to match parameter types
|
||||
const saved_call_bindings = self.type_param_bindings;
|
||||
self.type_param_bindings = bindings;
|
||||
var arg_vals = std.ArrayList(c.LLVMValueRef).empty;
|
||||
for (call_node.args, 0..) |arg, i| {
|
||||
@@ -5208,7 +5314,7 @@ pub const CodeGen = struct {
|
||||
try arg_vals.append(self.allocator, try self.genExpr(arg));
|
||||
}
|
||||
}
|
||||
self.type_param_bindings = null;
|
||||
self.type_param_bindings = saved_call_bindings;
|
||||
const args_slice = try arg_vals.toOwnedSlice(self.allocator);
|
||||
|
||||
const fn_type = c.LLVMGlobalGetValueType(callee_fn);
|
||||
@@ -5584,9 +5690,10 @@ pub const CodeGen = struct {
|
||||
self.scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty;
|
||||
self.defer_stack = std.ArrayList(std.ArrayList(*Node)).empty;
|
||||
|
||||
// Set type param bindings
|
||||
// Set type param bindings (save/restore to support nested generic instantiation)
|
||||
const saved_bindings = self.type_param_bindings;
|
||||
self.type_param_bindings = bindings;
|
||||
defer self.type_param_bindings = null;
|
||||
defer self.type_param_bindings = saved_bindings;
|
||||
|
||||
// Build the specialized function type
|
||||
const fn_type = try self.buildFnType(fd.params, fd.return_type, mangled);
|
||||
@@ -5913,18 +6020,28 @@ pub const CodeGen = struct {
|
||||
fn genEnumLiteral(self: *CodeGen, variant_name: []const u8, enum_type_name: []const u8) c.LLVMValueRef {
|
||||
const i64_type = c.LLVMInt64TypeInContext(self.context);
|
||||
const variants = self.enum_types.get(enum_type_name) orelse return c.LLVMConstInt(i64_type, 0, 0);
|
||||
const values = self.enum_variant_values.get(enum_type_name);
|
||||
for (variants, 0..) |v, i| {
|
||||
if (std.mem.eql(u8, v, variant_name)) {
|
||||
return c.LLVMConstInt(i64_type, @intCast(i), 0);
|
||||
const val: u64 = if (values) |vals| @bitCast(vals[i]) else @intCast(i);
|
||||
return c.LLVMConstInt(i64_type, val, 0);
|
||||
}
|
||||
}
|
||||
return c.LLVMConstInt(i64_type, 0, 0);
|
||||
}
|
||||
|
||||
fn lookupVariantIndex(variants: ?[]const []const u8, name: []const u8) u64 {
|
||||
fn lookupVariantValue(self: *CodeGen, enum_name: ?[]const u8, variants: ?[]const []const u8, name: []const u8) u64 {
|
||||
if (variants) |vs| {
|
||||
for (vs, 0..) |v, i| {
|
||||
if (std.mem.eql(u8, v, name)) return i;
|
||||
if (std.mem.eql(u8, v, name)) {
|
||||
// Use resolved values if available (flags enums, explicit values)
|
||||
if (enum_name) |en| {
|
||||
if (self.enum_variant_values.get(en)) |vals| {
|
||||
return @bitCast(vals[i]);
|
||||
}
|
||||
}
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
@@ -5985,7 +6102,7 @@ pub const CodeGen = struct {
|
||||
for (match.arms, 0..) |arm, i| {
|
||||
const pat = arm.pattern orelse continue; // skip else arm
|
||||
if (pat.data == .enum_literal) {
|
||||
const idx = lookupVariantIndex(variants, pat.data.enum_literal.name);
|
||||
const idx = self.lookupVariantValue(enum_name orelse union_name, variants, pat.data.enum_literal.name);
|
||||
const case_val = c.LLVMConstInt(i64_type, idx, 0);
|
||||
c.LLVMAddCase(sw, case_val, case_bbs.items[i]);
|
||||
} else if (pat.data == .type_expr) {
|
||||
@@ -6235,6 +6352,8 @@ pub const CodeGen = struct {
|
||||
if (std.mem.eql(u8, base, "field_count")) return self.genFieldCount(call_node);
|
||||
if (std.mem.eql(u8, base, "field_name")) return self.genFieldName(call_node);
|
||||
if (std.mem.eql(u8, base, "field_value")) return self.genFieldValue(call_node);
|
||||
if (std.mem.eql(u8, base, "is_flags")) return self.genIsFlags(call_node);
|
||||
if (std.mem.eql(u8, base, "field_value_int")) return self.genFieldValueInt(call_node);
|
||||
return self.emitErrorFmt("unknown builtin function '{s}'", .{name});
|
||||
}
|
||||
|
||||
@@ -6466,6 +6585,10 @@ pub const CodeGen = struct {
|
||||
if (std.mem.eql(u8, base_name, "field_name")) return .string_type;
|
||||
// Built-in: field_value returns Any
|
||||
if (std.mem.eql(u8, base_name, "field_value")) return .{ .any_type = {} };
|
||||
// Built-in: is_flags returns bool
|
||||
if (std.mem.eql(u8, base_name, "is_flags")) return .boolean;
|
||||
// Built-in: field_value_int returns s64
|
||||
if (std.mem.eql(u8, base_name, "field_value_int")) return Type.s(64);
|
||||
// Built-in: cast returns the target type (first arg)
|
||||
if (std.mem.eql(u8, base_name, "cast")) {
|
||||
if (call_node.args.len > 0) return self.resolveType(call_node.args[0]);
|
||||
|
||||
Reference in New Issue
Block a user