Skip to content

Commit e218c9f

Browse files
committed
fix seemless fintetuning for timm simplenet
1 parent 26ec25d commit e218c9f

File tree

1 file changed

+12
-9
lines changed
  • ImageNet/training_scripts/imagenet_training/timm/models

1 file changed

+12
-9
lines changed

ImageNet/training_scripts/imagenet_training/timm/models/simplenet.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def get_classifier(self):
338338

339339
def reset_classifier(self, num_classes: int):
340340
self.num_classes = num_classes
341-
self.classifier = nn.Linear(round(self.cfg[self.networks[self.network_idx]][-1][1] * self.scale), num_classes)
341+
self.classifier = nn.Linear(round(self.cfg[self.networks[self.network_idx]][-1][0] * self.scale), num_classes)
342342

343343
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
344344
return self.features(x)
@@ -367,15 +367,18 @@ def _gen_simplenet(
367367
) -> SimpleNet:
368368

369369
model_args = dict(
370-
num_classes=num_classes,
371-
in_chans=in_chans,
372-
scale=scale,
373-
network_idx=network_idx,
374-
mode=mode,
375-
drop_rates=drop_rates,
376-
**kwargs,
370+
in_chans=in_chans, scale=scale, network_idx=network_idx, mode=mode, drop_rates=drop_rates, **kwargs,
377371
)
372+
# to allow for seemless finetuning, remove the num_classes
373+
# and load the model intact, we apply the changes afterward!
374+
if "num_classes" in kwargs:
375+
kwargs.pop("num_classes")
378376
model = build_model_with_cfg(SimpleNet, model_variant, pretrained, **model_args)
377+
# if the num_classes is different than imagenet's, it
378+
# means its going to be finetuned, so only create a
379+
# new classifier after the whole model is loaded!
380+
if num_classes != 1000:
381+
model.reset_classifier(num_classes)
379382
return model
380383

381384

@@ -436,7 +439,7 @@ def remove_network_settings(kwargs: Dict[str, Any]) -> Dict[str, Any]:
436439
Returns:
437440
Dict[str,Any]: cleaned kwargs
438441
"""
439-
model_args = {k: v for k, v in kwargs.items() if k not in ["scale", "network_idx", "mode","drop_rate"]}
442+
model_args = {k: v for k, v in kwargs.items() if k not in ["scale", "network_idx", "mode", "drop_rate"]}
440443
return model_args
441444

442445

0 commit comments

Comments
 (0)