@@ -338,7 +338,7 @@ def get_classifier(self):
338
338
339
339
def reset_classifier (self , num_classes : int ):
340
340
self .num_classes = num_classes
341
- self .classifier = nn .Linear (round (self .cfg [self .networks [self .network_idx ]][- 1 ][1 ] * self .scale ), num_classes )
341
+ self .classifier = nn .Linear (round (self .cfg [self .networks [self .network_idx ]][- 1 ][0 ] * self .scale ), num_classes )
342
342
343
343
def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
344
344
return self .features (x )
@@ -367,15 +367,18 @@ def _gen_simplenet(
367
367
) -> SimpleNet :
368
368
369
369
model_args = dict (
370
- num_classes = num_classes ,
371
- in_chans = in_chans ,
372
- scale = scale ,
373
- network_idx = network_idx ,
374
- mode = mode ,
375
- drop_rates = drop_rates ,
376
- ** kwargs ,
370
+ in_chans = in_chans , scale = scale , network_idx = network_idx , mode = mode , drop_rates = drop_rates , ** kwargs ,
377
371
)
372
+ # to allow for seemless finetuning, remove the num_classes
373
+ # and load the model intact, we apply the changes afterward!
374
+ if "num_classes" in kwargs :
375
+ kwargs .pop ("num_classes" )
378
376
model = build_model_with_cfg (SimpleNet , model_variant , pretrained , ** model_args )
377
+ # if the num_classes is different than imagenet's, it
378
+ # means its going to be finetuned, so only create a
379
+ # new classifier after the whole model is loaded!
380
+ if num_classes != 1000 :
381
+ model .reset_classifier (num_classes )
379
382
return model
380
383
381
384
@@ -436,7 +439,7 @@ def remove_network_settings(kwargs: Dict[str, Any]) -> Dict[str, Any]:
436
439
Returns:
437
440
Dict[str,Any]: cleaned kwargs
438
441
"""
439
- model_args = {k : v for k , v in kwargs .items () if k not in ["scale" , "network_idx" , "mode" ,"drop_rate" ]}
442
+ model_args = {k : v for k , v in kwargs .items () if k not in ["scale" , "network_idx" , "mode" , "drop_rate" ]}
440
443
return model_args
441
444
442
445
0 commit comments