This commit is contained in:
agra
2026-02-11 01:43:30 +02:00
parent 25e1372731
commit 89fc6427c4
6 changed files with 236 additions and 27 deletions

View File

@@ -4,11 +4,15 @@ const c = llvm.c;
pub const Builtins = struct {
printf_fn: c.LLVMValueRef,
calloc_fn: c.LLVMValueRef,
malloc_fn: c.LLVMValueRef,
free_fn: c.LLVMValueRef,
memcpy_fn: c.LLVMValueRef,
pub fn init(module: c.LLVMModuleRef, ctx: c.LLVMContextRef) Builtins {
const ptr_type = c.LLVMPointerTypeInContext(ctx, 0);
const i64_type = c.LLVMInt64TypeInContext(ctx);
const i32_type = c.LLVMInt32TypeInContext(ctx);
const void_type = c.LLVMVoidTypeInContext(ctx);
// Declare: int printf(const char*, ...)
var printf_params = [_]c.LLVMTypeRef{ptr_type};
@@ -20,6 +24,27 @@ pub const Builtins = struct {
const calloc_type = c.LLVMFunctionType(ptr_type, &calloc_params, 2, 0);
const calloc_fn = c.LLVMAddFunction(module, "calloc", calloc_type);
return .{ .printf_fn = printf_fn, .calloc_fn = calloc_fn };
// Declare: void* malloc(size_t size)
var malloc_params = [_]c.LLVMTypeRef{i64_type};
const malloc_type = c.LLVMFunctionType(ptr_type, &malloc_params, 1, 0);
const malloc_fn = c.LLVMAddFunction(module, "malloc", malloc_type);
// Declare: void free(void* ptr)
var free_params = [_]c.LLVMTypeRef{ptr_type};
const free_type = c.LLVMFunctionType(void_type, &free_params, 1, 0);
const free_fn = c.LLVMAddFunction(module, "free", free_type);
// Declare: void* memcpy(void* dst, const void* src, size_t n)
var memcpy_params = [_]c.LLVMTypeRef{ ptr_type, ptr_type, i64_type };
const memcpy_type = c.LLVMFunctionType(ptr_type, &memcpy_params, 3, 0);
const memcpy_fn = c.LLVMAddFunction(module, "memcpy", memcpy_type);
return .{
.printf_fn = printf_fn,
.calloc_fn = calloc_fn,
.malloc_fn = malloc_fn,
.free_fn = free_fn,
.memcpy_fn = memcpy_fn,
};
}
};

View File

@@ -139,6 +139,9 @@ pub const CodeGen = struct {
field_defaults: []const ?*Node,
llvm_type: c.LLVMTypeRef,
display_name: ?[]const u8 = null, // pretty name for generic instances
type_param_names: []const []const u8 = &.{}, // original type param names (e.g. ["T"])
type_param_types: []const Type = &.{}, // resolved types (e.g. [s32])
template_name: ?[]const u8 = null, // original template name (e.g. "List")
};
const UnionInfo = struct {
@@ -320,7 +323,15 @@ pub const CodeGen = struct {
/// works in any_to_string even before buildAnyValue is called for this type.
fn preRegisterAnyType(self: *CodeGen, sx_type: Type) !void {
switch (sx_type) {
.struct_type => |name| _ = try self.getAnyTypeId(name, sx_type),
.struct_type => |name| {
_ = try self.getAnyTypeId(name, sx_type);
// Recursively register struct field types
if (self.struct_types.get(name)) |info| {
for (info.field_types) |ft| {
try self.preRegisterAnyType(ft);
}
}
},
.enum_type => |name| _ = try self.getAnyTypeId(name, sx_type),
.union_type => |name| _ = try self.getAnyTypeId(name, sx_type),
.vector_type => |info| _ = try self.getAnyTypeId(try std.fmt.allocPrint(self.allocator, "vec[{d}]{s}", .{ info.length, info.element_name }), sx_type),
@@ -988,12 +999,28 @@ pub const CodeGen = struct {
try display_buf.append(self.allocator, ')');
const display_name = try display_buf.toOwnedSlice(self.allocator);
// Collect type param names and resolved types for later extraction
var tp_names = std.ArrayList([]const u8).empty;
var tp_types = std.ArrayList(Type).empty;
for (sd.type_params) |tp| {
const constraint_name = if (tp.constraint.data == .type_expr) tp.constraint.data.type_expr.name else "";
if (std.mem.eql(u8, constraint_name, "Type")) {
if (type_bindings.get(tp.name)) |ty| {
try tp_names.append(self.allocator, tp.name);
try tp_types.append(self.allocator, ty);
}
}
}
try self.struct_types.put(mangled_name, .{
.field_names = sd.field_names,
.field_types = try field_sx_types.toOwnedSlice(self.allocator),
.field_defaults = resolved_defaults,
.llvm_type = struct_ty,
.display_name = display_name,
.type_param_names = try tp_names.toOwnedSlice(self.allocator),
.type_param_types = try tp_types.toOwnedSlice(self.allocator),
.template_name = template_name,
});
_ = try self.getAnyTypeId(mangled_name, .{ .struct_type = mangled_name });
@@ -2118,7 +2145,21 @@ pub const CodeGen = struct {
c.LLVMPointerTypeInContext(self.context, 0), entry.ptr, "ptr_load");
const gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, loaded_ptr, @intCast(fi), "pfield_ptr");
const rhs = try self.genExprAsType(asgn.value, field_ty);
_ = c.LLVMBuildStore(self.builder, rhs, gep);
if (asgn.op == .assign) {
_ = c.LLVMBuildStore(self.builder, rhs, gep);
} else {
const field_llvm_ty = self.typeToLLVM(field_ty);
const cur = c.LLVMBuildLoad2(self.builder, field_llvm_ty, gep, "pcur");
const store_val = switch (asgn.op) {
.add_assign => if (field_ty.isFloat()) c.LLVMBuildFAdd(self.builder, cur, rhs, "paddtmp") else c.LLVMBuildAdd(self.builder, cur, rhs, "paddtmp"),
.sub_assign => if (field_ty.isFloat()) c.LLVMBuildFSub(self.builder, cur, rhs, "psubtmp") else c.LLVMBuildSub(self.builder, cur, rhs, "psubtmp"),
.mul_assign => if (field_ty.isFloat()) c.LLVMBuildFMul(self.builder, cur, rhs, "pmultmp") else c.LLVMBuildMul(self.builder, cur, rhs, "pmultmp"),
.div_assign => if (field_ty.isFloat()) c.LLVMBuildFDiv(self.builder, cur, rhs, "pdivtmp") else if (field_ty.isUnsigned()) c.LLVMBuildUDiv(self.builder, cur, rhs, "pdivtmp") else c.LLVMBuildSDiv(self.builder, cur, rhs, "pdivtmp"),
.mod_assign => if (field_ty.isFloat()) c.LLVMBuildFRem(self.builder, cur, rhs, "pmodtmp") else if (field_ty.isUnsigned()) c.LLVMBuildURem(self.builder, cur, rhs, "pmodtmp") else c.LLVMBuildSRem(self.builder, cur, rhs, "pmodtmp"),
.assign => unreachable,
};
_ = c.LLVMBuildStore(self.builder, store_val, gep);
}
return null;
}
return self.emitError("field assignment through pointer requires a struct pointee");
@@ -2985,6 +3026,25 @@ pub const CodeGen = struct {
}
}
// Implicit address-of: passing T where *T is expected → auto &
if (target_ty.isPointer()) {
const src_ty = self.inferType(node);
const pointee_name = target_ty.pointer_type.pointee_name;
const src_matches = if (src_ty.isStruct())
std.mem.eql(u8, src_ty.struct_type, pointee_name) or
(if (self.type_aliases.get(src_ty.struct_type)) |alias| std.mem.eql(u8, alias, pointee_name) else false) or
(if (self.type_aliases.get(pointee_name)) |alias| std.mem.eql(u8, alias, src_ty.struct_type) else false)
else
false;
if (src_matches) {
if (node.data == .identifier) {
if (self.named_values.get(node.data.identifier.name)) |entry| {
return entry.ptr;
}
}
}
}
const val = try self.genExpr(node);
const src_ty = self.inferType(node);
@@ -3496,6 +3556,37 @@ pub const CodeGen = struct {
return self.buildStringSliceRT(ptr, size_val);
}
fn genMalloc(self: *CodeGen, args: []const *Node) !c.LLVMValueRef {
if (args.len != 1) return self.emitError("malloc expects exactly 1 argument: malloc(size)");
const builtins = self.builtins orelse return self.emitError("builtins not available");
const size_val = try self.genExpr(args[0]);
const fn_ty = c.LLVMGlobalGetValueType(builtins.malloc_fn);
var call_args = [_]c.LLVMValueRef{size_val};
return c.LLVMBuildCall2(self.builder, fn_ty, builtins.malloc_fn, &call_args, 1, "malloc_ptr");
}
fn genFree(self: *CodeGen, args: []const *Node) !c.LLVMValueRef {
if (args.len != 1) return self.emitError("free expects exactly 1 argument: free(ptr)");
const builtins = self.builtins orelse return self.emitError("builtins not available");
const ptr_val = try self.genExpr(args[0]);
const fn_ty = c.LLVMGlobalGetValueType(builtins.free_fn);
var call_args = [_]c.LLVMValueRef{ptr_val};
_ = c.LLVMBuildCall2(self.builder, fn_ty, builtins.free_fn, &call_args, 1, "");
return null;
}
fn genMemcpy(self: *CodeGen, args: []const *Node) !c.LLVMValueRef {
if (args.len != 3) return self.emitError("memcpy expects 3 arguments: memcpy(dst, src, size)");
const builtins = self.builtins orelse return self.emitError("builtins not available");
const dst = try self.genExpr(args[0]);
const src = try self.genExpr(args[1]);
const size_val = try self.genExpr(args[2]);
const fn_ty = c.LLVMGlobalGetValueType(builtins.memcpy_fn);
var call_args = [_]c.LLVMValueRef{ dst, src, size_val };
_ = c.LLVMBuildCall2(self.builder, fn_ty, builtins.memcpy_fn, &call_args, 3, "");
return null;
}
fn genVectorExtract(self: *CodeGen, vec_val: c.LLVMValueRef, field: []const u8) !c.LLVMValueRef {
if (field.len == 1) {
const idx_val = componentToIndex(field[0]) orelse return self.emitErrorFmt("invalid vector component '{c}'", .{field[0]});
@@ -4034,6 +4125,15 @@ pub const CodeGen = struct {
if (std.mem.eql(u8, callee_name, "cast")) {
return self.genCast(call_node);
}
if (std.mem.eql(u8, callee_name, "malloc")) {
return self.genMalloc(call_node.args);
}
if (std.mem.eql(u8, callee_name, "free")) {
return self.genFree(call_node.args);
}
if (std.mem.eql(u8, callee_name, "memcpy")) {
return self.genMemcpy(call_node.args);
}
const name_z = try self.allocator.dupeZ(u8, callee_name);
var callee_fn = c.LLVMGetNamedFunction(self.module, name_z.ptr);
@@ -4238,25 +4338,62 @@ pub const CodeGen = struct {
var bindings = std.StringHashMap(Type).init(self.allocator);
for (fd.params, 0..) |param, i| {
if (param.is_comptime) continue;
// Direct type param: (a: $T) or (a: T)
// Direct type param: (a: $T) introduces/widens, (a: T) only binds if not yet bound
if (param.type_expr.data == .type_expr) {
const type_name = param.type_expr.data.type_expr.name;
const is_introducing = param.type_expr.data.type_expr.is_generic;
// Check if this type name is a type parameter
for (fd.type_params) |tp| {
if (std.mem.eql(u8, tp.name, type_name)) {
if (i < call_node.args.len) {
const arg_ty = self.inferType(call_node.args[i]);
if (bindings.get(type_name)) |existing| {
// Widen to the broader type to avoid data loss
try bindings.put(type_name, Type.widen(existing, arg_ty));
} else {
try bindings.put(type_name, arg_ty);
if (is_introducing) {
const arg_ty = self.inferType(call_node.args[i]);
if (bindings.get(type_name)) |existing| {
try bindings.put(type_name, Type.widen(existing, arg_ty));
} else {
try bindings.put(type_name, arg_ty);
}
} else if (!bindings.contains(type_name)) {
// Plain T reference with no prior binding — infer from arg
try bindings.put(type_name, self.inferType(call_node.args[i]));
}
}
break;
}
}
}
// Pointer to parameterized type: (p: *Foo($T)) — extract T from concrete struct
if (param.type_expr.data == .pointer_type_expr) {
const pointee = param.type_expr.data.pointer_type_expr.pointee_type;
if (pointee.data == .parameterized_type_expr) {
const pte = pointee.data.parameterized_type_expr;
if (i < call_node.args.len) {
const arg_ty = self.inferType(call_node.args[i]);
// arg should be *StructName — get the struct's stored type param bindings
const struct_name = if (arg_ty.isPointer())
arg_ty.pointer_type.pointee_name
else if (arg_ty.isStruct())
arg_ty.struct_type
else
"";
if (self.struct_types.get(struct_name)) |info| {
if (info.template_name) |tmpl_name| {
if (std.mem.eql(u8, tmpl_name, pte.name)) {
// Match generic args against stored type param bindings
for (pte.args, 0..) |arg, ai| {
if (arg.data == .type_expr and arg.data.type_expr.is_generic) {
const gen_name = arg.data.type_expr.name;
if (ai < info.type_param_types.len) {
try bindings.put(gen_name, info.type_param_types[ai]);
}
}
}
}
}
}
}
}
}
// Slice type param: (items: []$T) — infer T from array or slice element type
if (param.type_expr.data == .slice_type_expr) {
const elem_node = param.type_expr.data.slice_type_expr.element_type;
@@ -5330,6 +5467,9 @@ pub const CodeGen = struct {
if (std.mem.eql(u8, base, "size_of")) return self.genSizeOf(call_node);
if (std.mem.eql(u8, base, "cast")) return self.genCast(call_node);
if (std.mem.eql(u8, base, "alloc")) return self.genAlloc(call_node.args);
if (std.mem.eql(u8, base, "malloc")) return self.genMalloc(call_node.args);
if (std.mem.eql(u8, base, "free")) return self.genFree(call_node.args);
if (std.mem.eql(u8, base, "memcpy")) return self.genMemcpy(call_node.args);
if (std.mem.eql(u8, base, "type_of")) return self.genTypeOf(call_node);
if (std.mem.eql(u8, base, "type_name")) return self.genTypeName(call_node);
if (std.mem.eql(u8, base, "field_count")) return self.genFieldCount(call_node);
@@ -5583,6 +5723,9 @@ pub const CodeGen = struct {
}
// Built-in: alloc returns string
if (std.mem.eql(u8, base_name, "alloc")) return .string_type;
if (std.mem.eql(u8, base_name, "malloc")) return .{ .pointer_type = .{ .pointee_name = "void" } };
if (std.mem.eql(u8, base_name, "free")) return .void_type;
if (std.mem.eql(u8, base_name, "memcpy")) return .void_type;
// Check generic templates — infer return type from widened bindings
const template = self.generic_templates.get(callee_name) orelse blk: {
// Intra-namespace fallback

View File

@@ -547,6 +547,23 @@ pub const Parser = struct {
return try params.toOwnedSlice(self.allocator);
}
/// Recursively find all generic type names ($T) in a type expression tree.
fn collectGenericNames(node: *Node, list: *std.ArrayList([]const u8), allocator: std.mem.Allocator) void {
switch (node.data) {
.type_expr => |te| {
if (te.is_generic) list.append(allocator, te.name) catch {};
},
.pointer_type_expr => |pte| collectGenericNames(pte.pointee_type, list, allocator),
.many_pointer_type_expr => |mpte| collectGenericNames(mpte.element_type, list, allocator),
.slice_type_expr => |ste| collectGenericNames(ste.element_type, list, allocator),
.array_type_expr => |ate| collectGenericNames(ate.element_type, list, allocator),
.parameterized_type_expr => |pte| {
for (pte.args) |arg| collectGenericNames(arg, list, allocator);
},
else => {},
}
}
/// Collect generic type params and comptime value params from parameter annotations.
fn collectTypeParams(self: *Parser, params: []const ast.Param) ![]const ast.StructTypeParam {
var type_params = std.ArrayList(ast.StructTypeParam).empty;
@@ -563,25 +580,20 @@ pub const Parser = struct {
try type_params.append(self.allocator, .{ .name = param.name, .constraint = param.type_expr });
}
} else {
// Check for generic type param: direct $T or nested inside []$T
const generic_type_expr: ?*Node = if (param.type_expr.data == .type_expr and param.type_expr.data.type_expr.is_generic)
param.type_expr
else if (param.type_expr.data == .slice_type_expr) blk: {
const elem = param.type_expr.data.slice_type_expr.element_type;
break :blk if (elem.data == .type_expr and elem.data.type_expr.is_generic) elem else null;
} else null;
if (generic_type_expr) |gte| {
// Collect all generic type params found anywhere in the type expression
var generic_names = std.ArrayList([]const u8).empty;
collectGenericNames(param.type_expr, &generic_names, self.allocator);
for (generic_names.items) |gen_name| {
var found = false;
for (type_params.items) |existing| {
if (std.mem.eql(u8, existing.name, gte.data.type_expr.name)) {
if (std.mem.eql(u8, existing.name, gen_name)) {
found = true;
break;
}
}
if (!found) {
const type_constraint = try self.createNode(param.type_expr.span.start, .{ .type_expr = .{ .name = "Type" } });
try type_params.append(self.allocator, .{ .name = gte.data.type_expr.name, .constraint = type_constraint });
const type_constraint = self.createNode(param.type_expr.span.start, .{ .type_expr = .{ .name = "Type" } }) catch continue;
type_params.append(self.allocator, .{ .name = gen_name, .constraint = type_constraint }) catch {};
}
}
}

View File

@@ -230,9 +230,10 @@ pub const Type = union(enum) {
if (self.isSlice() and target.isSlice()) {
return std.mem.eql(u8, self.slice_type.element_name, target.slice_type.element_name);
}
// Pointer types: compare pointee names by content, null (*void) → any pointer
// Pointer types: compare pointee names by content, *void is universal (both directions)
if (self.isPointer() and target.isPointer()) {
if (std.mem.eql(u8, self.pointer_type.pointee_name, "void")) return true;
if (std.mem.eql(u8, target.pointer_type.pointee_name, "void")) return true;
return std.mem.eql(u8, self.pointer_type.pointee_name, target.pointer_type.pointee_name);
}
// Many-pointer types: compare element names by content
@@ -240,9 +241,15 @@ pub const Type = union(enum) {
return std.mem.eql(u8, self.many_pointer_type.element_name, target.many_pointer_type.element_name);
}
// *T → [*]T: pointer to element is implicitly convertible to many-pointer
// null (*void) → [*]T is also allowed
if (self.isPointer() and target.isManyPointer()) {
if (std.mem.eql(u8, self.pointer_type.pointee_name, "void")) return true;
return std.mem.eql(u8, self.pointer_type.pointee_name, target.many_pointer_type.element_name);
}
// [*]T → *void: any many-pointer converts to void pointer
if (self.isManyPointer() and target.isPointer()) {
return std.mem.eql(u8, target.pointer_type.pointee_name, "void");
}
const src_float = self.isFloat();
const dst_float = target.isFloat();