Skip to content

Commit ced29fc

Browse files
authored
[Refactor] Handle case where device is neither CPU nor CUDA in HamHead (#2868)
1 parent 969f504 commit ced29fc

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

mmseg/models/decode_heads/ham_head.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn as nn
66
import torch.nn.functional as F
77
from mmcv.cnn import ConvModule
8+
from mmengine.device import get_device
89

910
from mmseg.registry import MODELS
1011
from ..utils import resize
@@ -52,7 +53,7 @@ def __init__(self,
5253

5354
self.rand_init = rand_init
5455

55-
def _build_bases(self, B, S, D, R, cuda=False):
56+
def _build_bases(self, B, S, D, R, device=None):
5657
raise NotImplementedError
5758

5859
def local_step(self, x, bases, coef):
@@ -80,14 +81,13 @@ def forward(self, x, return_bases=False):
8081
D = C // self.S
8182
N = H * W
8283
x = x.view(B * self.S, D, N)
83-
cuda = 'cuda' in str(x.device)
8484
if not self.rand_init and not hasattr(self, 'bases'):
85-
bases = self._build_bases(1, self.S, D, self.R, cuda=cuda)
85+
bases = self._build_bases(1, self.S, D, self.R, device=x.device)
8686
self.register_buffer('bases', bases)
8787

8888
# (S, D, R) -> (B * S, D, R)
8989
if self.rand_init:
90-
bases = self._build_bases(B, self.S, D, self.R, cuda=cuda)
90+
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
9191
else:
9292
bases = self.bases.repeat(B, 1, 1)
9393

@@ -116,13 +116,11 @@ def __init__(self, args=dict()):
116116

117117
self.inv_t = 1
118118

119-
def _build_bases(self, B, S, D, R, cuda=False):
119+
def _build_bases(self, B, S, D, R, device=None):
120120
"""Build bases in initialization."""
121-
if cuda:
122-
bases = torch.rand((B * S, D, R)).cuda()
123-
else:
124-
bases = torch.rand((B * S, D, R))
125-
121+
if device is None:
122+
device = get_device()
123+
bases = torch.rand((B * S, D, R)).to(device)
126124
bases = F.normalize(bases, dim=1)
127125

128126
return bases

0 commit comments

Comments
 (0)