- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 5.1k
Add T2T_ViT #2426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add T2T_ViT #2426
Conversation
| @brianhou0208 thanks for the work, and looks like a good job getting it in shape. I took a closer look using your code but I have some doubts about this model 
 For speed comparisons I disabled F.sdpa in existing vit to be fair. Simpler vits with higher acccuracy (imagenet-1k pretrain also to be fair) are often 30-40% faster. So not convinced this is worth the add. Was there a particular reason you had interest in the model? | 
|     def single_attn(self, x: torch.Tensor) -> torch.Tensor:
        k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)
        if not torch.jit.is_scripting():
            with torch.autocast(device_type=v.device.type, enabled=False):
                y = self._attn_impl(k, q, v)
        else:
            y = self._attn_impl(k, q, v)
        # skip connection
        y = v + self.dp(self.proj(y))  # same as token_transformer in T2T layer, use v as skip connection
        return y
    def _attn_impl(self, k, q, v):
        kp, qp = self.prm_exp(k), self.prm_exp(q)  # (B, T, m), (B, T, m)
        D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2)  # (B, T, m) * (B, m) -> (B, T, 1)
        kptv = torch.einsum('bin,bim->bnm', v.float(), kp)  # (B, emb, m)
        y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon)  # (B, T, emb)/Diag
        return y | 
| Hi @rwightman, I agree with your observation. The T2T-ViT model does not have advantages over other models. The only advantage might be that it does not use any  Another issue occurs when using pre-trained weights and testing whether the structure of first_conv is adaptive to the number of input (C, H, W). If  pytorch-image-models/tests/test_models.py Lines 371 to 376 in d81da93 
 In test_model_load_pretrained, iffirst_convdis like T2T-ViT without Conv, passing this parameter tonn.Linearinstead ofnn.Conv2dwill also report an error.pytorch-image-models/timm/models/_builder.py Lines 225 to 239 in d81da93 
 Since this involves modifying  | 
| @brianhou0208 I don't know if not having the input conv is a 'feature', my very first vit impl here, before the official JAX code was released that used the Conv2D trick was this: pytorch-image-models/timm/models/vision_transformer.py Lines 139 to 169 in 7613094 
 The conv approach was faster since it was an optimized kernel and not a chain of API calls, I suppose torch.compile would rectify most of that but still don't see the downside to the conv. Also the packed vit I started working on (have yet to pick it back up) has to push patchification further into the data pipeline, https://github.com/huggingface/pytorch-image-models/blob/379780bb6ca3304d63bf8ca789d5bbce5949d0b5/timm/models/vision_transformer_packed.py | 
Hi @rwightman this PR resolved #2364 , please check.
Result
test T2T-ViT model and weight on ImageNet val dataset
test code
output log
calculate FLOPs/MACs/Params tool
report from calflops
Reference
paper: https://arxiv.org/pdf/2101.11986
code: https://github.com/yitu-opensource/T2T-ViT