Skip to content

Commit c233a91

Browse files
committed
Fix custom ops bug for pytorch 1.12 and onwards
Adapt to newer _jit_get_operation API that changed in pytorch/pytorch#76814 for #188, #193
1 parent 407db86 commit c233a91

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torch_utils/ops/grid_sample_gradfix.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
enabled = False # Enable the custom op by setting this to true.
2424
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25+
_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12
2526

2627
#----------------------------------------------------------------------------
2728

@@ -58,6 +59,8 @@ class _GridSample2dBackward(torch.autograd.Function):
5859
@staticmethod
5960
def forward(ctx, grad_output, input, grid):
6061
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
62+
if _use_pytorch_1_12_api:
63+
op = op[0]
6164
if _use_pytorch_1_11_api:
6265
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
6366
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)

0 commit comments

Comments
 (0)