Skip to content

Commit f4db472

Browse files
committed
[Target] Support CUDA device function calls
This commit adds support for CUDA device function calls by: 1. Modifying the calling convention handling in CUDA codegen to support both device kernel launches and device function calls 2. Updating the function signature printing to emit appropriate CUDA attributes (__global__ vs __device__) based on calling convention 3. Adding a test case demonstrating device function calls 4. Fixing target handling in split_host_device_mods to properly handle device function dictionaries 5. Adding a safety check for global symbol extraction The changes enable proper compilation and execution of CUDA device functions that can be called from CUDA kernels. Example: ```python @I.ir_module class Module: @T.prim_func(private=True) def add(a: T.float32, b: T.float32) -> T.float32: return a + b @T.prim_func def main( A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32"), ): for bx in T.thread_binding(1024, "blockIdx.x"): for tx in T.thread_binding(1024, "threadIdx.x"): C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) ```
1 parent 9cb6705 commit f4db472

File tree

6 files changed

+60
-13
lines changed

6 files changed

+60
-13
lines changed

python/tvm/tir/build.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,23 @@ class CallConv(enum.IntEnum):
5454
kDeviceKernelLaunch = 2
5555

5656
host_mod = tvm.tir.transform.Filter(
57-
lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
58-
!= int(CallConv.kDeviceKernelLaunch)
57+
lambda f: "cpu" in str(f.attrs.get("target", "cpu"))
5958
)(mod)
6059
device_mod = tvm.tir.transform.Filter(
61-
lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
62-
== int(CallConv.kDeviceKernelLaunch)
60+
lambda f: "cpu" not in str(f.attrs.get("target", "cpu"))
6361
)(mod)
62+
# TODO(syfeng): Here we use str as key since target hash is not correct
63+
target_str2target = {}
64+
device_func_dict = {}
6465
device_mod_dict = {}
6566
for gv, func in device_mod.functions.items():
66-
device_mod_dict.setdefault(func.attrs.get("target", None), dict()).update({gv: func})
67-
for target, funcs in device_mod_dict.items():
68-
device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs)
67+
target = func.attrs.get("target", None)
68+
target_str = str(target) if target is not None else ""
69+
target_str2target[target_str] = target # This might be overridden by the last one
70+
device_func_dict.setdefault(target_str, dict()).update({gv: func})
71+
for target_str in target_str2target.keys():
72+
target = target_str2target[target_str]
73+
device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], attrs=device_mod.attrs)
6974
return host_mod, device_mod_dict
7075

7176

src/target/build_common.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ inline std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
5656
}
5757
}
5858
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
59-
fmap[static_cast<std::string>(global_symbol.value())] = info;
59+
if (global_symbol) {
60+
fmap[static_cast<std::string>(global_symbol.value())] = info;
61+
}
6062
}
6163
return fmap;
6264
}

src/target/opt/build_cuda_on.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,12 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
134134
for (auto [gvar, base_func] : mod->functions) {
135135
ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
136136
auto prim_func = Downcast<PrimFunc>(base_func);
137-
auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
138-
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
139-
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
137+
auto calling_conv =
138+
prim_func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDefault));
139+
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch ||
140+
calling_conv == CallingConv::kDefault)
141+
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch or "
142+
"CallingConv::kDefault";
140143
functions.Set(gvar, prim_func);
141144
}
142145

src/target/source/codegen_cuda.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,19 @@ void CodeGenCUDA::Init(bool output_ssa) {
140140
ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
141141
}
142142

143-
void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }
143+
void CodeGenCUDA::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
144+
std::ostream& os) {
145+
auto calling_conv =
146+
func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDefault));
147+
if (calling_conv == CallingConv::kDeviceKernelLaunch) {
148+
os << "extern \"C\" __global__ ";
149+
} else if (calling_conv == CallingConv::kDefault) {
150+
os << "extern \"C\" __device__ ";
151+
} else {
152+
LOG(FATAL) << "Unsupported calling convention for cuda codegen: " << calling_conv;
153+
}
154+
CodeGenC::PrintFunctionSignature(function_name, func, os);
155+
}
144156

145157
class ThreadIdxExtractor : public tir::StmtVisitor {
146158
private:

src/target/source/codegen_cuda.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class CodeGenCUDA final : public CodeGenC {
4646
enable_fp4_ || need_math_constants_h_ || need_mma_h_);
4747
}
4848
// override behavior
49-
void PrintFuncPrefix(std::ostream& os) final;
49+
void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
50+
std::ostream& os) final;
5051
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; // NOLINT(*)
5152
void VisitStmt_(const ForNode* op) final;
5253
void PrintStorageSync(const CallNode* op) final;

tests/python/codegen/test_target_codegen_cuda.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import tvm.testing
2424
from tvm import te, topi
2525
from tvm.contrib.nvcc import have_bf16, have_fp16, have_int8
26+
from tvm.script import ir as I
2627
from tvm.script import tir as T
2728

2829

@@ -746,5 +747,28 @@ def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None:
746747
tvm.compile(func, target="cuda")
747748

748749

750+
@tvm.testing.requires_cuda
751+
def test_cuda_device_func_call():
752+
@I.ir_module
753+
class Module:
754+
@T.prim_func(private=True)
755+
def add(a: T.float32, b: T.float32) -> T.float32:
756+
return a + b
757+
758+
@T.prim_func
759+
def main(
760+
A: T.Buffer((1024, 1024), "float32"),
761+
B: T.Buffer((1024, 1024), "float32"),
762+
C: T.Buffer((1024, 1024), "float32"),
763+
):
764+
for bx in T.thread_binding(1024, "blockIdx.x"):
765+
for tx in T.thread_binding(1024, "threadIdx.x"):
766+
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])
767+
768+
lib = tvm.compile(Module, target="cuda")
769+
cuda_code = lib.mod.imported_modules[0].get_source()
770+
assert 'extern "C" __device__ float add(float a, float b) {\n return (a + b);\n}' in cuda_code
771+
772+
749773
if __name__ == "__main__":
750774
tvm.testing.main()

0 commit comments

Comments
 (0)