Skip to content

[Target] Support CUDA device function calls #18055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Hzfengsy
Copy link
Member

This commit adds support for CUDA device function calls by:

  1. Modifying the calling convention handling in CUDA codegen to support both device kernel launches and device function calls
  2. Updating the function signature printing to emit appropriate CUDA attributes (global vs device) based on calling convention
  3. Adding a test case demonstrating device function calls
  4. Fixing target handling in split_host_device_mods to properly handle device function dictionaries
  5. Adding a safety check for global symbol extraction

The changes enable proper compilation and execution of CUDA device functions that can be called from CUDA kernels.

Example:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(a: T.float32, b: T.float32) -> T.float32:
        return a + b

    @T.prim_func
    def main(
        A: T.Buffer((1024, 1024), "float32"),
        B: T.Buffer((1024, 1024), "float32"),
        C: T.Buffer((1024, 1024), "float32"),
    ):
        for bx in T.thread_binding(1024, "blockIdx.x"):
            for tx in T.thread_binding(1024, "threadIdx.x"):
                C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])

@LeiWang1999
Copy link
Contributor

super helpful enhancement, Thanks!

@Hzfengsy Hzfengsy force-pushed the device_call branch 2 times, most recently from f4db472 to 4bfdbf7 Compare June 13, 2025 03:50
This commit adds support for CUDA device function calls by:
1. Modifying the calling convention handling in CUDA codegen to support both
   device kernel launches and device function calls
2. Updating the function signature printing to emit appropriate CUDA attributes
   (__global__ vs __device__) based on calling convention
3. Adding a test case demonstrating device function calls
4. Fixing target handling in split_host_device_mods to properly handle device
   function dictionaries
5. Adding a safety check for global symbol extraction

The changes enable proper compilation and execution of CUDA device functions
that can be called from CUDA kernels.

Example:
```python
@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(a: T.float32, b: T.float32) -> T.float32:
        return a + b

    @T.prim_func
    def main(
        A: T.Buffer((1024, 1024), "float32"),
        B: T.Buffer((1024, 1024), "float32"),
        C: T.Buffer((1024, 1024), "float32"),
    ):
        for bx in T.thread_binding(1024, "blockIdx.x"):
            for tx in T.thread_binding(1024, "threadIdx.x"):
                C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])

```
return a + b

@T.prim_func
def main(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think right now it's quite implicit to determine which function is the kernel function and which function is the device function. It might be clearer if we can mark @T.prim_func(kind="device") explicitly for device functions.

Moreover, we can enhance by adding a test case where all functions are not wrapped by Module, and instead of compiling the Module, we compile the kernel function directly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting pt, added some followup comment on #18055 (comment) about distinctions before/after SplitHostDevice, maybe it is fine before SplitHostDevice(in this case), but would be good to clarify in comment

device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs)
target = func.attrs.get("target", None)
target_str = str(target) if target is not None else ""
target_str2target[target_str] = target # This might be overridden by the last one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to make sure in which cases different Target obects might have the same string representations target_str.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now str uniquely maps a target so it is ok here, but good to document such invariance and like @Hzfengsy commented, we can fix after target hash is supported

@tqchen
Copy link
Member

tqchen commented Jun 14, 2025

after reading the comments so far on host/device function info split and the compiler phases:

  • S0: In the beginning(before SplitHostDevice), we don't distinguish host/device function, a function can contain kernels
  • S1: The host/device function split becomes clear after the SplitHostDevice pass. currently in the case of single device launch:
    • global kernel are annotated as DeviceKernelLaunch calling conv
    • host ones are annotated as others

After we enable the compiler to handle device function, one thing we first need to ensure is what is the behavior after S1. Would be useful to clarify in the PR with comments.

Summarizing the logic so far:

  • Before S0 seems the decision is to not distinguish between host/device function and implicit
  • Such distinction should become clear after S1, by checking the target annotation of each function that marks the default convention.

Here is an example of such case:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(a: T.float32, b: T.float32) -> T.float32:
        return a + b

    @T.prim_func
    def main(
        A: T.Buffer((1024, 1024), "float32"),
        B: T.Buffer((1024, 1024), "float32"),
        C: T.Buffer((1024, 1024), "float32"),
    ):
        # bound temp var in host side
        temp_var = T.float32()
        with T.LetStmt(
                Module.add(T.float32(1), T.float32(2))
                var=temp_var,
         ):
          for bx in T.thread_binding(1024, "blockIdx.x"):
              for tx in T.thread_binding(1024, "threadIdx.x"):
                  C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) + temp_var

Because of the implicitness, we may need to cross check the current bebehavior of SplitHostDevice, for rare cases where say both host and device calls the same function: in such cases we may either
- S0a: place constraint and report an error
- S0b: have SplitHostDevice pass manually duplicate such function and mark the target

In both cases, would be good to enhance splithostdevice testcases to ensure target field is clear after S1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants