|
30 | 30 | #include "llvm/Analysis/TargetTransformInfo.h" |
31 | 31 | #include "llvm/IR/CallingConv.h" |
32 | 32 | #include "llvm/IR/Constants.h" |
| 33 | +#include "llvm/IR/IRBuilder.h" |
| 34 | +#include "llvm/IR/LegacyPassManager.h" |
33 | 35 | #include "llvm/IR/Module.h" |
| 36 | +#include "llvm/IR/Verifier.h" |
34 | 37 | #include "llvm/IRReader/IRReader.h" |
35 | 38 | #include "llvm/Linker/Linker.h" |
| 39 | +#include "llvm/MC/TargetRegistry.h" |
| 40 | +#include "llvm/Pass.h" |
36 | 41 | #include "llvm/Passes/OptimizationLevel.h" |
37 | 42 | #include "llvm/Passes/PassBuilder.h" |
| 43 | +#include "llvm/Support/CommandLine.h" |
38 | 44 | #include "llvm/Support/Error.h" |
39 | 45 | #include "llvm/Support/FormatVariadic.h" |
40 | 46 | #include "llvm/Support/SourceMgr.h" |
| 47 | +#include "llvm/Support/TargetSelect.h" |
41 | 48 | #include "llvm/Target/TargetMachine.h" |
| 49 | +#include "llvm/Transforms/IPO/AlwaysInliner.h" |
42 | 50 | #include "llvm/Transforms/InstCombine/InstCombine.h" |
| 51 | +#include <filesystem> |
43 | 52 | #include <optional> |
44 | 53 | #ifdef _WIN32 |
45 | 54 | #define WIN32_LEAN_AND_MEAN |
@@ -512,5 +521,81 @@ void addExternalLibs(mlir::ModuleOp &module, |
512 | 521 | module.getOperation()->setAttr("triton_gpu.externs", dict); |
513 | 522 | } |
514 | 523 |
|
| 524 | +// TODO: move to python |
| 525 | +static void initLLVM() { |
| 526 | + static std::once_flag init_flag; |
| 527 | + std::call_once(init_flag, []() { |
| 528 | + LLVMInitializeNVPTXTargetInfo(); |
| 529 | + LLVMInitializeNVPTXTarget(); |
| 530 | + LLVMInitializeNVPTXTargetMC(); |
| 531 | + LLVMInitializeNVPTXAsmPrinter(); |
| 532 | + |
| 533 | + LLVMInitializeAMDGPUTarget(); |
| 534 | + LLVMInitializeAMDGPUTargetInfo(); |
| 535 | + LLVMInitializeAMDGPUTargetMC(); |
| 536 | + LLVMInitializeAMDGPUAsmParser(); |
| 537 | + LLVMInitializeAMDGPUAsmPrinter(); |
| 538 | + }); |
| 539 | +} |
| 540 | + |
| 541 | +std::string translateLLVMIRToASM(llvm::Module &module, |
| 542 | + const std::string &triple, |
| 543 | + const std::string &proc, |
| 544 | + const std::string &features, |
| 545 | + const std::vector<std::string> &flags, |
| 546 | + bool enable_fp_fusion, bool isObject) { |
| 547 | + initLLVM(); |
| 548 | + // options |
| 549 | + auto options = llvm::cl::getRegisteredOptions(); |
| 550 | + for (std::string flag : flags) { |
| 551 | + auto *shortPtr = static_cast<llvm::cl::opt<bool> *>(options[flag]); |
| 552 | + assert(shortPtr); |
| 553 | + shortPtr->setValue(true); |
| 554 | + } |
| 555 | + // inline everything |
| 556 | + for (llvm::Function &f : module.functions()) |
| 557 | + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) |
| 558 | + f.addFnAttr(llvm::Attribute::AlwaysInline); |
| 559 | + // verify and store llvm |
| 560 | + llvm::legacy::PassManager pm; |
| 561 | + pm.add(llvm::createAlwaysInlinerLegacyPass()); |
| 562 | + pm.add(llvm::createVerifierPass()); |
| 563 | + pm.run(module); |
| 564 | + // module->print(llvm::outs(), nullptr); |
| 565 | + |
| 566 | + // create machine |
| 567 | + module.setTargetTriple(triple); |
| 568 | + std::string error; |
| 569 | + auto target = |
| 570 | + llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error); |
| 571 | + llvm::TargetOptions opt; |
| 572 | + if (enable_fp_fusion) |
| 573 | + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; |
| 574 | + opt.UnsafeFPMath = false; |
| 575 | + opt.NoInfsFPMath = false; |
| 576 | + opt.NoNaNsFPMath = true; |
| 577 | + opt.TrapUnreachable = true; |
| 578 | + std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine( |
| 579 | + module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, |
| 580 | + std::nullopt, llvm::CodeGenOptLevel::Aggressive)}; |
| 581 | + // set data layout |
| 582 | + module.setDataLayout(machine->createDataLayout()); |
| 583 | + // emit machine code |
| 584 | + std::string result; |
| 585 | + { |
| 586 | + llvm::raw_string_ostream stream(result); |
| 587 | + llvm::buffer_ostream pstream(stream); |
| 588 | + for (llvm::Function &f : module.functions()) |
| 589 | + f.addFnAttr(llvm::Attribute::AlwaysInline); |
| 590 | + llvm::legacy::PassManager pass; |
| 591 | + // emit |
| 592 | + auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile |
| 593 | + : llvm::CodeGenFileType::AssemblyFile; |
| 594 | + machine->addPassesToEmitFile(pass, pstream, nullptr, fileType); |
| 595 | + pass.run(module); |
| 596 | + } |
| 597 | + return result; |
| 598 | +} |
| 599 | + |
515 | 600 | } // namespace triton |
516 | 601 | } // namespace mlir |
0 commit comments