55import fnmatch
66
77import timm
8- from timm import list_models , create_model , set_scriptable
8+ from timm import list_models , create_model , set_scriptable , has_model_default_key , is_model_default_key , \
9+ get_model_default_value
910
1011if hasattr (torch ._C , '_jit_set_profiling_executor' ):
1112 # legacy executor is too slow to compile large models for unit tests
@@ -60,9 +61,15 @@ def test_model_backward(model_name, batch_size):
6061 model .eval ()
6162
6263 input_size = model .default_cfg ['input_size' ]
63- if any ([x > MAX_BWD_SIZE for x in input_size ]):
64- # cap backward test at 128 * 128 to keep resource usage down
65- input_size = tuple ([min (x , MAX_BWD_SIZE ) for x in input_size ])
64+ if not is_model_default_key (model_name , 'fixed_input_size' ):
65+ min_input_size = get_model_default_value (model_name , 'min_input_size' )
66+ if min_input_size is not None :
67+ input_size = min_input_size
68+ else :
69+ if any ([x > MAX_BWD_SIZE for x in input_size ]):
70+ # cap backward test at 128 * 128 to keep resource usage down
71+ input_size = tuple ([min (x , MAX_BWD_SIZE ) for x in input_size ])
72+
6673 inputs = torch .randn ((batch_size , * input_size ))
6774 outputs = model (inputs )
6875 outputs .mean ().backward ()
@@ -155,7 +162,14 @@ def test_model_forward_torchscript(model_name, batch_size):
155162 with set_scriptable (True ):
156163 model = create_model (model_name , pretrained = False )
157164 model .eval ()
158- input_size = (3 , 128 , 128 ) # jit compile is already a bit slow and we've tested normal res already...
165+
166+ if has_model_default_key (model_name , 'fixed_input_size' ):
167+ input_size = get_model_default_value (model_name , 'input_size' )
168+ elif has_model_default_key (model_name , 'min_input_size' ):
169+ input_size = get_model_default_value (model_name , 'min_input_size' )
170+ else :
171+ input_size = (3 , 128 , 128 ) # jit compile is already a bit slow and we've tested normal res already...
172+
159173 model = torch .jit .script (model )
160174 outputs = model (torch .randn ((batch_size , * input_size )))
161175
@@ -180,7 +194,14 @@ def test_model_forward_features(model_name, batch_size):
180194 model .eval ()
181195 expected_channels = model .feature_info .channels ()
182196 assert len (expected_channels ) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
183- input_size = (3 , 96 , 96 ) # jit compile is already a bit slow and we've tested normal res already...
197+
198+ if has_model_default_key (model_name , 'fixed_input_size' ):
199+ input_size = get_model_default_value (model_name , 'input_size' )
200+ elif has_model_default_key (model_name , 'min_input_size' ):
201+ input_size = get_model_default_value (model_name , 'min_input_size' )
202+ else :
203+ input_size = (3 , 96 , 96 ) # jit compile is already a bit slow and we've tested normal res already...
204+
184205 outputs = model (torch .randn ((batch_size , * input_size )))
185206 assert len (expected_channels ) == len (outputs )
186207 for e , o in zip (expected_channels , outputs ):
0 commit comments