|
5 | 5 | import torch.nn as nn |
6 | 6 | import torch.nn.functional as F |
7 | 7 | from mmcv.cnn import ConvModule |
| 8 | +from mmengine.device import get_device |
8 | 9 |
|
9 | 10 | from mmseg.registry import MODELS |
10 | 11 | from ..utils import resize |
@@ -52,7 +53,7 @@ def __init__(self, |
52 | 53 |
|
53 | 54 | self.rand_init = rand_init |
54 | 55 |
|
55 | | - def _build_bases(self, B, S, D, R, cuda=False): |
| 56 | + def _build_bases(self, B, S, D, R, device=None): |
56 | 57 | raise NotImplementedError |
57 | 58 |
|
58 | 59 | def local_step(self, x, bases, coef): |
@@ -80,14 +81,13 @@ def forward(self, x, return_bases=False): |
80 | 81 | D = C // self.S |
81 | 82 | N = H * W |
82 | 83 | x = x.view(B * self.S, D, N) |
83 | | - cuda = 'cuda' in str(x.device) |
84 | 84 | if not self.rand_init and not hasattr(self, 'bases'): |
85 | | - bases = self._build_bases(1, self.S, D, self.R, cuda=cuda) |
| 85 | + bases = self._build_bases(1, self.S, D, self.R, device=x.device) |
86 | 86 | self.register_buffer('bases', bases) |
87 | 87 |
|
88 | 88 | # (S, D, R) -> (B * S, D, R) |
89 | 89 | if self.rand_init: |
90 | | - bases = self._build_bases(B, self.S, D, self.R, cuda=cuda) |
| 90 | + bases = self._build_bases(B, self.S, D, self.R, device=x.device) |
91 | 91 | else: |
92 | 92 | bases = self.bases.repeat(B, 1, 1) |
93 | 93 |
|
@@ -116,13 +116,11 @@ def __init__(self, args=dict()): |
116 | 116 |
|
117 | 117 | self.inv_t = 1 |
118 | 118 |
|
119 | | - def _build_bases(self, B, S, D, R, cuda=False): |
| 119 | + def _build_bases(self, B, S, D, R, device=None): |
120 | 120 | """Build bases in initialization.""" |
121 | | - if cuda: |
122 | | - bases = torch.rand((B * S, D, R)).cuda() |
123 | | - else: |
124 | | - bases = torch.rand((B * S, D, R)) |
125 | | - |
| 121 | + if device is None: |
| 122 | + device = get_device() |
| 123 | + bases = torch.rand((B * S, D, R)).to(device) |
126 | 124 | bases = F.normalize(bases, dim=1) |
127 | 125 |
|
128 | 126 | return bases |
|
0 commit comments