more forward declarations

This commit is contained in:
agra
2026-02-24 17:37:52 +02:00
parent 97475d6cfe
commit 566121c45a
13 changed files with 867 additions and 88 deletions

View File

@@ -201,6 +201,8 @@ pub const CodeGen = struct {
current_match_tags: ?[]const u64 = null,
// Functions deferred to compile after all types are registered (e.g. any_to_string)
deferred_fn_bodies: std.ArrayList(DeferredFn),
// AST nodes whose bodies were generated in Pass 4 (to avoid double generation in main)
generated_bodies: std.AutoHashMap(*const Node, void),
// Libraries to link against (from #library directives)
foreign_libraries: std.ArrayList([]const u8),
// Set of foreign function names (for ABI lowering at call sites)
@@ -436,6 +438,7 @@ pub const CodeGen = struct {
.any_type_id_map = std.StringHashMap(u64).init(allocator),
.any_type_entries = std.StringHashMap(AnyTypeEntry).init(allocator),
.deferred_fn_bodies = std.ArrayList(DeferredFn).empty,
.generated_bodies = std.AutoHashMap(*const Node, void).init(allocator),
.foreign_libraries = std.ArrayList([]const u8).empty,
.foreign_fns = std.StringHashMap(void).init(allocator),
.library_constants = std.StringHashMap([]const u8).init(allocator),
@@ -1392,7 +1395,7 @@ pub const CodeGen = struct {
switch (decl.data) {
.fn_decl => |fd| {
if (fd.body.data != .builtin_expr and fd.type_params.len == 0) {
try self.registerFnDecl(fd, fd.name);
_ = try self.registerFnDecl(fd, fd.name);
}
},
.struct_decl => |sd| try self.registerStructMethods(sd),
@@ -1457,11 +1460,13 @@ pub const CodeGen = struct {
} else {
try self.genFnBody(fd, fd.name);
}
try self.generated_bodies.put(decl, {});
}
},
.const_decl => |cd| {
if (cd.value.data == .lambda) {
try self.genLambdaBody(cd.name, cd.value.data.lambda);
try self.generated_bodies.put(decl, {});
}
},
.namespace_decl => |ns| {
@@ -2338,7 +2343,7 @@ pub const CodeGen = struct {
}
}
fn registerFnDecl(self: *CodeGen, fd: ast.FnDecl, llvm_name: []const u8) !void {
fn registerFnDecl(self: *CodeGen, fd: ast.FnDecl, llvm_name: []const u8) !c.LLVMValueRef {
const is_foreign = fd.body.data == .foreign_expr;
// For foreign functions: resolve C symbol name (rename) and validate library ref
const actual_llvm_name = if (is_foreign) blk: {
@@ -2359,7 +2364,7 @@ pub const CodeGen = struct {
} else llvm_name;
const fn_type = try self.buildFnType(fd.params, fd.return_type, fd.name, is_foreign);
const name_z = try self.allocator.dupeZ(u8, actual_llvm_name);
_ = c.LLVMAddFunction(self.module, name_z.ptr, fn_type);
const function = c.LLVMAddFunction(self.module, name_z.ptr, fn_type);
// Track foreign functions for ABI lowering at call sites (use sx name for call-site lookup)
if (is_foreign) {
try self.foreign_fns.put(llvm_name, {});
@@ -2394,6 +2399,7 @@ pub const CodeGen = struct {
break;
}
}
return function;
}
/// registerTypes helper: register type names within a namespace.
@@ -2473,7 +2479,7 @@ pub const CodeGen = struct {
}
const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns.name, fd.name });
if (fd.body.data == .foreign_expr) {
try self.registerFnDecl(fd, fd.name);
_ = try self.registerFnDecl(fd, fd.name);
try self.foreign_fns.put(qualified, {});
const fe = fd.body.data.foreign_expr;
if (fe.c_name) |c_name| {
@@ -2490,7 +2496,7 @@ pub const CodeGen = struct {
} else if (fd.type_params.len > 0) {
try self.generic_templates.put(qualified, fd);
} else {
try self.registerFnDecl(fd, qualified);
_ = try self.registerFnDecl(fd, qualified);
}
},
.struct_decl => |sd| {
@@ -2980,8 +2986,10 @@ pub const CodeGen = struct {
if (ret_val) |val| {
const prepared = try self.prepareReturnValue(val, ret_sx_type);
self.ret(prepared);
} else {
} else if (ret_sx_type == .void_type) {
self.retVoid();
} else {
_ = c.LLVMBuildUnreachable(self.builder);
}
}
}
@@ -3014,30 +3022,45 @@ pub const CodeGen = struct {
// Infer return type from body for => lambdas without explicit annotation
const ret_sx_type = if (fd.return_type != null) self.resolveType(fd.return_type) else if (fd.is_arrow) self.inferType(fd.body) else Type.void_type;
// For arrow lambdas with inferred return type, build function manually
if (fd.is_arrow and fd.return_type == null) {
const ret_llvm_type = self.typeToLLVM(ret_sx_type);
var param_llvm_types = std.ArrayList(c.LLVMTypeRef).empty;
for (fd.params) |param| {
try param_llvm_types.append(self.allocator, self.typeToLLVM(self.resolveType(param.type_expr)));
// Build or register the LLVM function, keeping a direct reference
// (LLVMGetNamedFunction returns the first fn with that name, which
// may differ when multiple local functions share a name)
const function = blk: {
if (fd.is_arrow and fd.return_type == null) {
const ret_llvm_type = self.typeToLLVM(ret_sx_type);
var param_llvm_types = std.ArrayList(c.LLVMTypeRef).empty;
for (fd.params) |param| {
try param_llvm_types.append(self.allocator, self.typeToLLVM(self.resolveType(param.type_expr)));
}
const params_slice = try param_llvm_types.toOwnedSlice(self.allocator);
const fn_type = c.LLVMFunctionType(
ret_llvm_type,
if (params_slice.len > 0) params_slice.ptr else null,
@intCast(params_slice.len),
0,
);
const name_z2 = try self.allocator.dupeZ(u8, fd.name);
const func = c.LLVMAddFunction(self.module, name_z2.ptr, fn_type);
try self.function_return_types.put(fd.name, ret_sx_type);
break :blk func;
} else {
break :blk try self.registerFnDecl(fd, fd.name);
}
const params_slice = try param_llvm_types.toOwnedSlice(self.allocator);
const fn_type = c.LLVMFunctionType(
ret_llvm_type,
if (params_slice.len > 0) params_slice.ptr else null,
@intCast(params_slice.len),
0,
);
const name_z2 = try self.allocator.dupeZ(u8, fd.name);
_ = c.LLVMAddFunction(self.module, name_z2.ptr, fn_type);
try self.function_return_types.put(fd.name, ret_sx_type);
} else {
try self.registerFnDecl(fd, fd.name);
};
// Skip if this exact AST node was already generated in Pass 4
// (top-level fn_decls appear both in Pass 4 and main's body)
if (self.generated_bodies.contains(node)) {
self.named_values.deinit();
self.named_values = saved_named;
self.narrowed_types = saved_narrowed;
self.current_return_type = saved_ret;
self.current_function = saved_fn;
self.positionAt(saved_bb);
return null;
}
self.current_return_type = ret_sx_type;
const name_z = try self.allocator.dupeZ(u8, fd.name);
const function = c.LLVMGetNamedFunction(self.module, name_z.ptr) orelse
return self.emitErrorFmt("local function '{s}' not found", .{fd.name});
self.current_function = function;
_ = self.appendBlock(function, "entry");
@@ -3071,7 +3094,7 @@ pub const CodeGen = struct {
const ret_val = try self.prepareReturnValue(val, ret_sx_type);
self.ret(ret_val);
} else {
self.retVoid();
_ = c.LLVMBuildUnreachable(self.builder);
}
}
@@ -3081,6 +3104,25 @@ pub const CodeGen = struct {
self.current_return_type = saved_ret;
self.current_function = saved_fn;
self.positionAt(saved_bb);
// Register local function in outer scope's named_values so it
// shadows any top-level function with the same name.
{
var param_types_list = std.ArrayList(Type).empty;
for (fd.params) |param| {
try param_types_list.append(self.allocator, self.resolveType(param.type_expr));
}
const ret_type_ptr = try self.allocator.create(Type);
ret_type_ptr.* = ret_sx_type;
const fn_ty: Type = .{ .function_type = .{
.param_types = try param_types_list.toOwnedSlice(self.allocator),
.return_type = ret_type_ptr,
} };
const local_name_z = try self.allocator.dupeZ(u8, fd.name);
const fn_alloca = self.buildEntryBlockAlloca(self.ptrType(), local_name_z.ptr);
_ = c.LLVMBuildStore(self.builder, function, fn_alloca);
try self.named_values.put(fd.name, .{ .ptr = fn_alloca, .ty = fn_ty });
}
}
return null;
},
@@ -4841,7 +4883,7 @@ pub const CodeGen = struct {
try self.generic_templates.put(qualified, fd);
} else {
// Non-generic struct, non-generic method: register directly
try self.registerFnDecl(fd, qualified);
_ = try self.registerFnDecl(fd, qualified);
}
}
}
@@ -4851,6 +4893,9 @@ pub const CodeGen = struct {
fn registerProtocolDecl(self: *CodeGen, pd: ast.ProtocolDecl) !void {
try self.protocol_decls.put(pd.name, pd);
// Skip if already registered (can happen with diamond imports)
if (self.type_registry.contains(pd.name)) return;
if (pd.is_inline) {
// #inline protocol: generate struct { ctx: *void, method1: fn_ptr, method2: fn_ptr, ... }
const n_fields = 1 + pd.methods.len; // ctx + one fn-ptr per method
@@ -5188,7 +5233,7 @@ pub const CodeGen = struct {
try self.generic_templates.put(qualified, fd);
} else {
// Non-generic: register directly
try self.registerFnDecl(fd, qualified);
_ = try self.registerFnDecl(fd, qualified);
}
try self.fn_signatures.put(qualified, self.buildFnSignature(fd));
}
@@ -5211,7 +5256,7 @@ pub const CodeGen = struct {
// Synthesize a fn_decl: method_name :: (self: *ConcreteType, params...) -> R { default_body }
const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ib.target_type, method.name });
const self_fd = try self.synthesizeDefaultMethod(ib.target_type, method);
try self.registerFnDecl(self_fd, qualified);
_ = try self.registerFnDecl(self_fd, qualified);
try self.fn_signatures.put(qualified, self.buildFnSignature(self_fd));
}
}
@@ -7931,30 +7976,6 @@ pub const CodeGen = struct {
return self.genCallByName(resolved, call_node);
}
// Check if this is a generic function call
if (self.generic_templates.get(callee_name)) |template| {
return self.genGenericCall(callee_name, template, call_node);
}
// Intra-namespace fallback for generic templates
if (self.current_namespace) |ns| {
const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns, callee_name });
if (self.generic_templates.get(qualified)) |template| {
return self.genGenericCall(qualified, template, call_node);
}
}
// Check for #builtin function (only available when imported)
if (self.builtin_functions.contains(callee_name)) {
return self.dispatchBuiltin(callee_name, call_node);
}
// Intra-namespace fallback for builtins
if (self.current_namespace) |ns| {
const qualified_builtin = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns, callee_name });
if (self.builtin_functions.contains(qualified_builtin)) {
return self.dispatchBuiltin(qualified_builtin, call_node);
}
}
// Compiler intrinsics (always available, no #builtin declaration needed)
if (std.mem.eql(u8, callee_name, "sqrt")) {
return self.genMathIntrinsic(call_node, "sqrt");
@@ -7981,7 +8002,8 @@ pub const CodeGen = struct {
return self.genClosureIntrinsic(call_node);
}
// Local variable takes priority: closures and function pointers shadow LLVM named functions
// Local variables shadow imported functions: closures and function pointers
// take priority over generic templates, builtins, and LLVM named functions.
if (self.lookupValue(callee_name)) |v| {
const entry = v.asNamedValue();
if (entry) |e| {
@@ -7994,6 +8016,30 @@ pub const CodeGen = struct {
}
}
// Check if this is a generic function call
if (self.generic_templates.get(callee_name)) |template| {
return self.genGenericCall(callee_name, template, call_node);
}
// Intra-namespace fallback for generic templates
if (self.current_namespace) |ns| {
const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns, callee_name });
if (self.generic_templates.get(qualified)) |template| {
return self.genGenericCall(qualified, template, call_node);
}
}
// Check for #builtin function (only available when imported)
if (self.builtin_functions.contains(callee_name)) {
return self.dispatchBuiltin(callee_name, call_node);
}
// Intra-namespace fallback for builtins
if (self.current_namespace) |ns| {
const qualified_builtin = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns, callee_name });
if (self.builtin_functions.contains(qualified_builtin)) {
return self.dispatchBuiltin(qualified_builtin, call_node);
}
}
var nbuf: [256]u8 = undefined;
var callee_fn = c.LLVMGetNamedFunction(self.module, self.nameToCStr(callee_name, &nbuf));
// Foreign function fallback: qualified name "ns.Func" → try unqualified "Func" (the C symbol)
@@ -8417,7 +8463,23 @@ pub const CodeGen = struct {
if (ret_llvm == self.voidType()) {
_ = c.LLVMBuildRetVoid(self.builder);
} else {
_ = c.LLVMBuildRet(self.builder, result);
// Convert result type if it doesn't match the thunk's declared return type
var ret_val = result;
const result_ty = c.LLVMTypeOf(result);
if (result_ty != ret_llvm) {
const src_kind = c.LLVMGetTypeKind(result_ty);
const dst_kind = c.LLVMGetTypeKind(ret_llvm);
if (src_kind == c.LLVMIntegerTypeKind and dst_kind == c.LLVMIntegerTypeKind) {
const src_bits = c.LLVMGetIntTypeWidth(result_ty);
const dst_bits = c.LLVMGetIntTypeWidth(ret_llvm);
if (src_bits > dst_bits) {
ret_val = c.LLVMBuildTrunc(self.builder, result, ret_llvm, "thunk_trunc");
} else {
ret_val = c.LLVMBuildSExt(self.builder, result, ret_llvm, "thunk_sext");
}
}
}
_ = c.LLVMBuildRet(self.builder, ret_val);
}
// Restore position

View File

@@ -1084,7 +1084,7 @@ pub const Server = struct {
};
var hints = std.ArrayList(lsp.InlayHint).empty;
collectInlayHints(self.allocator, root, sema.symbols, doc.source, &hints);
collectInlayHints(self.allocator, root, sema.symbols, sema.fn_signatures, doc.source, &hints);
self.collectCallHints(doc, root, &hints);
const result_json = try lsp.inlayHintsJson(self.allocator, hints.items);
try self.sendResponse(id_json, result_json);
@@ -1094,37 +1094,46 @@ pub const Server = struct {
allocator: std.mem.Allocator,
node: *const sx.ast.Node,
symbols: []const sx.sema.Symbol,
fn_signatures: std.StringHashMap(sx.sema.FnSignature),
source: [:0]const u8,
hints: *std.ArrayList(lsp.InlayHint),
) void {
switch (node.data) {
.root => |r| {
for (r.decls) |decl| collectInlayHints(allocator, decl, symbols, source, hints);
for (r.decls) |decl| collectInlayHints(allocator, decl, symbols, fn_signatures, source, hints);
},
.block => |b| {
for (b.stmts) |stmt| collectInlayHints(allocator, stmt, symbols, source, hints);
for (b.stmts) |stmt| collectInlayHints(allocator, stmt, symbols, fn_signatures, source, hints);
},
.fn_decl => |fd| {
collectInlayHints(allocator, fd.body, symbols, source, hints);
collectInlayHints(allocator, fd.body, symbols, fn_signatures, source, hints);
// Return type hint for arrow functions without explicit return type
if (fd.return_type == null and fd.is_arrow) {
if (fn_signatures.get(fd.name)) |sig| {
if (sig.return_type != .void_type) {
addReturnTypeHint(allocator, node.span, source, sig.return_type, hints);
}
}
}
},
.lambda => |lm| {
collectInlayHints(allocator, lm.body, symbols, source, hints);
collectInlayHints(allocator, lm.body, symbols, fn_signatures, source, hints);
},
.if_expr => |ie| {
if (ie.binding_name) |bname| {
addBindingHint(allocator, bname, node.span, symbols, source, hints);
}
collectInlayHints(allocator, ie.then_branch, symbols, source, hints);
if (ie.else_branch) |eb| collectInlayHints(allocator, eb, symbols, source, hints);
collectInlayHints(allocator, ie.then_branch, symbols, fn_signatures, source, hints);
if (ie.else_branch) |eb| collectInlayHints(allocator, eb, symbols, fn_signatures, source, hints);
},
.while_expr => |we| {
if (we.binding_name) |bname| {
addBindingHint(allocator, bname, node.span, symbols, source, hints);
}
collectInlayHints(allocator, we.body, symbols, source, hints);
collectInlayHints(allocator, we.body, symbols, fn_signatures, source, hints);
},
.for_expr => |fe| {
collectInlayHints(allocator, fe.body, symbols, source, hints);
collectInlayHints(allocator, fe.body, symbols, fn_signatures, source, hints);
},
.var_decl => |vd| {
// Only show hint when type is inferred (:= syntax)
@@ -1135,9 +1144,22 @@ pub const Server = struct {
.const_decl => |cd| {
// Skip if explicit type annotation
if (cd.type_annotation != null) return;
// Handle lambda with return type hint
if (cd.value.data == .lambda) {
const lam = cd.value.data.lambda;
collectInlayHints(allocator, lam.body, symbols, fn_signatures, source, hints);
if (lam.return_type == null) {
if (fn_signatures.get(cd.name)) |sig| {
if (sig.return_type != .void_type) {
addReturnTypeHint(allocator, cd.value.span, source, sig.return_type, hints);
}
}
}
return;
}
// Skip functions, types, structs, enums, unions, comptime, foreign, library
switch (cd.value.data) {
.lambda, .fn_decl, .type_expr, .struct_decl, .enum_decl, .union_decl,
.fn_decl, .type_expr, .struct_decl, .enum_decl, .union_decl,
.comptime_expr, .foreign_expr, .library_decl,
=> return,
else => {},
@@ -1241,6 +1263,47 @@ pub const Server = struct {
}
}
fn addReturnTypeHint(
allocator: std.mem.Allocator,
span: sx.ast.Span,
source: [:0]const u8,
return_type: sx.types.Type,
hints: *std.ArrayList(lsp.InlayHint),
) void {
// Find '(' from span start
var pos: u32 = span.start;
while (pos < source.len and source[pos] != '(') : (pos += 1) {}
if (pos >= source.len) return;
// Match nested parens to find closing ')'
var depth: u32 = 0;
while (pos < source.len) : (pos += 1) {
if (source[pos] == '(') {
depth += 1;
} else if (source[pos] == ')') {
depth -= 1;
if (depth == 0) break;
}
}
if (pos >= source.len or depth != 0) return;
// Place hint right after ')'
const loc = sx.errors.SourceLoc.compute(source, pos + 1);
if (loc.line == 0 or loc.col == 0) return;
const type_name = return_type.displayName(allocator) catch return;
const label = std.fmt.allocPrint(allocator, "-> {s}", .{type_name}) catch return;
hints.append(allocator, .{
.line = loc.line - 1,
.character = loc.col - 1,
.label = label,
.kind = 1,
.padding_left = true,
.padding_right = true,
}) catch {};
}
fn findSymbolAtSpan(symbols: []const sx.sema.Symbol, span_start: u32, name: []const u8) ?sx.sema.Symbol {
for (symbols) |sym| {
if (sym.def_span.start == span_start and std.mem.eql(u8, sym.name, name)) {

View File

@@ -129,7 +129,8 @@ pub const Analyzer = struct {
fn registerTopLevelDeclPrefixed(self: *Analyzer, node: *Node, ns_prefix: ?[]const u8) !void {
switch (node.data) {
.fn_decl => |fd| {
const ret_ty = resolveReturnType(fd);
const ret_ty = resolveReturnType(fd) orelse
if (fd.is_arrow) self.inferFnReturnType(fd.params, fd.body) else null;
try self.addSymbol(fd.name, .function, ret_ty, node.span);
// Populate fn_signatures registry
var param_types = std.ArrayList(Type).empty;
@@ -171,7 +172,10 @@ pub const Analyzer = struct {
const pt = Type.fromTypeExpr(param.type_expr) orelse Type.s(64);
try param_types.append(self.allocator, pt);
}
const ret = if (lam.return_type) |rt| Type.fromTypeExpr(rt) orelse .void_type else .void_type;
const ret = if (lam.return_type) |rt|
Type.fromTypeExpr(rt) orelse .void_type
else
self.inferFnReturnType(lam.params, lam.body) orelse .void_type;
const key = if (ns_prefix) |pfx|
try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ pfx, cd.name })
else
@@ -666,7 +670,21 @@ pub const Analyzer = struct {
fn analyzeNode(self: *Analyzer, node: *Node) !void {
switch (node.data) {
.fn_decl => |fd| {
try self.addSymbol(fd.name, .function, resolveReturnType(fd), node.span);
const local_ret_ty = resolveReturnType(fd) orelse
if (fd.is_arrow) self.inferFnReturnType(fd.params, fd.body) else null;
try self.addSymbol(fd.name, .function, local_ret_ty, node.span);
// Register fn_signatures for local functions (for return type hints + hover)
{
var param_types = std.ArrayList(Type).empty;
for (fd.params) |param| {
const pt = Type.fromTypeExpr(param.type_expr) orelse Type.s(64);
try param_types.append(self.allocator, pt);
}
try self.fn_signatures.put(fd.name, .{
.param_types = try param_types.toOwnedSlice(self.allocator),
.return_type = local_ret_ty orelse .void_type,
});
}
try self.pushScope();
try self.analyzeParams(fd.params);
try self.analyzeNode(fd.body);
@@ -967,6 +985,26 @@ pub const Analyzer = struct {
return null;
}
/// Infer return type from a function/lambda body by temporarily registering params.
fn inferFnReturnType(self: *Analyzer, params: []const ast.Param, body: *const Node) ?Type {
self.pushScope() catch return null;
for (params) |param| {
const pt = Type.fromTypeExpr(param.type_expr) orelse Type.s(64);
self.addSymbol(param.name, .param, pt, param.name_span) catch {};
}
// Arrow fn_decl wraps body in block{[expr]} — unwrap to inner expression
const expr_node = if (body.data == .block) blk: {
const stmts = body.data.block.stmts;
if (stmts.len > 0) break :blk stmts[stmts.len - 1];
break :blk body;
} else body;
const inferred = self.inferExprType(expr_node);
self.popScope();
if (inferred != .void_type) return inferred;
return null;
}
fn resolveTypeAnnotation(self: *Analyzer, type_node: ?*Node) ?Type {
if (type_node) |tn| {
if (Type.fromTypeExpr(tn)) |t| return t;