@@ -214,20 +214,88 @@ struct LLVMFunc{F,tt}
214214end
215215
216216
217+ const GPUCompiler = CUDA. GPUCompiler
218+ const LLVM = GPUCompiler. LLVM
219+
220+
221+ GPULowerCPUFeaturesPass () = LLVM. NewPMModulePass (" GPULowerCPUFeatures" , GPUCompiler. cpu_features!)
222+ GPULowerPTLSPass () = LLVM. NewPMModulePass (" GPULowerPTLS" , GPUCompiler. lower_ptls!)
223+ GPULowerGCFramePass () = LLVM. NewPMFunctionPass (" GPULowerGCFrame" , GPUCompiler. lower_gc_frame!)
224+ function noop_pass (x)
225+ return false
226+ end
227+ function kern_pass (mod)
228+ for fname in (" julia.gpu.state_getter" ,)
229+ if LLVM. haskey (LLVM. functions (mod), fname)
230+ fn = LLVM. functions (mod)[fname]
231+ insts = LLVM. Instruction[]
232+ for u in LLVM. uses (fn)
233+ u = LLVM. user (u)
234+ LLVM. replace_uses! (u, LLVM. UndefValue (LLVM. value_type (u)))
235+ push! (insts, u)
236+ end
237+ for inst in insts
238+ Reactant. Enzyme. Compiler. eraseInst (LLVM. parent (inst), inst)
239+ end
240+ Reactant. Enzyme. Compiler. eraseInst (mod, fn)
241+ end
242+ end
243+
244+ return true
245+ end
246+ AddKernelStatePass () = LLVM. NewPMModulePass (" AddKernelStatePass" , kern_pass)
247+ LowerKernelStatePass () = LLVM. NewPMFunctionPass (" LowerKernelStatePass" , noop_pass)
248+ CleanupKernelStatePass () = LLVM. NewPMModulePass (" CleanupKernelStatePass" , noop_pass)
249+
217250# compile to executable machine code
218251function compile (job)
252+
219253 # lower to PTX
220254 # TODO : on 1.9, this actually creates a context. cache those.
221- modstr, image, entry = CUDA. GPUCompiler. JuliaContext () do ctx
222- asm, meta = CUDA. GPUCompiler. compile (:asm , job)
223- mod = meta. ir
224-
255+ modstr, image, entry = GPUCompiler. JuliaContext () do ctx
256+ mod, meta = GPUCompiler. compile (:llvm , job; optimize= false , cleanup= false , validate= false )
257+ GPUCompiler. optimize_module! (job, mod)
258+ opt_level = 2
259+ tm = GPUCompiler. llvm_machine (job. config. target)
260+ LLVM. @dispose pb= LLVM. NewPMPassBuilder () begin
261+ LLVM. register! (pb, GPULowerCPUFeaturesPass ())
262+ LLVM. register! (pb, GPULowerPTLSPass ())
263+ LLVM. register! (pb, GPULowerGCFramePass ())
264+ LLVM. register! (pb, AddKernelStatePass ())
265+ LLVM. register! (pb, LowerKernelStatePass ())
266+ LLVM. register! (pb, CleanupKernelStatePass ())
267+
268+ LLVM. add! (pb, LLVM. NewPMModulePassManager ()) do mpm
269+ GPUCompiler. buildNewPMPipeline! (mpm, job, opt_level)
270+ end
271+ LLVM. run! (pb, mod, tm)
272+ end
273+ GPUCompiler. optimize_module! (job, mod)
274+ LLVM. run! (CUDA. GPUCompiler. DeadArgumentEliminationPass (), mod, tm)
275+
276+
277+ for fname in (" gpu_report_exception" , " gpu_signal_exception" )
278+ if LLVM. haskey (LLVM. functions (mod), fname)
279+ fn = LLVM. functions (mod)[fname]
280+ insts = LLVM. Instruction[]
281+ for u in LLVM. uses (fn)
282+ push! (insts, LLVM. user (u))
283+ end
284+ for inst in insts
285+ Reactant. Enzyme. Compiler. eraseInst (LLVM. parent (inst), inst)
286+ end
287+ Reactant. Enzyme. Compiler. eraseInst (mod, fn)
288+ end
289+ end
290+
291+ LLVM. strip_debuginfo! (mod)
225292 modstr = string (mod)
226293
227294 # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
228295 # it is probably safer to reparse a string using the right llvm module api, so we will do that.
229296
230- mmod = MLIR. IR. Module (@ccall MLIR. API. mlir_c. ConvertLLVMToMLIR (mod:: CUDA.LLVM.API.LLVMModuleRef , MLIR. IR. context ():: MLIR.API.MlirContext ):: MLIR.API.MlirModule )
297+ println (string (modstr))
298+ mmod = MLIR. IR. Module (@ccall MLIR. API. mlir_c. ConvertLLVMStrToMLIR (modstr:: Cstring , MLIR. IR. context ():: MLIR.API.MlirContext ):: MLIR.API.MlirModule )
231299 @show mmod
232300
233301 # check if we'll need the device runtime
@@ -461,8 +529,18 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg
461529 cache = compiler_cache (MLIR. IR. context ())
462530 source = CUDA. methodinstance (F, tt)
463531
464- cuda = CUDA. active_state ()
465- config = CUDA. compiler_config (cuda. device; kwargs... ):: CUDA.CUDACompilerConfig
532+ # cuda = CUDA.active_state()
533+ device = nothing # cuda.device
534+ # config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig
535+ cuda_cap= v " 5.0"
536+ cuda_ptx= v " 6.3"
537+ llvm_cap= v " 5.0"
538+ llvm_ptx= v " 6.3"
539+ kernel= true
540+ always_inline= false
541+ name= nothing
542+ debuginfo= false
543+ config = CUDA. CompilerConfig (CUDA. PTXCompilerTarget (; cap= llvm_cap, ptx= llvm_ptx, debuginfo), CUDA. CUDACompilerParams (; cap= cuda_cap, ptx= cuda_ptx); kernel, name, always_inline)
466544 CUDA. GPUCompiler. cached_compilation (cache, source, config, compile, link)
467545 end
468546 res
0 commit comments