File tree Expand file tree Collapse file tree 3 files changed +6
-8
lines changed 
mmseg/models/decode_heads Expand file tree Collapse file tree 3 files changed +6
-8
lines changed Original file line number Diff line number Diff line change 88from  mmengine .model  import  BaseModule 
99from  torch  import  Tensor 
1010
11+ from  mmseg .registry  import  MODELS 
1112from  mmseg .structures  import  build_pixel_sampler 
1213from  mmseg .utils  import  ConfigType , SampleList 
13- from  ..builder  import  build_loss 
1414from  ..losses  import  accuracy 
1515from  ..utils  import  resize 
1616
@@ -140,11 +140,11 @@ def __init__(self,
140140        self .threshold  =  threshold 
141141
142142        if  isinstance (loss_decode , dict ):
143-             self .loss_decode  =  build_loss (loss_decode )
143+             self .loss_decode  =  MODELS . build (loss_decode )
144144        elif  isinstance (loss_decode , (list , tuple )):
145145            self .loss_decode  =  nn .ModuleList ()
146146            for  loss  in  loss_decode :
147-                 self .loss_decode .append (build_loss (loss ))
147+                 self .loss_decode .append (MODELS . build (loss ))
148148        else :
149149            raise  TypeError (f'loss_decode must be a dict or sequence of dict,\  
150150                 but got { type (loss_decode )}  ' )
Original file line number Diff line number Diff line change 99
1010from  mmseg .registry  import  MODELS 
1111from  mmseg .utils  import  ConfigType , SampleList 
12- from  ..builder  import  build_loss 
1312from  ..utils  import  Encoding , resize 
1413from  .decode_head  import  BaseDecodeHead 
1514
@@ -128,7 +127,7 @@ def __init__(self,
128127            norm_cfg = self .norm_cfg ,
129128            act_cfg = self .act_cfg )
130129        if  self .use_se_loss :
131-             self .loss_se_decode  =  build_loss (loss_se_decode )
130+             self .loss_se_decode  =  MODELS . build (loss_se_decode )
132131            self .se_layer  =  nn .Linear (self .channels , self .num_classes )
133132
134133    def  forward (self , inputs ):
Original file line number Diff line number Diff line change 1010
1111from  mmseg .registry  import  MODELS 
1212from  mmseg .utils  import  SampleList 
13- from  ..builder  import  build_loss 
1413from  ..utils  import  resize 
1514from  .decode_head  import  BaseDecodeHead 
1615
@@ -184,11 +183,11 @@ def __init__(
184183
185184        # build loss 
186185        if  isinstance (loss_decode , dict ):
187-             self .loss_decode  =  build_loss (loss_decode )
186+             self .loss_decode  =  MODELS . build (loss_decode )
188187        elif  isinstance (loss_decode , (list , tuple )):
189188            self .loss_decode  =  nn .ModuleList ()
190189            for  loss  in  loss_decode :
191-                 self .loss_decode .append (build_loss (loss ))
190+                 self .loss_decode .append (MODELS . build (loss ))
192191        else :
193192            raise  TypeError (f'loss_decode must be a dict or sequence of dict,\  
194193                 but got { type (loss_decode )}  ' )
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments