@@ -400,6 +400,7 @@ def __init__(
400400 self .feature_info = []
401401
402402 self .use_cls_token = use_cls_token
403+ self .global_pool = 'token' if use_cls_token else 'avg'
403404
404405 dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
405406
@@ -448,6 +449,21 @@ def __init__(
448449 self .head = nn .Linear (dims [- 1 ], num_classes ) if num_classes > 0 else nn .Identity ()
449450
450451
452+
453+ @torch .jit .ignore
454+ def get_classifier (self ) -> nn .Module :
455+ return self .head
456+
457+ def reset_classifier (self , num_classes : int , global_pool = None ) -> None :
458+ self .num_classes = num_classes
459+ if global_pool is not None :
460+ assert global_pool in ('' , 'avg' , 'token' )
461+ if global_pool == 'token' and not self .use_cls_token :
462+ assert False , 'Model not configured to use class token'
463+ self .global_pool = global_pool
464+ self .head = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
465+
466+
451467 def _forward_features (self , x : torch .Tensor ) -> torch .Tensor :
452468 # nn.Sequential forward can't accept tuple intermediates
453469 # TODO grad checkpointing
@@ -457,12 +473,13 @@ def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
457473 return x
458474
459475 def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
476+ # get feature map, not always used
460477 x = self ._forward_features (x )
461478
462479 return x [0 ] if self .use_cls_token else x
463480
464481 def forward_head (self , x : torch .Tensor ) -> torch .Tensor :
465- if self .use_cls_token :
482+ if self .global_pool == 'token' :
466483 return self .head (self .norm (x [1 ].flatten (1 )))
467484 else :
468485 return self .head (self .norm (x .mean (dim = (2 ,3 ))))
0 commit comments