Skip to content

Commit 701d826

Browse files
authored
Add Ascend NPU Support (#1521)
* Add Ascend NPU support for generate and chat * update * Use torch.accelerator for device selection * Modify npu nightly link * Fix device selection issues
1 parent ecdb4e3 commit 701d826

File tree

7 files changed

+42
-30
lines changed

7 files changed

+42
-30
lines changed

install/install_torch.sh

+7
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ then
6666
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
6767
#torchtune=="0.7.0" # no 0.6.0 on xpu nightly
6868
)
69+
elif [[ -x "$(command -v npu-smi)" ]];
70+
then
71+
REQUIREMENTS_TO_INSTALL=(
72+
torch=="2.7.0.dev20250310+cpu"
73+
torchvision=="0.22.0.dev20250310"
74+
torchtune=="0.6.0"
75+
)
6976
else
7077
REQUIREMENTS_TO_INSTALL=(
7178
torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}"

torchchat/cli/builder.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torchchat.utils.build_utils import (
3030
device_sync,
3131
is_cpu_device,
32-
is_cuda_or_cpu_or_xpu_device,
32+
is_supported_device,
3333
name_to_dtype,
3434
)
3535
from torchchat.utils.measure_time import measure_time
@@ -74,10 +74,8 @@ class BuilderArgs:
7474

7575
def __post_init__(self):
7676
if self.device is None:
77-
if torch.cuda.is_available():
78-
self.device = "cuda"
79-
elif torch.xpu.is_available():
80-
self.device = "xpu"
77+
if torch.accelerator.is_available():
78+
self.device = torch.accelerator.current_accelerator().type
8179
else:
8280
self.device = "cpu"
8381

@@ -539,7 +537,7 @@ def _initialize_model(
539537
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")
540538

541539
if builder_args.dso_path:
542-
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
540+
if not is_supported_device(builder_args.device):
543541
print(
544542
f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead"
545543
)
@@ -573,7 +571,7 @@ def do_nothing(max_batch_size, max_seq_length):
573571
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
574572

575573
elif builder_args.aoti_package_path:
576-
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
574+
if not is_supported_device(builder_args.device):
577575
print(
578576
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
579577
)

torchchat/cli/cli.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def _add_model_config_args(parser, verb: str) -> None:
176176
"--device",
177177
type=str,
178178
default=None,
179-
choices=["fast", "cpu", "cuda", "mps", "xpu"],
180-
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
179+
choices=["fast", "cpu", "cuda", "mps", "xpu", "npu"],
180+
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu, npu",
181181
)
182182
model_config_parser.add_argument(
183183
"--attention-backend",

torchchat/generate.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1213,8 +1213,10 @@ def callback(x, *, done_generating=False):
12131213
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
12141214
elif self.builder_args.device == "cuda":
12151215
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
1216-
else:
1216+
elif self.builder_args.device == "xpu":
12171217
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
1218+
elif self.builder_args.device == "npu":
1219+
print(prof.key_averages().table(sort_by="self_npu_time_total"))
12181220
prof.export_chrome_trace(f"{self.profile}.json")
12191221

12201222
if start_pos >= max_seq_length:
@@ -1299,8 +1301,10 @@ def callback(x, *, done_generating=False):
12991301
)
13001302
if torch.cuda.is_available():
13011303
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
1302-
if torch.xpu.is_available():
1304+
elif torch.xpu.is_available():
13031305
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")
1306+
elif hasattr(torch, "npu") and torch.npu.is_available():
1307+
print(f"Memory used: {torch.npu.max_memory_reserved() / 1e9:.02f} GB")
13041308

13051309

13061310

@@ -1595,7 +1599,6 @@ def sample(
15951599

15961600
return idx_next, probs
15971601

1598-
15991602
def run_generator(
16001603
args,
16011604
rank: Optional[int] =None
@@ -1628,8 +1631,10 @@ def run_generator(
16281631
)
16291632
if torch.cuda.is_available():
16301633
torch.cuda.reset_peak_memory_stats()
1631-
if torch.xpu.is_available():
1634+
elif torch.xpu.is_available():
16321635
torch.xpu.reset_peak_memory_stats()
1636+
elif hasattr(torch, "npu") and torch.npu.is_available():
1637+
torch.npu.reset_peak_memory_stats()
16331638

16341639
for _ in gen.chat(generator_args):
16351640
pass

torchchat/utils/build_utils.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def device_sync(device="cpu"):
233233
torch.cuda.synchronize(device)
234234
elif "xpu" in device:
235235
torch.xpu.synchronize(device)
236+
elif "npu" in device:
237+
torch.npu.synchronize(device)
236238
elif ("cpu" in device) or ("mps" in device):
237239
pass
238240
else:
@@ -275,33 +277,32 @@ def is_mps_available() -> bool:
275277
# MPS, is that you?
276278
return True
277279

280+
def select_device() -> str:
281+
if torch.accelerator.is_available():
282+
device = torch.accelerator.current_accelerator().type
283+
if device == "mps" and not is_mps_available():
284+
return "cpu"
285+
return device
286+
else:
287+
return "cpu"
278288

279289
def get_device_str(device) -> str:
280290
if isinstance(device, str) and device == "fast":
281-
device = (
282-
"cuda"
283-
if torch.cuda.is_available()
284-
else "mps" if is_mps_available()
285-
else "xpu" if torch.xpu.is_available() else "cpu"
286-
)
291+
device = select_device()
287292
return device
288293
else:
289294
return str(device)
290295

291296

292297
def get_device(device) -> str:
293298
if isinstance(device, str) and device == "fast":
294-
device = (
295-
"cuda"
296-
if torch.cuda.is_available()
297-
else "mps" if is_mps_available()
298-
else "xpu" if torch.xpu.is_available() else "cpu"
299-
)
299+
device = select_device()
300300
return torch.device(device)
301301

302302

303303
def is_cpu_device(device) -> bool:
304304
return device == "" or str(device) == "cpu"
305305

306-
def is_cuda_or_cpu_or_xpu_device(device) -> bool:
307-
return is_cpu_device(device) or ("cuda" in str(device)) or ("xpu" in str(device))
306+
def is_supported_device(device) -> bool:
307+
device_str = str(device)
308+
return is_cpu_device(device) or any(dev in device_str for dev in ('cuda', 'xpu', 'npu'))

torchchat/utils/device_info.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99

1010
import torch
1111

12-
1312
def get_device_info(device: str) -> str:
1413
"""Returns a human-readable description of the hardware based on a torch.device.type
1514
1615
Args:
17-
device: A torch.device.type string: one of {"cpu", "cuda", "xpu"}.
16+
device: A torch.device.type string: one of {"cpu", "cuda", "xpu", "npu"}.
1817
Returns:
1918
str: A human-readable description of the hardware or an empty string if the device type is unhandled.
2019
@@ -46,4 +45,6 @@ def get_device_info(device: str) -> str:
4645
.split("\n")[0]
4746
.split("Device Name:")[1]
4847
)
48+
if device == "npu":
49+
return torch.npu.get_device_name(0)
4950
return ""

torchchat/utils/quantize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def quantize_model(
123123
raise RuntimeError(f"unknown quantizer {quantizer} specified")
124124
else:
125125
# Use tensor subclass API for int4 weight only.
126-
if (device == "cuda" or device == "xpu") and quantizer == "linear:int4":
126+
if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4":
127127
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
128128
if not support_tensor_subclass:
129129
unwrap_tensor_subclass(model)

0 commit comments

Comments
 (0)