Code for paper "Semi-Supervised Vision Transformers"

Code for paper "Semi-Supervised Vision Transformers"
Abstract: We study the training of Vision Transformers for semi-supervised image classification. Transformers have recently demonstrated impressive performance on a multitude of supervised learning tasks. Surprisingly, we show Vision Transformers perform significantly worse than Convolutional Neural Networks when only a small set of labeled data is available. Inspired by this observation, we introduce a joint semi-supervised learning framework, Semiformer, which contains a transformer stream, a convolutional stream and a carefully designed fusion module for knowledge sharing between these streams. The convolutional stream is trained on limited labeled data and further used to generate pseudo labels to supervise the training of the transformer stream on unlabeled data. Extensive experiments on ImageNet demonstrate that Semiformer achieves 75.5% top-1 accuracy, outperforming the state-of-the-art by a clear margin. In addition, we show, among other things, Semiformer is a general framework that is compatible with most modern transformer and convolutional neural architectures. Code is available at this https URL.

Semiformer: Semi-Supervised Vision Transformers

This repository contains the official Pytorch implementation of our paper: "Semiformer: Semi-Supervised Vision Transformers (ECCV 2022, Poster)"

Zejia Weng, Xitong Yang, Ang Li, Zuxuan Wu, Yu-Gang Jiang

[Paper] [Supp]

Introduction

We introduce a joint semi-supervised learning framework, Semiformer, which contains a transformer stream, a convolutional stream and a carefully designed fusion module for knowledge sharing between these streams. The convolutional stream is trained on limited labeled data and further used to generate pseudo labels to supervise the training of the transformer stream on unlabeled data.

The main framework of the Semiformer is shown as following:

Dependent Packages

PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2

Log and Checkpoint downloaded

The best performance of semiformer in 10%-ImageNet SSL learning is 75.5%. We upload the training log and the corresponding checkpoint, which can be downloaded through the following links:

Data Preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure of the ImageNet data is expected as:

/path/to/imagenet
  train/
    class 1/
    class 2/
    ...
    class 1000/
  val/
    class 1/
    class 2/
    ...
    class 1000/

Evaluation

Download the checkpoint and evaluate the model by the script "/script/eval.sh". You will get the 75.5% Top1 Accuracy. Training process is recorded here.

# set $ROOT as the project root path
# set $DATA_ROOT as the INet path 
# SET $RESUME_PATH as the downloaded checkpoint path
ROOT=/share/home/jia/workspace/semiformer-codeclean/Semiformer
DATA_ROOT=/share/common/ImageDatasets/imagenet_2012
RESUME_PATH=$ROOT/semiformer.pth

cd $ROOT

export CUDA_VISIBLE_DEVICES=0

python -m torch.distributed.launch --master_port 50131 --nproc_per_node=1 --use_env semi_main_concat_evalVersion.py \
                                   --model Semiformer_small_patch16 \
                                   --data-set SEMI-IMNET \
                                   --batch-size 256 \
                                   --num_workers 4 \
                                   --data-path $DATA_ROOT \
                                   --data-split-file $ROOT/data_splits/files2shards_train_size128116_split1.txt \
                                   --eval \
                                   --resume $RESUME_PATH \

Train

Training scripts of Semiformer are provided in /script/submitit_Semiformer_*.sh using submitit.

You can also train the model directly using DDP without submitit, referring to the DDP script example provided in /script/run_ddp_example.sh.

Acknowledgement

This repository is built upon DeiT, Conformer, and timm. Thanks for those well-organized codebases.

Citation

@inproceedings{weng2022semi,
  title={Semi-supervised vision transformers},
  author={Weng, Zejia and Yang, Xitong and Li, Ang and Wu, Zuxuan and Jiang, Yu-Gang},
  booktitle={ECCV},
  year={2022}
}

Download Source Code

Download ZIP

Paper Preview

Aug 15, 2022