|
19 | 19 | from typing import List
|
20 | 20 |
|
21 | 21 | import numpy as np
|
| 22 | +from parameterized import parameterized |
22 | 23 | from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
23 | 24 |
|
24 | 25 | from diffusers import (
|
|
29 | 30 | TorchAoConfig,
|
30 | 31 | )
|
31 | 32 | from diffusers.models.attention_processor import Attention
|
| 33 | +from diffusers.quantizers import PipelineQuantizationConfig |
32 | 34 | from diffusers.utils.testing_utils import (
|
33 | 35 | backend_empty_cache,
|
34 | 36 | backend_synchronize,
|
|
44 | 46 | torch_device,
|
45 | 47 | )
|
46 | 48 |
|
| 49 | +from ..test_torch_compile_utils import QuantCompileTests |
| 50 | + |
47 | 51 |
|
48 | 52 | enable_full_determinism()
|
49 | 53 |
|
@@ -625,6 +629,53 @@ def test_int_a16w8_cpu(self):
|
625 | 629 | self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
626 | 630 |
|
627 | 631 |
|
| 632 | +@require_torchao_version_greater_or_equal("0.7.0") |
| 633 | +class TorchAoCompileTest(QuantCompileTests): |
| 634 | + @property |
| 635 | + def quantization_config(self): |
| 636 | + return PipelineQuantizationConfig( |
| 637 | + quant_mapping={ |
| 638 | + "transformer": TorchAoConfig(quant_type="int8_weight_only"), |
| 639 | + }, |
| 640 | + ) |
| 641 | + |
| 642 | + def test_torch_compile(self): |
| 643 | + super()._test_torch_compile(quantization_config=self.quantization_config) |
| 644 | + |
| 645 | + @unittest.skip( |
| 646 | + "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work " |
| 647 | + "when compiling." |
| 648 | + ) |
| 649 | + def test_torch_compile_with_cpu_offload(self): |
| 650 | + # RuntimeError: _apply(): Couldn't swap Linear.weight |
| 651 | + super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) |
| 652 | + |
| 653 | + @unittest.skip( |
| 654 | + """ |
| 655 | + For `use_stream=False`: |
| 656 | + - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation |
| 657 | + is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure. |
| 658 | + For `use_stream=True`: |
| 659 | + Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO. |
| 660 | + """ |
| 661 | + ) |
| 662 | + @parameterized.expand([False, True]) |
| 663 | + def test_torch_compile_with_group_offload_leaf(self): |
| 664 | + # For use_stream=False: |
| 665 | + # If we run group offloading without compilation, we will see: |
| 666 | + # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match. |
| 667 | + # When running with compilation, the error ends up being different: |
| 668 | + # Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16, |
| 669 | + # requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu') |
| 670 | + # Looks like something that will have to be looked into upstream. |
| 671 | + # for linear layers, weight.tensor_impl shows cuda... but: |
| 672 | + # weight.tensor_impl.{data,scale,zero_point}.device will be cpu |
| 673 | + |
| 674 | + # For use_stream=True: |
| 675 | + # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={} |
| 676 | + super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config) |
| 677 | + |
| 678 | + |
628 | 679 | # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
629 | 680 | @require_torch
|
630 | 681 | @require_torch_accelerator
|
|
0 commit comments