Skip to content

Commit 364823a

Browse files
committed
continuing
1 parent e2ffe87 commit 364823a

File tree

3 files changed

+90
-8
lines changed

3 files changed

+90
-8
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,11 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) {
376376
return wrap(res);
377377
}
378378

379+
#include "llvm/IRReader/IRReader.h"
379380
extern "C" MlirModule ConvertLLVMStrToMLIR(const char* lmod, MlirContext cctx) {
380381
LLVMContext Context;
381-
auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Context);
382+
SMDiagnostic Err;
383+
auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context);
382384
mlir::MLIRContext &context = *unwrap(cctx);
383385
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/false, /*dropDICompositeElements*/false).release();
384386
return wrap(res);

deps/ReactantExtra/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,8 @@ cc_library(
450450
"@llvm-project//mlir:SCFDialect",
451451
"@llvm-project//mlir:TransformDialect",
452452
"@llvm-project//mlir:Transforms",
453+
454+
"@llvm-project//llvm:IRReader",
453455
"@llvm-project//llvm:Support",
454456
"@llvm-project//llvm:AArch64AsmParser",
455457
"@llvm-project//llvm:AArch64CodeGen",

ext/ReactantCUDAExt.jl

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,20 +214,88 @@ struct LLVMFunc{F,tt}
214214
end
215215

216216

217+
const GPUCompiler = CUDA.GPUCompiler
218+
const LLVM = GPUCompiler.LLVM
219+
220+
221+
GPULowerCPUFeaturesPass() = LLVM.NewPMModulePass("GPULowerCPUFeatures", GPUCompiler.cpu_features!)
222+
GPULowerPTLSPass() = LLVM.NewPMModulePass("GPULowerPTLS", GPUCompiler.lower_ptls!)
223+
GPULowerGCFramePass() = LLVM.NewPMFunctionPass("GPULowerGCFrame", GPUCompiler.lower_gc_frame!)
224+
function noop_pass(x)
225+
return false
226+
end
227+
function kern_pass(mod)
228+
for fname in ("julia.gpu.state_getter",)
229+
if LLVM.haskey(LLVM.functions(mod), fname)
230+
fn = LLVM.functions(mod)[fname]
231+
insts = LLVM.Instruction[]
232+
for u in LLVM.uses(fn)
233+
u = LLVM.user(u)
234+
LLVM.replace_uses!(u, LLVM.UndefValue(LLVM.value_type(u)))
235+
push!(insts, u)
236+
end
237+
for inst in insts
238+
Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst)
239+
end
240+
Reactant.Enzyme.Compiler.eraseInst(mod, fn)
241+
end
242+
end
243+
244+
return true
245+
end
246+
AddKernelStatePass() = LLVM.NewPMModulePass("AddKernelStatePass", kern_pass)
247+
LowerKernelStatePass() = LLVM.NewPMFunctionPass("LowerKernelStatePass", noop_pass)
248+
CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_pass)
249+
217250
# compile to executable machine code
218251
function compile(job)
252+
219253
# lower to PTX
220254
# TODO: on 1.9, this actually creates a context. cache those.
221-
modstr, image, entry = CUDA.GPUCompiler.JuliaContext() do ctx
222-
asm, meta = CUDA.GPUCompiler.compile(:asm, job)
223-
mod = meta.ir
224-
255+
modstr, image, entry = GPUCompiler.JuliaContext() do ctx
256+
mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false)
257+
GPUCompiler.optimize_module!(job, mod)
258+
opt_level = 2
259+
tm = GPUCompiler.llvm_machine(job.config.target)
260+
LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin
261+
LLVM.register!(pb, GPULowerCPUFeaturesPass())
262+
LLVM.register!(pb, GPULowerPTLSPass())
263+
LLVM.register!(pb, GPULowerGCFramePass())
264+
LLVM.register!(pb, AddKernelStatePass())
265+
LLVM.register!(pb, LowerKernelStatePass())
266+
LLVM.register!(pb, CleanupKernelStatePass())
267+
268+
LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm
269+
GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level)
270+
end
271+
LLVM.run!(pb, mod, tm)
272+
end
273+
GPUCompiler.optimize_module!(job, mod)
274+
LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm)
275+
276+
277+
for fname in ("gpu_report_exception", "gpu_signal_exception")
278+
if LLVM.haskey(LLVM.functions(mod), fname)
279+
fn = LLVM.functions(mod)[fname]
280+
insts = LLVM.Instruction[]
281+
for u in LLVM.uses(fn)
282+
push!(insts, LLVM.user(u))
283+
end
284+
for inst in insts
285+
Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst)
286+
end
287+
Reactant.Enzyme.Compiler.eraseInst(mod, fn)
288+
end
289+
end
290+
291+
LLVM.strip_debuginfo!(mod)
225292
modstr = string(mod)
226293

227294
# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
228295
# it is probably safer to reparse a string using the right llvm module api, so we will do that.
229296

230-
mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMToMLIR(mod::CUDA.LLVM.API.LLVMModuleRef, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule)
297+
println(string(modstr))
298+
mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule)
231299
@show mmod
232300

233301
# check if we'll need the device runtime
@@ -461,8 +529,18 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg
461529
cache = compiler_cache(MLIR.IR.context())
462530
source = CUDA.methodinstance(F, tt)
463531

464-
cuda = CUDA.active_state()
465-
config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig
532+
# cuda = CUDA.active_state()
533+
device = nothing # cuda.device
534+
# config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig
535+
cuda_cap=v"5.0"
536+
cuda_ptx=v"6.3"
537+
llvm_cap=v"5.0"
538+
llvm_ptx=v"6.3"
539+
kernel=true
540+
always_inline=false
541+
name=nothing
542+
debuginfo=false
543+
config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline)
466544
CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)
467545
end
468546
res

0 commit comments

Comments
 (0)