Skip to content

Commit b623e13

Browse files
authored
[aarch64] add cuda aarch64 torchvision and torchaudio (#5387)
add cuda aarch64 build for torchvision and torchaudio
1 parent 4432e2c commit b623e13

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

tools/scripts/generate_binary_build_matrix.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
# Accelerator architectures
5959
CPU = "cpu"
6060
CPU_AARCH64 = "cpu-aarch64"
61+
CUDA_AARCH64 = "cuda-aarch64"
6162
CUDA = "cuda"
6263
ROCM = "rocm"
6364

@@ -80,6 +81,7 @@
8081
LINUX_GPU_RUNNER = "linux.g5.4xlarge.nvidia.gpu"
8182
LINUX_CPU_RUNNER = "linux.2xlarge"
8283
LINUX_AARCH64_RUNNER = "linux.arm64.2xlarge"
84+
LINUX_AARCH64_GPU_RUNNER = "linux.arm64.m7g.4xlarge"
8385
WIN_GPU_RUNNER = "windows.8xlarge.nvidia.gpu"
8486
WIN_CPU_RUNNER = "windows.4xlarge"
8587
MACOS_M1_RUNNER = "macos-m1-stable"
@@ -103,6 +105,8 @@ def arch_type(arch_version: str) -> str:
103105
return ROCM
104106
elif arch_version == CPU_AARCH64:
105107
return CPU_AARCH64
108+
elif arch_version == CUDA_AARCH64:
109+
return CUDA_AARCH64
106110
else: # arch_version should always be CPU in this case
107111
return CPU
108112

@@ -114,7 +118,10 @@ def validation_runner(arch_type: str, os: str) -> str:
114118
else:
115119
return LINUX_CPU_RUNNER
116120
elif os == LINUX_AARCH64:
117-
return LINUX_AARCH64_RUNNER
121+
if arch_type == CUDA_AARCH64:
122+
return LINUX_AARCH64_GPU_RUNNER
123+
else:
124+
return LINUX_AARCH64_RUNNER
118125
elif os == WINDOWS:
119126
if arch_type == CUDA:
120127
return WIN_GPU_RUNNER
@@ -154,6 +161,7 @@ def initialize_globals(channel: str, build_python_only: bool) -> None:
154161
},
155162
CPU: "pytorch/manylinux-builder:cpu",
156163
CPU_AARCH64: "pytorch/manylinuxaarch64-builder:cpu-aarch64",
164+
CUDA_AARCH64: "pytorch/manylinuxaarch64-builder:cuda12.4",
157165
}
158166
CONDA_CONTAINER_IMAGES = {
159167
**{
@@ -188,6 +196,7 @@ def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
188196
return {
189197
CPU: "cpu",
190198
CPU_AARCH64: CPU,
199+
CUDA_AARCH64: "cu124",
191200
CUDA: f"cu{gpu_arch_version.replace('.', '')}",
192201
ROCM: f"rocm{gpu_arch_version}",
193202
}.get(gpu_arch_type, gpu_arch_version)
@@ -490,7 +499,7 @@ def generate_wheels_matrix(
490499
if os == LINUX_AARCH64:
491500
# Only want the one arch as the CPU type is different and
492501
# uses different build/test scripts
493-
arches = [CPU_AARCH64]
502+
arches = [CPU_AARCH64, CUDA_AARCH64]
494503

495504
if with_cuda == ENABLE:
496505
upload_to_base_bucket = "no"

0 commit comments

Comments
 (0)