Skip to content

Commit cdaf84a

Browse files
authored
TorchAO compile + offloading tests (#11697)
* update * update * update * update * update * user property instead
1 parent e8e44a5 commit cdaf84a

File tree

4 files changed

+85
-23
lines changed

4 files changed

+85
-23
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -866,15 +866,17 @@ def test_fp4_double_safe(self):
866866

867867
@require_torch_version_greater("2.7.1")
868868
class Bnb4BitCompileTests(QuantCompileTests):
869-
quantization_config = PipelineQuantizationConfig(
870-
quant_backend="bitsandbytes_8bit",
871-
quant_kwargs={
872-
"load_in_4bit": True,
873-
"bnb_4bit_quant_type": "nf4",
874-
"bnb_4bit_compute_dtype": torch.bfloat16,
875-
},
876-
components_to_quantize=["transformer", "text_encoder_2"],
877-
)
869+
@property
870+
def quantization_config(self):
871+
return PipelineQuantizationConfig(
872+
quant_backend="bitsandbytes_8bit",
873+
quant_kwargs={
874+
"load_in_4bit": True,
875+
"bnb_4bit_quant_type": "nf4",
876+
"bnb_4bit_compute_dtype": torch.bfloat16,
877+
},
878+
components_to_quantize=["transformer", "text_encoder_2"],
879+
)
878880

879881
def test_torch_compile(self):
880882
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -883,5 +885,7 @@ def test_torch_compile(self):
883885
def test_torch_compile_with_cpu_offload(self):
884886
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
885887

886-
def test_torch_compile_with_group_offload(self):
887-
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
888+
def test_torch_compile_with_group_offload_leaf(self):
889+
super()._test_torch_compile_with_group_offload_leaf(
890+
quantization_config=self.quantization_config, use_stream=True
891+
)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -831,11 +831,13 @@ def test_serialization_sharded(self):
831831

832832
@require_torch_version_greater_equal("2.6.0")
833833
class Bnb8BitCompileTests(QuantCompileTests):
834-
quantization_config = PipelineQuantizationConfig(
835-
quant_backend="bitsandbytes_8bit",
836-
quant_kwargs={"load_in_8bit": True},
837-
components_to_quantize=["transformer", "text_encoder_2"],
838-
)
834+
@property
835+
def quantization_config(self):
836+
return PipelineQuantizationConfig(
837+
quant_backend="bitsandbytes_8bit",
838+
quant_kwargs={"load_in_8bit": True},
839+
components_to_quantize=["transformer", "text_encoder_2"],
840+
)
839841

840842
def test_torch_compile(self):
841843
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -847,7 +849,7 @@ def test_torch_compile_with_cpu_offload(self):
847849
)
848850

849851
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
850-
def test_torch_compile_with_group_offload(self):
851-
super()._test_torch_compile_with_group_offload(
852-
quantization_config=self.quantization_config, torch_dtype=torch.float16
852+
def test_torch_compile_with_group_offload_leaf(self):
853+
super()._test_torch_compile_with_group_offload_leaf(
854+
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
853855
)

tests/quantization/test_torch_compile_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
@require_torch_gpu
2525
@slow
2626
class QuantCompileTests(unittest.TestCase):
27-
quantization_config = None
27+
@property
28+
def quantization_config(self):
29+
raise NotImplementedError(
30+
"This property should be implemented in the subclass to return the appropriate quantization config."
31+
)
2832

2933
def setUp(self):
3034
super().setUp()
@@ -64,16 +68,17 @@ def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=
6468
# small resolutions to ensure speedy execution.
6569
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
6670

67-
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
71+
def _test_torch_compile_with_group_offload_leaf(
72+
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
73+
):
6874
torch._dynamo.config.cache_size_limit = 10000
6975

7076
pipe = self._init_pipeline(quantization_config, torch_dtype)
7177
group_offload_kwargs = {
7278
"onload_device": torch.device("cuda"),
7379
"offload_device": torch.device("cpu"),
7480
"offload_type": "leaf_level",
75-
"use_stream": True,
76-
"non_blocking": True,
81+
"use_stream": use_stream,
7782
}
7883
pipe.transformer.enable_group_offload(**group_offload_kwargs)
7984
pipe.transformer.compile()

tests/quantization/torchao/test_torchao.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import List
2020

2121
import numpy as np
22+
from parameterized import parameterized
2223
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
2324

2425
from diffusers import (
@@ -29,6 +30,7 @@
2930
TorchAoConfig,
3031
)
3132
from diffusers.models.attention_processor import Attention
33+
from diffusers.quantizers import PipelineQuantizationConfig
3234
from diffusers.utils.testing_utils import (
3335
backend_empty_cache,
3436
backend_synchronize,
@@ -44,6 +46,8 @@
4446
torch_device,
4547
)
4648

49+
from ..test_torch_compile_utils import QuantCompileTests
50+
4751

4852
enable_full_determinism()
4953

@@ -625,6 +629,53 @@ def test_int_a16w8_cpu(self):
625629
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
626630

627631

632+
@require_torchao_version_greater_or_equal("0.7.0")
633+
class TorchAoCompileTest(QuantCompileTests):
634+
@property
635+
def quantization_config(self):
636+
return PipelineQuantizationConfig(
637+
quant_mapping={
638+
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
639+
},
640+
)
641+
642+
def test_torch_compile(self):
643+
super()._test_torch_compile(quantization_config=self.quantization_config)
644+
645+
@unittest.skip(
646+
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
647+
"when compiling."
648+
)
649+
def test_torch_compile_with_cpu_offload(self):
650+
# RuntimeError: _apply(): Couldn't swap Linear.weight
651+
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
652+
653+
@unittest.skip(
654+
"""
655+
For `use_stream=False`:
656+
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
657+
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
658+
For `use_stream=True`:
659+
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
660+
"""
661+
)
662+
@parameterized.expand([False, True])
663+
def test_torch_compile_with_group_offload_leaf(self):
664+
# For use_stream=False:
665+
# If we run group offloading without compilation, we will see:
666+
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
667+
# When running with compilation, the error ends up being different:
668+
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
669+
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
670+
# Looks like something that will have to be looked into upstream.
671+
# for linear layers, weight.tensor_impl shows cuda... but:
672+
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
673+
674+
# For use_stream=True:
675+
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
676+
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
677+
678+
628679
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
629680
@require_torch
630681
@require_torch_accelerator

0 commit comments

Comments
 (0)