55import torch
66import torch .nn as nn
77from mmengine import ConfigDict
8- from torch .utils .data import DataLoader , Dataset
98
109from mmseg .apis import MMSegInferencer
1110from mmseg .models import EncoderDecoder
@@ -46,33 +45,8 @@ def __init__(self, **kwargs):
4645 super ().__init__ (** kwargs )
4746
4847
49- class ExampleDataset (Dataset ):
50-
51- def __init__ (self ) -> None :
52- super ().__init__ ()
53- self .pipeline = [
54- dict (type = 'LoadImageFromFile' ),
55- dict (type = 'LoadAnnotations' ),
56- dict (type = 'PackSegInputs' )
57- ]
58-
59- def __getitem__ (self , idx ):
60- return dict (img = torch .tensor ([1 ]), img_metas = dict ())
61-
62- def __len__ (self ):
63- return 1
64-
65-
6648def test_inferencer ():
6749 register_all_modules ()
68- test_dataset = ExampleDataset ()
69- data_loader = DataLoader (
70- test_dataset ,
71- batch_size = 1 ,
72- sampler = None ,
73- num_workers = 0 ,
74- shuffle = False ,
75- )
7650
7751 visualizer = dict (
7852 type = 'SegLocalVisualizer' ,
@@ -87,7 +61,14 @@ def test_inferencer():
8761 decode_head = dict (type = 'InferExampleHead' ),
8862 test_cfg = dict (mode = 'whole' )),
8963 visualizer = visualizer ,
90- test_dataloader = data_loader )
64+ test_dataloader = dict (
65+ dataset = dict (
66+ type = 'ExampleDataset' ,
67+ pipeline = [
68+ dict (type = 'LoadImageFromFile' ),
69+ dict (type = 'LoadAnnotations' ),
70+ dict (type = 'PackSegInputs' )
71+ ]), ))
9172 cfg = ConfigDict (cfg_dict )
9273 model = MODELS .build (cfg .model )
9374
0 commit comments