diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 8037858151..acaeeb5ab2 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -213,22 +213,26 @@ def slice_scatter_decomposition( return src_tensor # Ensure start, end, and step are all integers - assert isinstance(start, int), "start must be an integer" - assert isinstance(end, int), "end must be an integer" - assert isinstance(step, int), "step must be an integer" - - cat_tensors = [] - index_tensor_shape = [] - for i, src_each_dim in enumerate(list(src_dim)): - if i != dim: - index_tensor_shape.append(src_each_dim) - for index in range(start, end, step): - cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64)) - index_tensor = torch.stack(cat_tensors, dim) - index_tensor = index_tensor.to(device_input_tensor) - index_tensor_64 = index_tensor.to(torch.int64) - output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor) - return output_tensor + # Ensure start, end, and step are all integers + assert isinstance(start, (int, torch.SymInt)), "start must be an int or SymInt" + assert isinstance(end, (int, torch.SymInt)), "end must be an int or SymInt" + assert isinstance(step, (int, torch.SymInt)), "step must be an int or SymInt" + + src_dim = src_tensor.shape + # step == 0 is not a valid torch case + # also src_dim should be equal to slice dimension + + if start == 0 and end == dim_size and step == 1: + return src_tensor + + indices = torch.arange( + start, end, step, device=device_input_tensor, dtype=torch.int64 + ) + index_tensor = indices.view( + [-1 if i == dim else 1 for i in range(input_tensor.dim())] + ) + index_tensor = index_tensor.expand_as(src_tensor) + return torch.scatter(input_tensor.clone(), dim, index_tensor, src_tensor) @register_torch_trt_decomposition( diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index b63e0f3bf7..9f0f53a4d8 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -812,6 +812,79 @@ def forward(self, x, src, dim, start, end, step): f"Slice_scatter TRT outputs don't match with the original model.", ) + def test_lowering_slice_scatter_dynamic_module(self): + class sliceScatter(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, src, dim, start=None, end=None, step=1): + y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = { + torch.ops.aten.scatter.src, + } + unexpected_ops = {torch.ops.aten.select_scatter} + + a = torch.zeros(8, 8).cuda() + b = torch.ones(8, 2).cuda() + + # 0-D tensors for dynamic scalar values + start = torch.tensor(1, dtype=torch.int64).cuda() + end = torch.tensor(6, dtype=torch.int64).cuda() + step = torch.tensor(1, dtype=torch.int64).cuda() + + # Mark scalar tensors as dynamic (note: shape = ()) + torch._dynamo.mark_dynamic(start, (), min=1, max=3) + torch._dynamo.mark_dynamic(end, (), min=4, max=6) + + inputs = (a, b, start, end, None, step) + fx_graph = torch.fx.symbolic_trace(sliceScatter()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Slice_scatter TRT outputs don't match with the original model.", + ) + def test_lowering_select_scatter_dimZero_module(self): class selectScatter(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: