xplshn
·
2025-09-10
llvm_backend.go
Go
1package codegen
2
3import (
4 "bytes"
5 "fmt"
6 "math"
7 "os"
8 "os/exec"
9 "sort"
10 "strconv"
11 "strings"
12
13 "github.com/xplshn/gbc/pkg/config"
14 "github.com/xplshn/gbc/pkg/ir"
15)
16
17type llvmBackend struct {
18 out *strings.Builder
19 prog *ir.Program
20 cfg *config.Config
21 wordType string
22 tempTypes map[string]string // maps temp name to LLVM type
23 tempIRTypes map[string]ir.Type // maps temp name to IR type
24 funcSigs map[string]string
25 currentFn *ir.Func
26}
27
28func NewLLVMBackend() Backend { return &llvmBackend{} }
29
30func (b *llvmBackend) Generate(prog *ir.Program, cfg *config.Config) (*bytes.Buffer, error) {
31 llvmIR, err := b.GenerateIR(prog, cfg)
32 if err != nil {
33 return nil, err
34 }
35
36 asm, err := b.compileLLVMIR(llvmIR)
37 if err != nil {
38 return nil, err
39 }
40 return bytes.NewBufferString(asm), nil
41}
42
43func (b *llvmBackend) GenerateIR(prog *ir.Program, cfg *config.Config) (string, error) {
44 var llvmIRBuilder strings.Builder
45 b.out = &llvmIRBuilder
46 b.prog = prog
47 b.cfg = cfg
48 b.wordType = fmt.Sprintf("i%d", cfg.WordSize*8)
49 b.tempTypes = make(map[string]string)
50 b.tempIRTypes = make(map[string]ir.Type)
51 b.funcSigs = make(map[string]string)
52
53 b.gen()
54
55 return llvmIRBuilder.String(), nil
56}
57
58func (b *llvmBackend) compileLLVMIR(llvmIR string) (string, error) {
59 llFile, err := os.CreateTemp("", "gbc-main-*.ll")
60 if err != nil {
61 return "", fmt.Errorf("failed to create temp file for LLVM IR: %w", err)
62 }
63 defer os.Remove(llFile.Name())
64 if _, err := llFile.WriteString(llvmIR); err != nil {
65 return "", fmt.Errorf("failed to write to temp file for LLVM IR: %w", err)
66 }
67 llFile.Close()
68
69 asmFile, err := os.CreateTemp("", "gbc-main-*.s")
70 if err != nil {
71 return "", fmt.Errorf("failed to create temp file for assembly: %w", err)
72 }
73 asmFile.Close()
74 defer os.Remove(asmFile.Name())
75
76 cmd := exec.Command("llc", "-O2", "-o", asmFile.Name(), llFile.Name())
77 if output, err := cmd.CombinedOutput(); err != nil {
78 return "", fmt.Errorf("llc command failed: %w\n--- LLVM IR ---\n%s\n--- Output ---\n%s", err, llvmIR, string(output))
79 }
80
81 asmBytes, err := os.ReadFile(asmFile.Name())
82 if err != nil {
83 return "", fmt.Errorf("failed to read temporary assembly file: %w", err)
84 }
85 return string(asmBytes), nil
86}
87
88func (b *llvmBackend) gen() {
89 fmt.Fprintf(b.out, "; Generated by gbc\n")
90 fmt.Fprintf(b.out, "target triple = \"%s\"\n\n", b.cfg.BackendTarget)
91
92 b.genDeclarations()
93 b.genStrings()
94 b.genGlobals()
95
96 for _, fn := range b.prog.Funcs {
97 b.genFunc(fn)
98 }
99}
100
101func (b *llvmBackend) getFuncSig(name string) (retType string) { return b.wordType }
102
103func (b *llvmBackend) genDeclarations() {
104 knownExternals := make(map[string]bool)
105
106 b.out.WriteString("declare void @llvm.memcpy.p0i8.p0i8.i64(i8*, i8*, i64, i1)\n")
107
108 if len(b.prog.ExtrnVars) > 0 {
109 b.out.WriteString("; --- External Variables ---\n")
110 for name := range b.prog.ExtrnVars {
111 if knownExternals[name] {
112 continue
113 }
114 ptrType := "i8*"
115 fmt.Fprintf(b.out, "@%s = external global %s\n", name, ptrType)
116 b.tempTypes["@"+name] = ptrType + "*"
117 knownExternals[name] = true
118 }
119 b.out.WriteString("\n")
120 }
121
122 potentialFuncs := make(map[string]bool)
123
124 for _, name := range b.prog.ExtrnFuncs {
125 potentialFuncs[name] = true
126 }
127
128 // Find additional functions that are called but not explicitly declared
129 for _, fn := range b.prog.Funcs {
130 for _, block := range fn.Blocks {
131 for _, instr := range block.Instructions {
132 if instr.Op == ir.OpCall {
133 if g, ok := instr.Args[0].(*ir.Global); ok {
134 potentialFuncs[g.Name] = true
135 }
136 }
137 }
138 }
139 }
140
141 // Remove functions that are defined in this program (not external)
142 for _, fn := range b.prog.Funcs {
143 delete(potentialFuncs, fn.Name)
144 }
145
146 var funcsToDeclare []string
147 for name := range potentialFuncs {
148 if !knownExternals[name] {
149 funcsToDeclare = append(funcsToDeclare, name)
150 }
151 }
152
153 if len(funcsToDeclare) > 0 {
154 b.out.WriteString("; --- External Functions ---\n")
155 sort.Strings(funcsToDeclare)
156 for _, name := range funcsToDeclare {
157 retType := b.getFuncSig(name)
158
159 // All external functions are declared as varargs
160 // The linker will handle the correct resolution
161 sig := fmt.Sprintf("declare %s @%s(...)\n", retType, name)
162
163 b.out.WriteString(sig)
164 b.funcSigs[name] = sig
165 }
166 b.out.WriteString("\n")
167 }
168}
169
170func (b *llvmBackend) genStrings() {
171 if len(b.prog.Strings) == 0 {
172 return
173 }
174 b.out.WriteString("; --- String Literals ---\n")
175 for s, label := range b.prog.Strings {
176 strLen := len(s) + 1
177 typeStr := fmt.Sprintf("[%d x i8]", strLen)
178
179 // Always emit as raw byte sequence for simplicity and reliability
180 b.out.WriteString(fmt.Sprintf("@%s = private unnamed_addr constant %s [", label, typeStr))
181
182 // Handle empty strings
183 if len(s) == 0 {
184 b.out.WriteString("i8 0")
185 } else {
186 for i := 0; i < len(s); i++ {
187 if i > 0 {
188 b.out.WriteString(", ")
189 }
190 b.out.WriteString(fmt.Sprintf("i8 %d", s[i]))
191 }
192 b.out.WriteString(", i8 0")
193 }
194 b.out.WriteString("]\n")
195
196 b.tempTypes["@"+label] = typeStr + "*"
197 }
198 b.out.WriteString("\n")
199}
200
201func (b *llvmBackend) genGlobals() {
202 if len(b.prog.Globals) == 0 {
203 return
204 }
205 b.out.WriteString("; --- Global Variables ---\n")
206 for _, g := range b.prog.Globals {
207 hasInitializer := false
208 totalItemCount := 0
209 var firstItemType ir.Type = -1
210
211 for _, item := range g.Items {
212 if item.Count > 0 {
213 totalItemCount += item.Count
214 } else {
215 totalItemCount++
216 hasInitializer = true
217 }
218 if firstItemType == -1 {
219 firstItemType = item.Typ
220 }
221 }
222
223 var globalType string
224 elemType := b.formatType(firstItemType)
225 if firstItemType == -1 {
226 elemType = b.wordType
227 }
228
229 if totalItemCount > 1 {
230 globalType = fmt.Sprintf("[%d x %s]", totalItemCount, elemType)
231 } else if totalItemCount == 1 {
232 globalType = elemType
233 } else {
234 continue
235 }
236
237 initializer := "zeroinitializer"
238 if hasInitializer {
239 if strings.HasPrefix(globalType, "[") {
240 var typedItems []string
241 for _, item := range g.Items {
242 if item.Count > 0 {
243 for i := 0; i < item.Count; i++ {
244 typedItems = append(typedItems, fmt.Sprintf("%s 0", elemType))
245 }
246 } else {
247 itemTypeStr := b.formatType(item.Typ)
248 valStr := b.formatGlobalInitializerValue(item.Value, itemTypeStr)
249 typedItems = append(typedItems, fmt.Sprintf("%s %s", itemTypeStr, valStr))
250 }
251 }
252 initializer = fmt.Sprintf("[ %s ]", strings.Join(typedItems, ", "))
253 } else {
254 initializer = b.formatGlobalInitializerValue(g.Items[0].Value, globalType)
255 }
256 }
257
258 fmt.Fprintf(b.out, "@%s = global %s %s, align %d\n", g.Name, globalType, initializer, g.Align)
259 b.tempTypes["@"+g.Name] = globalType + "*"
260 }
261 b.out.WriteString("\n")
262}
263
264func (b *llvmBackend) formatGlobalInitializerValue(v ir.Value, targetType string) string {
265 switch val := v.(type) {
266 case *ir.Const:
267 return fmt.Sprintf("%d", val.Value)
268 case *ir.FloatConst:
269 if targetType == "float" {
270 // For 32-bit floats, truncate to float32 precision then expand back to float64 for hex format
271 float32Val := float32(val.Value)
272 return fmt.Sprintf("0x%016X", math.Float64bits(float64(float32Val)))
273 }
274 return fmt.Sprintf("0x%016X", math.Float64bits(val.Value))
275 case *ir.Global:
276 strContent, isString := b.prog.IsStringLabel(val.Name)
277 if isString {
278 strType := fmt.Sprintf("[%d x i8]", len(strContent)+1)
279 gep := fmt.Sprintf("getelementptr inbounds (%s, %s* @%s, i64 0, i64 0)", strType, strType, val.Name)
280 if targetType != "i8*" {
281 return fmt.Sprintf("ptrtoint (i8* %s to %s)", gep, targetType)
282 }
283 return gep
284 }
285 sourceType := b.getType(val)
286 if !strings.HasSuffix(sourceType, "*") {
287 sourceType += "*"
288 }
289 return fmt.Sprintf("bitcast (%s @%s to %s)", sourceType, val.Name, targetType)
290 default:
291 return "0"
292 }
293}
294
295func (b *llvmBackend) genFunc(fn *ir.Func) {
296 b.currentFn = fn
297 globalTypes := make(map[string]string)
298 for k, v := range b.tempTypes {
299 if strings.HasPrefix(k, "@") {
300 globalTypes[k] = v
301 }
302 }
303 b.tempTypes = globalTypes
304
305 retTypeStr := b.formatType(fn.ReturnType)
306 var params []string
307 for _, p := range fn.Params {
308 pName := b.formatValue(p.Val)
309 pType := b.formatType(p.Typ)
310 if fn.Name == "main" && p.Name == "argv" {
311 pType = "i8**"
312 }
313 params = append(params, fmt.Sprintf("%s %s", pType, pName))
314 b.tempTypes[pName] = pType
315 }
316 paramStr := strings.Join(params, ", ")
317 if fn.HasVarargs {
318 if len(params) > 0 {
319 paramStr += ", "
320 }
321 paramStr += "..."
322 }
323
324 fmt.Fprintf(b.out, "define %s @%s(%s) {\n", retTypeStr, fn.Name, paramStr)
325 for i, block := range fn.Blocks {
326 labelName := block.Label.Name
327 if i == 0 {
328 labelName = "entry"
329 }
330 fmt.Fprintf(b.out, "%s:\n", labelName)
331 b.genBlock(block)
332 }
333 b.out.WriteString("}\n")
334}
335
336func (b *llvmBackend) genBlock(block *ir.BasicBlock) {
337 var deferredCasts []string
338 phiEndIndex := 0
339
340 for i, instr := range block.Instructions {
341 if instr.Op != ir.OpPhi {
342 phiEndIndex = i
343 break
344 }
345 }
346 if phiEndIndex == 0 && len(block.Instructions) > 0 && block.Instructions[0].Op == ir.OpPhi {
347 phiEndIndex = len(block.Instructions)
348 }
349
350 for _, instr := range block.Instructions[:phiEndIndex] {
351 if instr.Op == ir.OpPhi {
352 cast := b.genPhi(instr)
353 if cast != "" {
354 deferredCasts = append(deferredCasts, cast)
355 }
356 }
357 }
358
359 for _, cast := range deferredCasts {
360 b.out.WriteString(cast)
361 }
362
363 for _, instr := range block.Instructions[phiEndIndex:] {
364 b.genInstr(instr)
365 }
366}
367
368func (b *llvmBackend) genInstr(instr *ir.Instruction) {
369 if instr.Op == ir.OpPhi {
370 return
371 }
372
373 resultName := ""
374 if instr.Result != nil {
375 resultName = b.formatValue(instr.Result)
376 }
377
378 b.out.WriteString("\t")
379
380 switch instr.Op {
381 case ir.OpAlloc:
382 align := instr.Align
383 if align == 0 {
384 align = b.cfg.StackAlignment
385 }
386 sizeVal := b.prepareArg(instr.Args[0], b.wordType)
387 fmt.Fprintf(b.out, "%s = alloca i8, %s %s, align %d\n", resultName, b.wordType, sizeVal, align)
388 b.tempTypes[resultName] = "i8*"
389
390 case ir.OpLoad:
391 valType := b.formatType(instr.Typ)
392 ptrType := valType + "*"
393 ptrVal := b.prepareArg(instr.Args[0], ptrType)
394
395 // Check if we need to promote smaller signed types to word size (like QBE does)
396 needsPromotion := instr.Typ == ir.TypeSB || instr.Typ == ir.TypeSH || instr.Typ == ir.TypeUB || instr.Typ == ir.TypeUH || instr.Typ == ir.TypeB || instr.Typ == ir.TypeH || instr.Typ == ir.TypeW
397
398 if needsPromotion {
399 // Load the smaller type first
400 tempName := fmt.Sprintf("%s_small", resultName)
401 fmt.Fprintf(b.out, "%s = load %s, %s %s, align %d\n", tempName, valType, ptrType, ptrVal, ir.SizeOfType(instr.Typ, b.cfg.WordSize))
402
403 // Then extend to word size
404 wordType := b.wordType
405 var extOp string
406 switch instr.Typ {
407 case ir.TypeSB: extOp = "sext" // signed byte - sign extend
408 case ir.TypeSH: extOp = "sext" // signed half - sign extend
409 case ir.TypeUB, ir.TypeB: extOp = "zext" // unsigned byte or ambiguous byte - zero extend
410 case ir.TypeUH, ir.TypeH: extOp = "zext" // unsigned half or ambiguous half - zero extend
411 case ir.TypeW: extOp = "sext" // 32-bit word - sign extend (assuming signed by default)
412 }
413
414 fmt.Fprintf(b.out, "%s = %s %s %s to %s\n", resultName, extOp, valType, tempName, wordType)
415 b.tempTypes[resultName] = wordType
416 b.tempIRTypes[resultName] = ir.GetType(nil, b.cfg.WordSize)
417 } else {
418 fmt.Fprintf(b.out, "%s = load %s, %s %s, align %d\n", resultName, valType, ptrType, ptrVal, ir.SizeOfType(instr.Typ, b.cfg.WordSize))
419 b.tempTypes[resultName] = valType
420 b.tempIRTypes[resultName] = instr.Typ
421 }
422
423 case ir.OpStore:
424 valType := b.formatType(instr.Typ)
425 ptrType := valType + "*"
426 ptrVal := b.prepareArg(instr.Args[1], ptrType)
427 valStr := b.prepareArg(instr.Args[0], valType)
428 fmt.Fprintf(b.out, "store %s %s, %s %s, align %d\n", valType, valStr, ptrType, ptrVal, ir.SizeOfType(instr.Typ, b.cfg.WordSize))
429
430 case ir.OpAdd:
431 b.genAdd(instr)
432
433 case ir.OpCast:
434 targetType := b.formatType(instr.Typ)
435 sourceValStr := b.prepareArg(instr.Args[0], targetType)
436 if sourceValStr != resultName {
437 // Check if prepareArg already did the conversion
438 originalValStr := b.formatValue(instr.Args[0])
439 if sourceValStr != originalValStr {
440 // prepareArg already handled the conversion, just assign the result
441 if strings.HasSuffix(targetType, "*") {
442 // For pointer types, use getelementptr with 0 offset (effectively a copy)
443 fmt.Fprintf(b.out, "%s = getelementptr i8, %s %s, i64 0\n", resultName, targetType, sourceValStr)
444 } else {
445 // For integer types, use add with 0
446 fmt.Fprintf(b.out, "%s = add %s %s, 0\n", resultName, targetType, sourceValStr)
447 }
448 } else {
449 // No conversion by prepareArg, so we need to cast
450 sourceType := b.getType(instr.Args[0])
451 if strings.HasSuffix(targetType, "*") {
452 fmt.Fprintf(b.out, "%s = bitcast %s %s to %s\n", resultName, sourceType, sourceValStr, targetType)
453 } else {
454 fmt.Fprintf(b.out, "%s = add %s %s, 0\n", resultName, targetType, sourceValStr)
455 }
456 }
457 }
458 b.tempTypes[resultName] = targetType
459
460 case ir.OpCall:
461 b.genCall(instr)
462
463 case ir.OpJmp:
464 fmt.Fprintf(b.out, "br label %%%s\n", instr.Args[0].String())
465
466 case ir.OpJnz:
467 condVal := b.prepareArg(instr.Args[0], "i1")
468 fmt.Fprintf(b.out, "br i1 %s, label %%%s, label %%%s\n", condVal, instr.Args[1].String(), instr.Args[2].String())
469
470 case ir.OpRet:
471 if len(instr.Args) > 0 && instr.Args[0] != nil {
472 retType := b.formatType(b.currentFn.ReturnType)
473 var retVal string
474 if c, ok := instr.Args[0].(*ir.Const); ok && c.Value == 0 && strings.HasSuffix(retType, "*") {
475 retVal = "null"
476 } else {
477 retVal = b.prepareArg(instr.Args[0], retType)
478 }
479 fmt.Fprintf(b.out, "ret %s %s\n", retType, retVal)
480 } else {
481 fmt.Fprintf(b.out, "ret void\n")
482 }
483
484 case ir.OpCEq, ir.OpCNeq, ir.OpCLt, ir.OpCGt, ir.OpCLe, ir.OpCGe:
485 lhsType, rhsType := b.getType(instr.Args[0]), b.getType(instr.Args[1])
486
487 var valType string
488 lhsIsPtr := strings.HasSuffix(lhsType, "*") || (lhsType == "unknown" && b.isPointerValue(instr.Args[0]))
489 rhsIsPtr := strings.HasSuffix(rhsType, "*") || (rhsType == "unknown" && b.isPointerValue(instr.Args[1]))
490
491 if lhsIsPtr || rhsIsPtr {
492 valType = "i8*"
493 } else if lhsType != "unknown" && lhsType != b.wordType {
494 valType = lhsType
495 } else if rhsType != "unknown" && rhsType != b.wordType {
496 valType = rhsType
497 } else {
498 valType = b.wordType
499 }
500
501 isFloat := valType == "float" || valType == "double"
502 var opStr, predicate string
503 if isFloat {
504 opStr = "fcmp"
505 switch instr.Op {
506 case ir.OpCEq:
507 predicate = "oeq"
508 case ir.OpCNeq:
509 predicate = "one"
510 case ir.OpCLt:
511 predicate = "olt"
512 case ir.OpCGt:
513 predicate = "ogt"
514 case ir.OpCLe:
515 predicate = "ole"
516 case ir.OpCGe:
517 predicate = "oge"
518 }
519 } else {
520 opStr = "icmp"
521 switch instr.Op {
522 case ir.OpCEq:
523 predicate = "eq"
524 case ir.OpCNeq:
525 predicate = "ne"
526 case ir.OpCLt:
527 predicate = "slt"
528 case ir.OpCGt:
529 predicate = "sgt"
530 case ir.OpCLe:
531 predicate = "sle"
532 case ir.OpCGe:
533 predicate = "sge"
534 }
535 }
536
537 lhs := b.prepareArgForComparison(instr.Args[0], valType)
538 rhs := b.prepareArgForComparison(instr.Args[1], valType)
539
540 i1Temp := b.newBackendTemp()
541 fmt.Fprintf(b.out, "%s = %s %s %s %s, %s\n", i1Temp, opStr, predicate, valType, lhs, rhs)
542 b.tempTypes[i1Temp] = "i1"
543 fmt.Fprintf(b.out, "\t%s = zext i1 %s to %s\n", resultName, i1Temp, b.wordType)
544 b.tempTypes[resultName] = b.wordType
545
546 case ir.OpNegF:
547 opStr, _ := b.formatOp(instr.Op)
548 valType := b.formatType(instr.Typ)
549 arg := b.prepareArg(instr.Args[0], valType)
550 fmt.Fprintf(b.out, "%s = %s %s %s\n", resultName, opStr, valType, arg)
551 b.tempTypes[resultName] = valType
552 case ir.OpSub, ir.OpSubF, ir.OpMul, ir.OpMulF, ir.OpDiv, ir.OpDivF, ir.OpRem, ir.OpRemF, ir.OpAnd, ir.OpOr, ir.OpXor, ir.OpShl, ir.OpShr:
553 opStr, _ := b.formatOp(instr.Op)
554 valType := b.formatType(instr.Typ)
555 lhs := b.prepareArg(instr.Args[0], valType)
556 rhs := b.prepareArg(instr.Args[1], valType)
557 fmt.Fprintf(b.out, "%s = %s %s %s, %s\n", resultName, opStr, valType, lhs, rhs)
558 b.tempTypes[resultName] = valType
559
560 case ir.OpBlit:
561 if len(instr.Args) >= 2 {
562 srcPtr := b.prepareArg(instr.Args[0], "i8*")
563 dstPtr := b.prepareArg(instr.Args[1], "i8*")
564 var sizeVal string
565 if len(instr.Args) >= 3 {
566 sizeVal = b.prepareArg(instr.Args[2], b.wordType)
567 } else {
568 sizeVal = fmt.Sprintf("%d", ir.SizeOfType(instr.Typ, b.cfg.WordSize))
569 }
570 fmt.Fprintf(b.out, "call void @llvm.memcpy.p0i8.p0i8.i64(i8* %s, i8* %s, i64 %s, i1 false)\n",
571 dstPtr, srcPtr, sizeVal)
572 }
573
574 case ir.OpSWToF, ir.OpSLToF:
575 valType := b.formatType(instr.Typ)
576 srcType := b.wordType
577 if instr.Op == ir.OpSWToF {
578 srcType = "i32"
579 }
580 srcVal := b.prepareArg(instr.Args[0], srcType)
581 fmt.Fprintf(b.out, "%s = sitofp %s %s to %s\n", resultName, srcType, srcVal, valType)
582 b.tempTypes[resultName] = valType
583
584 case ir.OpFToF:
585 valType := b.formatType(instr.Typ)
586 srcType := b.getType(instr.Args[0])
587 srcVal := b.prepareArg(instr.Args[0], srcType)
588
589 var castOp string
590 if valType == "double" && srcType == "float" {
591 castOp = "fpext"
592 } else if valType == "float" && srcType == "double" {
593 castOp = "fptrunc"
594 } else {
595 castOp = "bitcast"
596 }
597 fmt.Fprintf(b.out, "%s = %s %s %s to %s\n", resultName, castOp, srcType, srcVal, valType)
598 b.tempTypes[resultName] = valType
599
600 case ir.OpFToSI, ir.OpFToUI:
601 valType := b.formatType(instr.Typ)
602 srcType := b.getType(instr.Args[0])
603 srcVal := b.prepareArg(instr.Args[0], srcType)
604 castOp := "fptosi"
605 if instr.Op == ir.OpFToUI {
606 castOp = "fptoui"
607 }
608 fmt.Fprintf(b.out, "%s = %s %s %s to %s\n", resultName, castOp, srcType, srcVal, valType)
609 b.tempTypes[resultName] = valType
610
611 case ir.OpExtSB, ir.OpExtUB, ir.OpExtSH, ir.OpExtUH, ir.OpExtSW, ir.OpExtUW:
612 valType := b.formatType(instr.Typ)
613 var srcType, castOp string
614 switch instr.Op {
615 case ir.OpExtSB, ir.OpExtUB:
616 srcType, castOp = "i8", "sext"
617 if instr.Op == ir.OpExtUB {
618 castOp = "zext"
619 }
620 case ir.OpExtSH, ir.OpExtUH:
621 srcType, castOp = "i16", "sext"
622 if instr.Op == ir.OpExtUH {
623 castOp = "zext"
624 }
625 case ir.OpExtSW, ir.OpExtUW:
626 srcType, castOp = "i32", "sext"
627 if instr.Op == ir.OpExtUW {
628 castOp = "zext"
629 }
630 }
631 srcVal := b.prepareArg(instr.Args[0], srcType)
632 fmt.Fprintf(b.out, "%s = %s %s %s to %s\n", resultName, castOp, srcType, srcVal, valType)
633 b.tempTypes[resultName] = valType
634
635 default:
636 opStr, _ := b.formatOp(instr.Op)
637 valType := b.formatType(instr.Typ)
638 lhs := b.prepareArg(instr.Args[0], valType)
639 rhs := b.prepareArg(instr.Args[1], valType)
640 fmt.Fprintf(b.out, "%s = %s %s %s, %s\n", resultName, opStr, valType, lhs, rhs)
641 b.tempTypes[resultName] = valType
642 }
643}
644
645func (b *llvmBackend) genPhi(instr *ir.Instruction) string {
646 resultName := b.formatValue(instr.Result)
647 originalResultType := b.formatType(instr.Typ)
648 phiType := originalResultType
649
650 hasPtrInput, hasIntInput := false, false
651 for i := 1; i < len(instr.Args); i += 2 {
652 argType := b.getType(instr.Args[i])
653 if strings.HasSuffix(argType, "*") {
654 hasPtrInput = true
655 } else if strings.HasPrefix(argType, "i") {
656 hasIntInput = true
657 }
658 }
659
660 if hasPtrInput && hasIntInput {
661 phiType = "i8*"
662 }
663
664 var pairs []string
665 for i := 0; i < len(instr.Args); i += 2 {
666 labelName := instr.Args[i].String()
667 if labelName == "start" {
668 labelName = "entry"
669 }
670 val := b.prepareArgForPhi(instr.Args[i+1], phiType)
671 pairs = append(pairs, fmt.Sprintf("[ %s, %%%s ]", val, labelName))
672 }
673
674 phiResultName := resultName
675 if phiType != originalResultType {
676 phiResultName = b.newBackendTemp()
677 }
678
679 fmt.Fprintf(b.out, "\t%s = phi %s %s\n", phiResultName, phiType, strings.Join(pairs, ", "))
680 b.tempTypes[phiResultName] = phiType
681
682 if phiResultName != resultName {
683 b.tempTypes[resultName] = originalResultType
684 return fmt.Sprintf("\t%s\n", b.formatCast(phiResultName, resultName, phiType, originalResultType))
685 }
686 return ""
687}
688
689func (b *llvmBackend) prepareArgForPhi(v ir.Value, targetType string) string {
690 valStr := b.formatValue(v)
691 currentType := b.getType(v)
692
693 if currentType == targetType || currentType == "unknown" {
694 return valStr
695 }
696
697 if c, isConst := v.(*ir.Const); isConst {
698 if strings.HasSuffix(targetType, "*") && c.Value == 0 {
699 return "null"
700 }
701 if strings.HasSuffix(targetType, "*") {
702 return fmt.Sprintf("inttoptr (%s %s to %s)", currentType, valStr, targetType)
703 }
704 }
705
706 if _, isGlobal := v.(*ir.Global); isGlobal {
707 return fmt.Sprintf("bitcast (%s %s to %s)", currentType, valStr, targetType)
708 }
709 return valStr
710}
711
712func (b *llvmBackend) genAdd(instr *ir.Instruction) {
713 resultName := b.formatValue(instr.Result)
714 lhs, rhs := instr.Args[0], instr.Args[1]
715
716 _, isLhsGlobal := lhs.(*ir.Global)
717 _, isRhsGlobal := rhs.(*ir.Global)
718 isLhsFunc := isLhsGlobal && (b.prog.FindFunc(lhs.String()) != nil || b.funcSigs[lhs.String()] != "")
719 isRhsFunc := isRhsGlobal && (b.prog.FindFunc(rhs.String()) != nil || b.funcSigs[rhs.String()] != "")
720
721 lhsType, rhsType := b.getType(lhs), b.getType(rhs)
722 isLhsPtr := strings.HasSuffix(lhsType, "*") || (isLhsGlobal && !isLhsFunc)
723 isRhsPtr := strings.HasSuffix(rhsType, "*") || (isRhsGlobal && !isRhsFunc)
724
725 if (isLhsPtr && !isRhsPtr) || (!isLhsPtr && isRhsPtr) {
726 var ptr ir.Value
727 var ptrType string
728 var offset ir.Value
729 if isLhsPtr {
730 ptr, ptrType, offset = lhs, lhsType, rhs
731 } else {
732 ptr, ptrType, offset = rhs, rhsType, lhs
733 }
734
735 if ptrType == "unknown" {
736 ptrType = "i8*"
737 }
738
739 i8PtrVal := b.prepareArg(ptr, "i8*")
740 offsetVal := b.prepareArg(offset, b.wordType)
741
742 gepResultTemp := b.newBackendTemp()
743 fmt.Fprintf(b.out, "%s = getelementptr i8, i8* %s, %s %s\n", gepResultTemp, i8PtrVal, b.wordType, offsetVal)
744 b.tempTypes[gepResultTemp] = "i8*"
745
746 if ptrType != "i8*" {
747 fmt.Fprintf(b.out, "\t%s = bitcast i8* %s to %s\n", resultName, gepResultTemp, ptrType)
748 } else {
749 fmt.Fprintf(b.out, "\t%s = bitcast i8* %s to i8*\n", resultName, gepResultTemp)
750 }
751 b.tempTypes[resultName] = ptrType
752 return
753 }
754
755 if isLhsPtr && isRhsPtr {
756 lhsInt := b.prepareArg(lhs, b.wordType)
757 rhsInt := b.prepareArg(rhs, b.wordType)
758 resultInt := b.newBackendTemp()
759
760 fmt.Fprintf(b.out, "%s = add %s %s, %s\n", resultInt, b.wordType, lhsInt, rhsInt)
761 b.tempTypes[resultInt] = b.wordType
762
763 fmt.Fprintf(b.out, "\t%s = inttoptr %s %s to %s\n", resultName, b.wordType, resultInt, lhsType)
764 b.tempTypes[resultName] = lhsType
765 return
766 }
767
768 resultType := b.formatType(instr.Typ)
769 if strings.HasSuffix(resultType, "*") {
770 lhsInt := b.prepareArg(lhs, b.wordType)
771 rhsInt := b.prepareArg(rhs, b.wordType)
772 resultInt := b.newBackendTemp()
773
774 fmt.Fprintf(b.out, "%s = add %s %s, %s\n", resultInt, b.wordType, lhsInt, rhsInt)
775 b.tempTypes[resultInt] = b.wordType
776
777 fmt.Fprintf(b.out, "\t%s = inttoptr %s %s to %s\n", resultName, b.wordType, resultInt, resultType)
778 b.tempTypes[resultName] = resultType
779 } else {
780 lhsVals := b.prepareArg(lhs, resultType)
781 rhsVals := b.prepareArg(rhs, resultType)
782 fmt.Fprintf(b.out, "%s = add %s %s, %s\n", resultName, resultType, lhsVals, rhsVals)
783 b.tempTypes[resultName] = resultType
784 b.tempIRTypes[resultName] = instr.Typ
785 }
786}
787
788func (b *llvmBackend) genCall(instr *ir.Instruction) {
789 resultName := ""
790 if instr.Result != nil {
791 resultName = b.formatValue(instr.Result)
792 }
793
794 callee := instr.Args[0]
795 calleeStr := b.formatValue(callee)
796 retType := b.getFuncSig(callee.String())
797
798 // Check if this is an external function call (declared but not defined)
799 isExternalFunc := false
800 if g, ok := callee.(*ir.Global); ok {
801 // Check if the function is not defined in this program (making it external)
802 funcIsDefined := false
803 for _, fn := range b.prog.Funcs {
804 if fn.Name == g.Name {
805 funcIsDefined = true
806 break
807 }
808 }
809 if !funcIsDefined {
810 isExternalFunc = true
811 }
812 }
813
814 var argParts []string
815 for i, arg := range instr.Args[1:] {
816 targetType := b.wordType
817
818 // Determine the source type of the argument
819 sourceType := b.getType(arg)
820 if fc, ok := arg.(*ir.FloatConst); ok {
821 sourceType = b.formatType(fc.Typ)
822 }
823
824 if instr.ArgTypes != nil && i < len(instr.ArgTypes) {
825 requestedType := b.formatType(instr.ArgTypes[i])
826
827 // - float -> double
828 // - small integers (i8, i16) -> word type (int promotion)
829 // - preserve pointers and larger types
830 if isExternalFunc {
831 if requestedType == "float" {
832 targetType = "double"
833 } else if requestedType == "i8" || requestedType == "i16" {
834 targetType = b.wordType
835 } else {
836 targetType = requestedType
837 }
838 } else {
839 targetType = requestedType
840 }
841 } else {
842 // No explicit ArgTypes, infer from the argument and apply C standard promotions for external functions
843 if isExternalFunc {
844 if sourceType == "float" {
845 targetType = "double"
846 } else if sourceType == "i8" || sourceType == "i16" {
847 targetType = b.wordType
848 } else if sourceType != "unknown" {
849 targetType = sourceType
850 } else if g, ok := arg.(*ir.Global); ok {
851 if _, isString := b.prog.IsStringLabel(g.Name); isString {
852 targetType = "i8*"
853 }
854 }
855 } else if g, ok := arg.(*ir.Global); ok {
856 if _, isString := b.prog.IsStringLabel(g.Name); isString {
857 targetType = "i8*"
858 }
859 }
860 }
861
862 valStr := b.prepareArg(arg, targetType)
863 argParts = append(argParts, fmt.Sprintf("%s %s", targetType, valStr))
864 }
865
866 if _, isGlobal := callee.(*ir.Global); !isGlobal {
867 funcPtrType := fmt.Sprintf("%s (...)*", retType)
868 calleeStr = b.prepareArg(callee, funcPtrType)
869 }
870
871 callStr := fmt.Sprintf("call %s %s(%s)", retType, calleeStr, strings.Join(argParts, ", "))
872
873 if resultName != "" && retType != "void" {
874 fmt.Fprintf(b.out, "%s = %s\n", resultName, callStr)
875 b.tempTypes[resultName] = retType
876 } else {
877 fmt.Fprintf(b.out, "%s\n", callStr)
878 }
879}
880
881func (b *llvmBackend) prepareArg(v ir.Value, targetType string) string {
882 valStr := b.formatValue(v)
883 if g, ok := v.(*ir.Global); ok {
884 isFunc := b.prog.FindFunc(g.Name) != nil || b.funcSigs[g.Name] != ""
885 if isFunc {
886 if strings.HasPrefix(targetType, "i") && !strings.HasSuffix(targetType, "*") {
887 funcSig := b.getFuncSig(g.Name) + " (...)*"
888 castTemp := b.newBackendTemp()
889 fmt.Fprintf(b.out, "\t%s = ptrtoint %s @%s to %s\n", castTemp, funcSig, g.Name, targetType)
890 b.tempTypes[castTemp] = targetType
891 return castTemp
892 }
893 return "@" + g.Name
894 }
895 }
896
897 if _, ok := v.(*ir.Const); ok {
898 return valStr
899 }
900 if _, ok := v.(*ir.FloatConst); ok {
901 return valStr
902 }
903
904 currentType := b.getType(v)
905 if currentType == targetType || currentType == "unknown" {
906 return valStr
907 }
908
909 // Get the IR type to determine signedness for casting
910 sourceIRType := b.getIRType(v)
911 castTemp := b.newBackendTemp()
912 b.out.WriteString("\t")
913 b.out.WriteString(b.formatCastWithSignedness(valStr, castTemp, currentType, targetType, sourceIRType))
914 b.out.WriteString("\n")
915 b.tempTypes[castTemp] = targetType
916 return castTemp
917}
918
919func (b *llvmBackend) formatCast(sourceName, targetName, sourceType, targetType string) string {
920 return b.formatCastWithSignedness(sourceName, targetName, sourceType, targetType, ir.TypeNone)
921}
922
923func (b *llvmBackend) formatCastWithSignedness(sourceName, targetName, sourceType, targetType string, sourceIRType ir.Type) string {
924 isSourcePtr, isTargetPtr := strings.HasSuffix(sourceType, "*"), strings.HasSuffix(targetType, "*")
925 isSourceInt, isTargetInt := strings.HasPrefix(sourceType, "i") && !isSourcePtr, strings.HasPrefix(targetType, "i") && !isTargetPtr
926 isSourceFloat, isTargetFloat := sourceType == "float" || sourceType == "double", targetType == "float" || targetType == "double"
927
928 var castOp string
929 switch {
930 case sourceType == "i1" && isTargetInt:
931 castOp = "zext"
932 case isSourceInt && targetType == "i1":
933 return fmt.Sprintf("%s = icmp ne %s %s, 0", targetName, sourceType, sourceName)
934 case isSourceInt && isTargetPtr:
935 castOp = "inttoptr"
936 case isSourcePtr && isTargetInt:
937 castOp = "ptrtoint"
938 case isSourcePtr && isTargetPtr:
939 castOp = "bitcast"
940 case isSourceInt && isTargetInt:
941 sourceBits, _ := strconv.Atoi(strings.TrimPrefix(sourceType, "i"))
942 targetBits, _ := strconv.Atoi(strings.TrimPrefix(targetType, "i"))
943 if sourceBits > targetBits {
944 castOp = "trunc"
945 } else {
946 // Choose sext vs zext based on source IR type signedness
947 switch sourceIRType {
948 case ir.TypeUB, ir.TypeUH:
949 castOp = "zext" // unsigned types get zero extension
950 default:
951 castOp = "sext" // signed types (and ambiguous ones) get sign extension
952 }
953 }
954 case isSourceInt && isTargetFloat:
955 castOp = "sitofp"
956 case isSourceFloat && isTargetInt:
957 castOp = "fptosi"
958 case isSourceFloat && isTargetFloat:
959 castOp = "fpext"
960 if sourceType == "double" {
961 castOp = "fptrunc"
962 }
963 default:
964 castOp = "bitcast"
965 }
966 return fmt.Sprintf("%s = %s %s %s to %s", targetName, castOp, sourceType, sourceName, targetType)
967}
968
969func (b *llvmBackend) getType(v ir.Value) string {
970 valStr := b.formatValue(v)
971 if t, ok := b.tempTypes[valStr]; ok {
972 return t
973 }
974 if _, ok := v.(*ir.Const); ok {
975 return b.wordType
976 }
977 if fc, ok := v.(*ir.FloatConst); ok {
978 return b.formatType(fc.Typ)
979 }
980 if g, ok := v.(*ir.Global); ok {
981 if _, isString := b.prog.IsStringLabel(g.Name); isString {
982 return "i8*"
983 }
984 }
985 return "unknown"
986}
987
988func (b *llvmBackend) getIRType(v ir.Value) ir.Type {
989 valStr := b.formatValue(v)
990 if t, ok := b.tempIRTypes[valStr]; ok {
991 return t
992 }
993 if _, ok := v.(*ir.Const); ok {
994 return ir.GetType(nil, b.prog.WordSize)
995 }
996 if fc, ok := v.(*ir.FloatConst); ok {
997 return fc.Typ
998 }
999 return ir.TypeNone
1000}
1001
1002func (b *llvmBackend) newBackendTemp() string {
1003 name := fmt.Sprintf("%%.b%d", b.prog.GetBackendTempCount())
1004 b.prog.IncBackendTempCount()
1005 return name
1006}
1007
1008func (b *llvmBackend) formatValue(v ir.Value) string {
1009 if v == nil {
1010 return "void"
1011 }
1012 switch val := v.(type) {
1013 case *ir.Const:
1014 return fmt.Sprintf("%d", val.Value)
1015 case *ir.FloatConst:
1016 if val.Typ == ir.TypeS {
1017 // For 32-bit floats, truncate to float32 precision then expand back to float64 for hex format
1018 float32Val := float32(val.Value)
1019 return fmt.Sprintf("0x%016X", math.Float64bits(float64(float32Val)))
1020 } else {
1021 return fmt.Sprintf("0x%016X", math.Float64bits(val.Value))
1022 }
1023 case *ir.Global:
1024 return "@" + val.Name
1025 case *ir.Temporary:
1026 safeName := strings.NewReplacer(".", "_", "[", "_", "]", "_").Replace(val.Name)
1027 if val.ID == -1 {
1028 return "%" + safeName
1029 }
1030 if safeName != "" {
1031 return fmt.Sprintf("%%.%s_%d", safeName, val.ID)
1032 }
1033 return fmt.Sprintf("%%t%d", val.ID)
1034 case *ir.Label:
1035 return "%" + val.Name
1036 case *ir.CastValue:
1037 return b.formatValue(val.Value)
1038 default:
1039 return ""
1040 }
1041}
1042
1043func (b *llvmBackend) formatType(t ir.Type) string {
1044 switch t {
1045 case ir.TypeB, ir.TypeSB, ir.TypeUB:
1046 return "i8"
1047 case ir.TypeH, ir.TypeSH, ir.TypeUH:
1048 return "i16"
1049 case ir.TypeW:
1050 return "i32"
1051 case ir.TypeL:
1052 return "i64"
1053 case ir.TypeS:
1054 return "float"
1055 case ir.TypeD:
1056 return "double"
1057 case ir.TypeNone:
1058 return "void"
1059 case ir.TypePtr:
1060 return "i8*"
1061 default:
1062 return b.wordType
1063 }
1064}
1065
1066func (b *llvmBackend) formatOp(op ir.Op) (string, string) {
1067 switch op {
1068 case ir.OpAdd:
1069 return "add", ""
1070 case ir.OpSub:
1071 return "sub", ""
1072 case ir.OpMul:
1073 return "mul", ""
1074 case ir.OpDiv:
1075 return "sdiv", ""
1076 case ir.OpRem:
1077 return "srem", ""
1078 case ir.OpAddF:
1079 return "fadd", ""
1080 case ir.OpSubF:
1081 return "fsub", ""
1082 case ir.OpMulF:
1083 return "fmul", ""
1084 case ir.OpDivF:
1085 return "fdiv", ""
1086 case ir.OpRemF:
1087 return "frem", ""
1088 case ir.OpNegF:
1089 return "fneg", ""
1090 case ir.OpAnd:
1091 return "and", ""
1092 case ir.OpOr:
1093 return "or", ""
1094 case ir.OpXor:
1095 return "xor", ""
1096 case ir.OpShl:
1097 return "shl", ""
1098 case ir.OpShr:
1099 return "ashr", ""
1100 case ir.OpCEq:
1101 return "icmp", "eq"
1102 case ir.OpCNeq:
1103 return "icmp", "ne"
1104 case ir.OpCLt:
1105 return "icmp", "slt"
1106 case ir.OpCGt:
1107 return "icmp", "sgt"
1108 case ir.OpCLe:
1109 return "icmp", "sle"
1110 case ir.OpCGe:
1111 return "icmp", "sge"
1112 default:
1113 return "unknown_op", ""
1114 }
1115}
1116
1117func (b *llvmBackend) isPointerValue(v ir.Value) bool {
1118 if g, ok := v.(*ir.Global); ok {
1119 if _, isString := b.prog.IsStringLabel(g.Name); isString {
1120 return true
1121 }
1122 return b.prog.FindFunc(g.Name) == nil && b.funcSigs[g.Name] == ""
1123 }
1124 return false
1125}
1126
1127func (b *llvmBackend) prepareArgForComparison(v ir.Value, targetType string) string {
1128 if c, isConst := v.(*ir.Const); isConst && c.Value == 0 && strings.HasSuffix(targetType, "*") {
1129 return "null"
1130 }
1131 return b.prepareArg(v, targetType)
1132}