Skip to content

Commit 9947a39

Browse files
authored
fix dnl_head export onnx inference difference type Cast error (open-mmlab#1161)
* fix export onnx inference difference type Cast error * fix export onnx inference difference type Cast error. * use yapf format * use same device type with pairwise_weight
1 parent 69d5cc5 commit 9947a39

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

mmseg/models/decode_heads/dnl_head.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ def embedded_gaussian(self, theta_x, phi_x):
2626
pairwise_weight = torch.matmul(theta_x, phi_x)
2727
if self.use_scale:
2828
# theta_x.shape[-1] is `self.inter_channels`
29-
pairwise_weight /= theta_x.shape[-1]**0.5
30-
pairwise_weight /= self.temperature
29+
pairwise_weight /= torch.tensor(
30+
theta_x.shape[-1],
31+
dtype=torch.float,
32+
device=pairwise_weight.device)**torch.tensor(
33+
0.5, device=pairwise_weight.device)
34+
pairwise_weight /= torch.tensor(
35+
self.temperature, device=pairwise_weight.device)
3136
pairwise_weight = pairwise_weight.softmax(dim=-1)
3237
return pairwise_weight
3338

0 commit comments

Comments
 (0)