DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification
Yongming Rao1 Wenliang Zhao1 Benlin Liu2
Jiwen Lu1 Jie Zhou1 Cho-Jui Hsieh2
1Tsinghua University 2University of California, Los Angeles
[Paper (arXiv)] [Code (GitHub)]
Figure 1: Unlike the structural downsampling in CNNs, the dynamic token sparsification can better exploit the sparsity in the input data by data-dependent downsampling.
Figure 2: The overall framework of our DynamicViT. The proposed prediction module is inserted between the transformer blocks to selectively prune less informative token conditioned on features produced by the previous layer. By doing so, less tokens are processed in the followed layers.
Abstract
Attention is sparse in vision transformers. We observe the final prediction in vision transformers is only based on a subset of most informative tokens, which is sufficient for accurate image recognition. Based on this observation, we propose a dynamic token sparsification framework to prune redundant tokens progressively and dynamically based on the input. Specifically, we devise a lightweight prediction module to estimate the importance score of each token given the current features. The module is added to different layers to prune redundant tokens hierarchically. To optimize the prediction module in an end-to-end manner, we propose an attention masking strategy to differentiably prune a token by blocking its interactions with other tokens. Benefiting from the nature of self-attention, the unstructured sparse tokens are still hardware friendly, which makes our framework easy to achieve actual speed-up. By hierarchically pruning 66% of the input tokens, our method greatly reduces 31% ~ 37% FLOPs and improves the throughput by over 40% while the drop of accuracy is within 0.5% for various vision transformers. Equipped with the dynamic token sparsification framework, DynamicViT models can achieve very competitive complexity/accuracy trade-offs compared to state-of-the-art CNNs and vision transformers on ImageNet.
Video
Results
We show that DynamicViT models can achieve favorable complexity/accuracy trade-offs on ImageNet.
The visualization of the progressively pruned tokens demonstrates that our DynamicViT has better interpretability.
Table 1: Main results on ImageNet. We apply our method to three vision transformers: DeiT-S, LV-ViT-S, and, LV-ViT-M. We report the top-1 classification accuracy, theoretical complexity in FLOPs, and throughput for different keeping ratios. The throughput is measured on a single NVIDIA RTX 3090 GPU with batch size 32.
Figure 3: Model complexity (FLOPs) and top-1 accuracy trade-offs on ImageNet. Our models achieve better trade-offs compared to the various vision transformers as well as carefully designed CNN models.
Figure 4: Comparison of our dynamic token sparsification method with model width scaling. We see dynamic token sparsification is more efficient than commonly used model width scaling.
Figure 5: Visualization of the progressively sparsified tokens. We show the original input image and the sparsification results after the three stages, where the masks represent that the corresponding tokens are discarded. We see our method can gradually focus on the most representative regions in the image. This phenomenon suggests that the DynamicViT has better interpretability.
BibTeX
@inproceedings{rao2021dynamicvit,
title={DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification},
author={Rao, Yongming and Zhao, Wenliang and Liu, Benlin and Lu, Jiwen and Zhou, Jie and Hsieh, Cho-Jui},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2021}
}