-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[Target] Support CUDA device function calls #18055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
super helpful enhancement, Thanks! |
f4db472
to
4bfdbf7
Compare
This commit adds support for CUDA device function calls by: 1. Modifying the calling convention handling in CUDA codegen to support both device kernel launches and device function calls 2. Updating the function signature printing to emit appropriate CUDA attributes (__global__ vs __device__) based on calling convention 3. Adding a test case demonstrating device function calls 4. Fixing target handling in split_host_device_mods to properly handle device function dictionaries 5. Adding a safety check for global symbol extraction The changes enable proper compilation and execution of CUDA device functions that can be called from CUDA kernels. Example: ```python @I.ir_module class Module: @T.prim_func(private=True) def add(a: T.float32, b: T.float32) -> T.float32: return a + b @T.prim_func def main( A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32"), ): for bx in T.thread_binding(1024, "blockIdx.x"): for tx in T.thread_binding(1024, "threadIdx.x"): C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) ```
return a + b | ||
|
||
@T.prim_func | ||
def main( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think right now it's quite implicit to determine which function is the kernel function and which function is the device function. It might be clearer if we can mark @T.prim_func(kind="device")
explicitly for device functions.
Moreover, we can enhance by adding a test case where all functions are not wrapped by Module, and instead of compiling the Module, we compile the kernel function directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting pt, added some followup comment on #18055 (comment) about distinctions before/after SplitHostDevice, maybe it is fine before SplitHostDevice(in this case), but would be good to clarify in comment
device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs) | ||
target = func.attrs.get("target", None) | ||
target_str = str(target) if target is not None else "" | ||
target_str2target[target_str] = target # This might be overridden by the last one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to make sure in which cases different Target
obects might have the same string representations target_str
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now str uniquely maps a target so it is ok here, but good to document such invariance and like @Hzfengsy commented, we can fix after target hash is supported
after reading the comments so far on host/device function info split and the compiler phases:
After we enable the compiler to handle device function, one thing we first need to ensure is what is the behavior after S1. Would be useful to clarify in the PR with comments. Summarizing the logic so far:
Here is an example of such case: @I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.float32, b: T.float32) -> T.float32:
return a + b
@T.prim_func
def main(
A: T.Buffer((1024, 1024), "float32"),
B: T.Buffer((1024, 1024), "float32"),
C: T.Buffer((1024, 1024), "float32"),
):
# bound temp var in host side
temp_var = T.float32()
with T.LetStmt(
Module.add(T.float32(1), T.float32(2))
var=temp_var,
):
for bx in T.thread_binding(1024, "blockIdx.x"):
for tx in T.thread_binding(1024, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) + temp_var Because of the implicitness, we may need to cross check the current bebehavior of SplitHostDevice, for rare cases where say both host and device calls the same function: in such cases we may either In both cases, would be good to enhance splithostdevice testcases to ensure target field is clear after S1 |
This commit adds support for CUDA device function calls by:
The changes enable proper compilation and execution of CUDA device functions that can be called from CUDA kernels.
Example: