Abstract: Self-attention based transformer models have been dominating many computer vision tasks in the past few years. Their superb model qualities heavily depend on the excessively large labeled image datasets. In order to reduce the reliance on large labeled datasets, reconstruction based masked autoencoders are gaining popularity, which learn high quality transferable representations from unlabeled images. For the same purpose, recent weakly supervised image pretraining methods explore language supervision from text captions accompanying the images. In this work, we propose masked image pretraining on language assisted representation, dubbed as MILAN. Instead of predicting raw pixels or low level features, our pretraining objective is to reconstruct the image features with substantial semantic signals that are obtained using caption supervision. Moreover, to accommodate our reconstruction target, we propose a more efficient prompting decoder architecture and a semantic aware mask sampling mechanism, which further advance the transfer performance of the pretrained model. Experimental results demonstrate that MILAN delivers higher accuracy than the previous works. When the masked autoencoder is pretrained and finetuned on ImageNet-1K dataset with an input resolution of 224x224, MILAN achieves a top-1 accuracy of 85.4% on ViTB/16, surpassing previous state-of-the-arts by 1%. In the downstream semantic segmentation task, MILAN achieves 52.7 mIoU using ViT-B/16 backbone on ADE20K dataset, outperforming previous masked pretraining results by 4 points.
MILAN: Masked Image Pretraining on Language Assisted Representation
This repository contains the PyTorch implementation of the paper MILAN: Masked Image Pretraining on Language Assisted Representation.
- This repo was built upon the MAE repo. Installation and preparation follow that repo.
Prepare the dataset
-
We use ImageNet-1K. Download the images.
-
Extract the training data:
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
cd ..
- Extract the validation data and move the images to subfolders:
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
Pretraining
Example of applying MILAN to pretrain ViT-B/16 on ImageNet-1K using 8 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_pretrain.py \
--model mae_vit_base_patch16 \
--batch_size 256 \
--accum_iter 2 \
--mask_ratio 0.75 \
--epochs 400 \
--warmup_epochs 40 \
--blr 1.5e-4 \
--weight_decay 0.05 \
--data_path /dataset/imagenet \
--output_dir ./milan_vit_base_pretrain_400epoch_useclip_changedecoder_attnmask \
--log_dir ./milan_vit_base_pretrain_400epoch_useclip_changedecoder_attnmask \
--use_clip \
--change_decoder \
--attn_mask
- The available
--model
choices are listed inmodels_milan.py
. - Effective batch size is 256 (
--batch_size
per GPU) * 2 (--accum_iter
gradient accumulation) * 8 (GPUs) = 4096. - Effective learning rate is 1.5e-4 (
--blr
base learning rate) * 4096 (effective batch size) / 256 = 2.4e-3. --mask_ratio
: percentage of patches to remove.--epochs
: total pretraining epochs,--warmup_epochs
: learning rate warmup epochs.- We apply
--weight decay
of 0.05 during pretraining. - We use the ViT-B/16 CLIP image encoder obtained from here to produce the reconstruction targets during pretraining.
--change_decoder
: switch to the prompting decoder.--attn_mask
: switch to the semantic aware masking strategy.- Training time is ~39h using 8 40GB A100 GPUs for 400 epochs.
Finetuning and linear probing
- Example of finetuning the pretrained ViT-B/16 on ImageNet-1K:
OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--accum_iter 1 \
--batch_size 128 \
--model vit_base_patch16 \
--finetune ./milan_vit_base_pretrain_400epoch_useclip_changedecoder_attnmask/checkpoint-399.pth \
--epochs 100 \
--blr 1e-4 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path /data/imagenet \
--output_dir ./milan_vit_base_finetune_pretrain400epochuseclipchangedecoderattnmask \
--log_dir ./milan_vit_base_finetune_pretrain400epochuseclipchangedecoderattnmask \
--global_pool
- Example of performing linear probing on the pretrained ViT-B/16:
OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_linprobe.py \
--accum_iter 1 \
--batch_size 2048 \
--model vit_base_patch16 \
--cls_token \
--finetune ./milan_vit_base_pretrain_400epoch_useclip_changedecoder_attnmask/checkpoint-399.pth \
--epochs 100 \
--blr 0.05 \
--weight_decay 0.0 \
--dist_eval --data_path /data/imagenet \
--output_dir ./milan_vit_base_linearprobe_pretrain400epochuseclipchangedecoderattnmask \
--log_dir ./milan_vit_base_linearprobe_pretrain400epochuseclipchangedecoderattnmask
Checkpoints
We provide the pretrained ViT-B/16 and ViT-L/16 checkpoints.
ViT-Base | ViT-Large | |
---|---|---|
Pretrained checkpoint | download | download |
Citation
If you find this repository helpful, please consider citing:
@article{MILAN2022,
title = {MILAN: Masked Image Pretraining on Language Assisted Representation},
author = {Hou, Zejiang and Sun, Fei and Chen, Yen-Kuang and Xie, Yuan and Kung, Sun-Yuan},
journal = {arXiv preprint arXiv:2208.06049},
year = {2022},
}