Skip to content

Commit 9130ab3

Browse files
ptrblckfacebook-github-bot
authored andcommitted
fix gemm call for CUDABlas for THCUNN conv, #23545 (#23552)
Summary: * Swapped `CUBLAS_OP_N` for `'n'` * added a test This PR should fix #23545. Thanks at AlphabetMan for reporting the initial issue reported in [the forum](https://discuss.pytorch.org/t/cuda-10-1-error-using-transposeconv2d-with-output-padding-1/51414?u=ptrblck) as well as ngimel for the guidance. Pull Request resolved: #23552 Differential Revision: D16580986 Pulled By: ezyang fbshipit-source-id: abc0bce1e84d9c9d96d44ae0296951725adc8424
1 parent 5d130e4 commit 9130ab3

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)) {
315315
incy == 1,
316316
"at::cuda::blas::gemv<Half>: support for incy != 1 not implemented");
317317
gemm<at::Half>(
318-
stream, trans, CUBLAS_OP_N, m, 1, n, alpha, a, n, x, n, beta, y, m);
318+
stream, trans, 'n', m, 1, n, alpha, a, n, x, n, beta, y, m);
319319
}
320320

321321
} // namespace blas

test/test_nn.py

+9
Original file line numberDiff line numberDiff line change
@@ -5289,6 +5289,15 @@ def test_ConvTranspose3d_correct_output_size(self):
52895289
i = torch.rand(1, 2, 1, 1, 1)
52905290
out = m(i, output_size=(1, 2, 2, 2, 2))
52915291

5292+
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
5293+
def test_ConvTranspose2d_half_cublas_gemm(self):
5294+
with torch.backends.cudnn.flags(enabled=False):
5295+
inputs = torch.randn(1, 1, 16, 16, device='cuda', dtype=torch.half)
5296+
deconv = nn.ConvTranspose2d(
5297+
1, 1, 3, stride=2, padding=1, output_padding=1).cuda().half()
5298+
output = deconv(inputs)
5299+
output.mean().backward()
5300+
52925301
def _test_Conv2d_naive_groups(self, device="cpu", dtype=torch.float):
52935302
# Check that grouped convolutions matches two half convolutions
52945303
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)

0 commit comments

Comments
 (0)