๐Ÿงฉ ViT on CIFAR-10

May 31, 2025ยท
Xuankun Yang
Xuankun Yang
,
Junjie Yu
ยท 14 min read
GitHub Repo: ViT-on-CIFAR-10

Exploring the Performance of Vision Transformer on Small Datasets: A Study on CIFAR-10


Preface

I must admit that this was a difficult task with a huge workload, but fortunately I had my partner JunjieYu28. I would like to thank him for his contributions in data augmentation, parameter tuning, and report cooperation. ๐ŸŽ‰

For Chinese report, please see this article.

In order to see each chart clearly, please adjust the mode to light mode in the upper right corner. โ†—๏ธ

Project Overview

This study focuses on the application of Vision Transformer (ViT) in image classification tasks, particularly its performance on the CIFAR-10 dataset. Image classification is a core task in computer vision, where traditional convolutional neural networks (like ResNet) have excelled. The success of Transformer models in natural language processing has inspired their exploration in visual tasks [Dosovitskiy et al., 2021]. ViT processes images by dividing them into patches and feeding them into a Transformer encoder for feature extraction. However, on small datasets, ViT’s performance can be unstable due to the lack of local inductive biases.

The goals of this project include: reproducing ViT and evaluating its performance on CIFAR-10, designing a hybrid model combining ViT and ResNet, analyzing the impact of hyperparameters, and introducing data augmentation strategies to improve performance.

To date, we have implemented the basic ViT model and hybrid models, conducted hyperparameter tuning, data augmentation experiments, and visualization analyses. The experimental results show that the optimal model achieves a Top-1 Accuracy of 92.54% and a Top-5 Accuracy of 99.57%. Further optimization of hybrid model parameter combinations and training efficiency is still needed.

Background

Image classification is a foundational task in computer vision, with traditional CNNs (such as ResNet) achieving remarkable success. In recent years, the breakthroughs of Transformer models in natural language processing have led to their application in vision tasks [Dosovitskiy et al., 2021]. However, vision models based on Transformer architectures, like ViT, are highly sensitive to the amount of training data and lack inductive biases for local features, resulting in unstable performance on small datasets (e.g., CIFAR-10).

Therefore, the research objectives of this experiment include:

  • Reproducing ViT, training it on CIFAR-10, and evaluating its adaptability to small datasets.
  • Combining ViT with ResNet to design hybrid models for improved classification performance.
  • Conducting ablation experiments to analyze the impact of key parameters (e.g., patch size, embedding dimension, number of layers) and perform hyperparameter tuning.
  • Introducing data augmentation strategies and comparing their effects on model accuracy and robustness.

Basic ViT Model Performance

We reproduced the basic ViT architecture and trained it on CIFAR-10 without any hyperparameter tuning, regularization, or data augmentation. The initial performance is shown in the figure below:

Basic ViT Model Performance

The results indicate that the basic model’s performance is suboptimal, necessitating improvements through model architecture enhancements, hyperparameter tuning, or data augmentation strategies to optimize ViT’s performance.

Model Architecture

Basic ViT Model Architecture

In the original paper, the authors used images of size 224x224 and a patch size of 16x16. Since we are using the CIFAR-10 dataset, where each image is 32x32, we did not resize the images but used the original size and set the patch size to 4x4.

Basic ViT Model Architecture

We named this model setting ViT-Basic.

Hybrid Model Architecture

Referencing the integration of ViT and ResNet in the original paper, we proposed hybrid models tailored to the CIFAR-10 dataset, divided into two hyperparameter settings:

  • Three downsampling operations, resulting in a feature map size of 4x4 with 256 channels, and setting ViT’s patch size to 1x1.
  • Two downsampling operations, resulting in a feature map size of 8x8 with 256 channels, and setting the patch size to 1x1.

We named these settings ViT-Hybrid-1 and ViT-Hybrid-2, respectively.

Hybrid ViT Model Architecture

Model Architecture Research

In this section, we explored several model settings. Specific parameter configurations and model parameter counts can be found in Appendix A.

Impact of Number of Self-Attention Heads in Transformer

Using ViT-Hybrid-2, we explored the impact of the number of self-attention heads on model performance, yielding the results in the figure below.

Impact of Number of Heads

Focusing on Top-1 Accuracy, with other parameters at default settings, the model’s performance improves as the number of heads increases. This can be understood as more heads allowing for finer understanding of input images. The performance gain from 12 to 16 heads is smaller than from 8 to 12. Thus, 12 heads is the optimal choice, balancing performance and efficiency, while 16 heads offer marginal additional benefits.

Impact of Number of Transformer Blocks

Still using ViT-Hybrid-2, we explored the impact of the number of Transformer blocks on performance, as shown below.

Impact of Number of Blocks

Observing Top-1 Accuracy, with default settings for other parameters, performance first increases and then decreases as layers increase. With 4 layers, the model may be too simple; with 12 layers, it may be too complex, leading to optimization difficulties under fixed iterations. Thus, 8 layers is optimal under the same iteration count.

Impact of Patch Size

Using ViT-Basic, we explored the impact of patch size on performance.

Impact of Patch Size

For Top-1 Accuracy, a patch size of 4 performs best, indicating that a moderate size effectively captures features. Sizes of 2 and 8 show significant drops, possibly due to insufficient information (too small) or loss of details (too large).

Comparison of Hybrid and Original Models

We compared the performance of different hybrid models and the original model.

Hybrid vs Original

Top-1 Accuracy shows that both ViT-Hybrid-1 and ViT-Hybrid-2 outperform ViT-Basic, indicating that hybrid structures enhance feature extraction. The performance difference between 4x4 and 8x8 feature maps is minimal, suggesting limited impact from feature map size, possibly due to consistent channels or small original image size.

Impact of Hidden Size and MLP Dimension

In ViT, hidden size and MLP dimension are crucial parameters. We kept MLP dim >= hidden size and conducted 5 experiments.

Impact of Hidden Size and MLP Dim

The optimal performance is at hidden size 288 and MLP dim 768. Larger values increase capacity but may cause optimization issues or overfitting on small datasets like ours.

Regularization Exploration

Using ViT-Hybrid-2, we explored three basic regularization methods and the more advanced stochastic depth.

Basic Regularization Methods

We examined Weight Decay, Attention Dropout, and Dropout with varying parameters.

WD DP ADP

For Weight Decay, performance improves as ฮป decreases, stabilizing after 5e-4, matching no-decay levels. Large ฮป may over-penalize weights, leading to underfitting. The model shows strong robustness across different attention and dropout rates.

Stochastic Depth Method

Our stochastic depth method skips blocks with probability p, as shown below. The expected number of updated blocks is:

$$ \mathbb{E}[N_{update}] = N_{block} \times (1 - p) $$

We experimented with different p values.

Stochastic Depth Illustration

Stochastic Depth Results

Large p slightly degrades performance due to fewer and unstable updates. Moderate p (e.g., 1e-2) improves performance, possibly aiding generalization by randomly skipping blocks.

Data Augmentation

Data Augmentation Methods

We combined official augmentation libraries with custom methods and used ablation studies to find optimal combinations:

  • AutoAugment: Proposed by Google Brain in 2019, using 25 optimal sub-policies for CIFAR-10 [Cubuk et al., 2019].
  • RandAugment: Proposed in 2020, using two hyperparameters $N$ and $M$ for unified augmentation [Cubuk et al., 2020].
  • Custom Augmentations: CutMix, MixUp, RandomCropPaste.
    • CutMix: Randomly selects two images, uses $\beta$ distribution for $\lambda$, replaces a region, and weights labels.
    • MixUp: Mixes two images pixel-wise with $\lambda$ from $\beta$ distribution, weighting labels.
    • RandomCropPaste: Crops a region from the image, flips with probability, and pastes back with linear fusion.

We chose RandAugment for its ease of tuning and combined it with [‘None’, ‘RandAugment’] + [‘None’, ‘CutMix’, ‘MixUp’, ‘RandomCropPaste’, ‘Batch_Random’].

Here are some images before and after augmentation:

Aug1
Aug2

Aug3
Aug4

Origin1
Origin2

Mixup1
Mixup2

Cutmix1
Cutmix2

RandAugment Effects

Ablation experiments controlling custom augmentations:

RandAugment Ablation

RandAugment shows good properties: it achieves solid results without complex augmentations and uses less memory (~7000 MiB vs. ~14000 MiB for CutMix/Batch Random).

Custom Augmentation Effects

Ablation without RandAugment:

Custom Aug Types

CutMix or batch-random custom augmentations yield the best training effects.

Visualization

Following the original ViT paper, we conducted the following visualizations.

Attention Maps

Using our optimal model, we visualized attention following [Abnar et al., 2020].

Attention Map

Left: Layers 1-4 and 5-8 from top to bottom; attention shifts with depth but focuses on class-relevant areas. Shallow layers focus locally, deeper ones globally, akin to CNN receptive fields. Right: Gradient map shows model outlines objects well, indicating strong discrimination. For a more intuitive look, please see the following animation.

Attention Map Gif

Feature Maps

To understand ResNet’s contribution, we extracted features from ViT-Hybrid-2 on the original image.

Feature Maps

256 channels show redundancy but include maps representing contours and depth. For 32x32 images, 256 channels may be excessive, warranting further study.

Patch Embedding Visualization

We visualized the first 28 principal components of the convolution in Patch Embedding using ViT-Basic.

Patch Embedding

Our results lack strong interpretability compared to the original as follows, due to small patch size from using original 32x32 images.

Original Patch Embedding

Attention Distance

We observed attention distances across depths and heads.

Attention Distance

Average distance increases and saturates with depth, showing shallow layers capture local features, deeper ones global. Head distributions concentrate with depth, possibly due to feature refinement.

Position Embedding Similarity

The original paper compared position encodings 1-D, 2-D, Relative and Learnable(default); we used learnable ones and examined similarities.

Position Embedding Similarity

Using optimal ViT-Basic and ViT-Hybrid-2, we found off-diagonal maxima parallel to the diagonal, suggesting strong 2D spatial representation (8x8 patches). The cls_token (patch 0) has weak correlations. In hybrids, patterns weaken due to ResNet compressing spatial info.

Experimental Results

Summary of Ablation Experiments

We tuned hyperparameters via ablations on architecture, regularization, and augmentation. Details in Appendix A.

Optimal Model

Our best from-scratch model (Cifar_No_3) achieves 92.54% Top-1 and 99.57% Top-5 on CIFAR-10.

Best Top-1
Best Top-5

Best Loss
Best Confusion

Parameters:

LRWDDPADPSDRAUGAUGMUCMRCP
1e-35e-50.00.01e-3Falsecutmix0.20.8(1.0, 0.5)
Res#B#HHSMDPSAreaTop-1Top-5PRM
2812384384xxAUG92.5499.579.18

Comparison with Traditional CNNs

To evaluate our ViT on CIFAR-10, we compared it with classic CNNs like ResNet [He et al., 2016].

ResNet Comparison 1

ResNet Comparison 2

  1. Accuracy Comparison:

    • Top-1: ViT outperforms ResNet-20 (+2.44 pp) and ResNet-56 (+1.24 pp) but lags ResNet-110 (-0.85 pp) and ResNet-164 (-1.66 pp).
    • Top-5: ViT at 99.57%, close to 100%; ResNet inferred ~99.7%, roughly equal.
  2. Parameter Efficiency: ViT uses ~9.2M params (3.7x ResNet-164) but lower Top-1, indicating lower efficiency; reducing dimensions or sharing params could help.

  3. Convergence and Generalization: ResNet converges after ~4e4 iterations with decay; our model after 15000 steps, showing higher efficiency.

Overall, our ViT approaches or surpasses some classic CNNs on small datasets like CIFAR-10.

References

  • Dosovitskiy et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
  • Cubuk et al. (2019). AutoAugment: Learning Augmentation Strategies from Data.
  • Cubuk et al. (2020). RandAugment: Practical Automated Data Augmentation with a Reduced Search Space.
  • Abnar et al. (2020). Quantifying Attention Flow in Transformers.
  • He et al. (2016). Deep Residual Learning for Image Recognition.

Appendix

A. Hyperparameters and Ablation Experiment Table

No.LearningRateWeightDecayDropoutRateAttentionDOProb_passRAUGAUGMixUpCutMixRandomCropPasteRes#Block#HeadHidden_sizeMLP_dimPatch_sizeAreaTop-1Top-5Parameter(MB)
01e-35e-50.00.01e-3FalseNone0.20.8(1.0, 0.5)2812384384xxAUG89.2399.569.18
11e-35e-50.00.01e-3Falsemixup0.20.8(1.0, 0.5)2812384384xxAUG90.8999.319.18
21e-35e-50.00.01e-3Falserandom_crop_paste0.20.8(1.0, 0.5)2812384384xxAUG89.3999.489.18
31e-35e-50.00.01e-3Falsecutmix0.20.8(1.0, 0.5)2812384384xxAUG92.5499.579.18
41e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxAUG91.8699.779.18
51e-35e-50.00.00.0Falsebatch_random0.20.8(1.0, 0.5)2812384384xxSD91.6599.649.18
61e-35e-50.00.01e-1Falsebatch_random0.20.8(1.0, 0.5)2812384384xxSD91.4999.569.18
71e-35e-50.00.01e-2Falsebatch_random0.20.8(1.0, 0.5)2812384384xxSD92.0499.699.18
81e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxSD91.8199.599.18
91e-35e-50.00.01e-4Falsebatch_random0.20.8(1.0, 0.5)2812384384xxSD91.6499.599.18
101e-30.00.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxWD90.3599.599.18
111e-35e-10.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxWD31.6286.459.18
121e-35e-20.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxWD39.2889.769.18
131e-35e-30.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxWD85.9399.279.18
141e-35e-40.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxWD91.6499.619.18
151e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxWD91.8199.579.18
161e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxDP91.6299.659.18
171e-35e-51e-10.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxDP91.9799.659.18
181e-35e-51e-20.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxDP91.8099.609.18
191e-35e-51e-30.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxDP91.7099.599.18
201e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxADP91.9699.619.18
211e-35e-50.01e-11e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxADP91.7499.659.18
221e-35e-50.01e-21e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxADP91.9799.669.18
231e-35e-50.01e-31e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxADP91.9299.639.18
241e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0812384384(4, 4)Res83.7799.137.16
251e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)1812384384xxRes91.7599.549.85
261e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)2812384384xxRes91.8399.519.18
271e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0812384384(2, 2)PS55.3294.107.22
281e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0812384384(4, 4)PS84.1599.007.16
291e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0812384384(8, 8)PS78.7298.617.19
301e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0412384384xxNB91.5199.585.63
311e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)01212384384xxNB91.6699.5812.74
321e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)088384384xxNH91.3399.609.18
331e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0816384384xxNH91.9299.619.18
341e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0812288288xxHS&MLP_dim91.5899.626.05
351e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0812288384xxHS&MLP_dim91.7199.576.49
361e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0812288768xxHS&MLP_dim92.0499.678.26
371e-35e-50.00.01e-3Falsebatch_random0.20.8(1.0, 0.5)0812384768xxHS&MLP_dim91.5899.6411.55
381e-35e-50.00.01e-3Falsecutmix0.22.0(weak)(1.0, 0.5)0812384384xxCM92.2199.749.18
391e-35e-50.00.01e-3Falsecutmix0.20.1(strong)(1.0, 0.5)0812384384xxCM92.3599.699.18
401e-35e-50.00.01e-3Falsemixup0.8(random)0.8(1.0, 0.5)0812384384xxMU91.5599.339.18
411e-35e-50.00.01e-3Falsemixup2.5(strong)0.8(1.0, 0.5)0812384384xxMU91.4399.599.18
Best_01e-35e-51e-21e-21e-2Falsecutmix0.20.8(1.0, 0.5)0812384384xxFind_Best

B. Parameter Tuning Table

NO.aug_typecutmixmixuprandom_croprand_aug(HL,MLP)top_1top_5
Test_1Cutmix0.8xxxx(2,9)xx0.91420.9974
Test_2Nonexxxxxx(2,9)xx0.89180.9941
Test_3batch_random0.82.5(1.0,0.8)(4,15)xx0.91930.9966
Test_4Nonexxxxxx(4,15)xx0.91130.9968
Test_5Nonexxxxxx(4,15)xx0.91170.9968
Test_6Cutmix0.8xxxx(4,15)xx0.92250.9971
Test_7Mixupxx2.5xx(4,15)xx0.91300.9966
Test_8Cutmix0.8xxxx(3,15)xx0.92320.9980
Test_9Cutmix0.8xxxx(4,15)(288,768)0.92440.9978
Test_10Cutmix0.8xxxx(2,15)xx0.92280.9974
Test_11Cutmix0.8xxxxFalse(288,768)0.92290.9968

C. Symbols and Abbreviations

Symbol/AbbrevMeaning
LRLearning Rate
WDWeight Decay
DPDropout Rate
ADPAttention Dropout Rate
SDStochastic Depth
RAUGRandAugment Enabled
AUGAugmentation Strategy
MUMixUp Parameter
CMCutMix Parameter
RCPRandom Crop Paste Params
ResResNet Variant
#BNumber of Blocks
#HNumber of Heads
HSHidden Size
MDMLP Dimension
PSPatch Size
AreaAblation Study Area
Top-1Top-1 Accuracy
Top-5Top-5 Accuracy
PRMParameter Count (MiB)

D. Experimental Hardware and Software Environment

  • GPU: NVIDIA RTX 3090 ร— 4 (24 GB ร— 4)
  • CUDA: 12.4
  • Python: 3.10.16
  • PyTorch: 2.5.1
  • TorchVision: 0.20.1