58
58
# Accelerator architectures
59
59
CPU = "cpu"
60
60
CPU_AARCH64 = "cpu-aarch64"
61
+ CUDA_AARCH64 = "cuda-aarch64"
61
62
CUDA = "cuda"
62
63
ROCM = "rocm"
63
64
80
81
LINUX_GPU_RUNNER = "linux.g5.4xlarge.nvidia.gpu"
81
82
LINUX_CPU_RUNNER = "linux.2xlarge"
82
83
LINUX_AARCH64_RUNNER = "linux.arm64.2xlarge"
84
+ LINUX_AARCH64_GPU_RUNNER = "linux.arm64.m7g.4xlarge"
83
85
WIN_GPU_RUNNER = "windows.8xlarge.nvidia.gpu"
84
86
WIN_CPU_RUNNER = "windows.4xlarge"
85
87
MACOS_M1_RUNNER = "macos-m1-stable"
@@ -103,6 +105,8 @@ def arch_type(arch_version: str) -> str:
103
105
return ROCM
104
106
elif arch_version == CPU_AARCH64 :
105
107
return CPU_AARCH64
108
+ elif arch_version == CUDA_AARCH64 :
109
+ return CUDA_AARCH64
106
110
else : # arch_version should always be CPU in this case
107
111
return CPU
108
112
@@ -114,7 +118,10 @@ def validation_runner(arch_type: str, os: str) -> str:
114
118
else :
115
119
return LINUX_CPU_RUNNER
116
120
elif os == LINUX_AARCH64 :
117
- return LINUX_AARCH64_RUNNER
121
+ if arch_type == CUDA_AARCH64 :
122
+ return LINUX_AARCH64_GPU_RUNNER
123
+ else :
124
+ return LINUX_AARCH64_RUNNER
118
125
elif os == WINDOWS :
119
126
if arch_type == CUDA :
120
127
return WIN_GPU_RUNNER
@@ -154,6 +161,7 @@ def initialize_globals(channel: str, build_python_only: bool) -> None:
154
161
},
155
162
CPU : "pytorch/manylinux-builder:cpu" ,
156
163
CPU_AARCH64 : "pytorch/manylinuxaarch64-builder:cpu-aarch64" ,
164
+ CUDA_AARCH64 : "pytorch/manylinuxaarch64-builder:cuda12.4" ,
157
165
}
158
166
CONDA_CONTAINER_IMAGES = {
159
167
** {
@@ -188,6 +196,7 @@ def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
188
196
return {
189
197
CPU : "cpu" ,
190
198
CPU_AARCH64 : CPU ,
199
+ CUDA_AARCH64 : "cu124" ,
191
200
CUDA : f"cu{ gpu_arch_version .replace ('.' , '' )} " ,
192
201
ROCM : f"rocm{ gpu_arch_version } " ,
193
202
}.get (gpu_arch_type , gpu_arch_version )
@@ -490,7 +499,7 @@ def generate_wheels_matrix(
490
499
if os == LINUX_AARCH64 :
491
500
# Only want the one arch as the CPU type is different and
492
501
# uses different build/test scripts
493
- arches = [CPU_AARCH64 ]
502
+ arches = [CPU_AARCH64 , CUDA_AARCH64 ]
494
503
495
504
if with_cuda == ENABLE :
496
505
upload_to_base_bucket = "no"
0 commit comments