ImageBERT
๋ฆฌ๋ทฐํ paper๋ ImageBERT: Cross-modal Pre-training with Large-scale Weak-supervised Image-Text Data ์ ๋๋ค. ์์๋ ๋ ผ๋ฌธ๊ณผ ๋์ผํ ์์๋ก ๋์ด ์์ต๋๋ค. Key word
Large Data Set
One Transformer
Multi-stage Pre-training + 4 tasks
1 Introduction
Text-Image Retrieval, Visual Question Answering(VQA), Visual Commonsense Reasoning(VCR)๋ฑ image์ text๋ฅผ ๋๋ค ์ฒ๋ฆฌํ๋ task์ ๊ด์ฌ์ด ๋ง์ ์ง๊ณ ์๋๋ฐ์. NLP์์ ์ฑ๊ณต์ ๋ณด์ธ, pre-training์ ํ์ฉํ ๋ฐฉ๋ฒ์ language + vision cross modal task์๋ ์ ์ฉํ๊ณ ํ๋ ๋ ธ๋ ฅ์ด ์์์ต๋๋ค. ์ด ๋ ผ๋ฌธ์์ ์ด๋ฌํ cross-modal pre-training ๋ฐฉ๋ฒ๋ค์ ๋น๊ตํ๊ณ , ์์ ๋ค์ด ์ ์ํ ๋ฐฉ๋ฒ๋ค์ ์๊ฐํฉ๋๋ค.
Transformer๊ฐ ์ ์๋ ํ, ๋ง์ ๋ ผ๋ฌธ๋ค์ด ์ด๋ฅผ ํ์ฉํ์ฌ cross-modal ๋ฌธ์ ๋ฅผ ํ๊ณ ์ ํ์์ต๋๋ค. ์ด๋ค์ ์ค์ฌ์ผ๋ก related work์ ๋ถ์ํ์์ต๋๋ค.
Model Architecture
language์ vision์ ๋ณ๋์ transformer๋ฅผ ์ ์ฉํ, ์ด์ cross-modal transformer๋ฅผ ์ถ๊ฐํ๋ ๋ฐฉ๋ฒ: ViLBERT, LXMERT๋ฑ
image์ sentence๋ฅผ ํ๋๋ก concat ์ํจ ํ, ํ๋์ transfomer์ ๋ฃ์ด์ ์ฒ๋ฆฌํ๋ ๋ฐฉ๋ฒ: VisualBERT, B2T2, Unicoder-VL, VL-BERT, Unified VLP, UNITER๋ฑ
๋ ๋ฐฉ๋ฒ ์ค ์ด๋ ๋ฐฉ๋ฒ์ด ์ข๋ค๊ณ ๋ง ํ ์ ์์ง๋ง, ๋ณธ ๋ ผ๋ฌธ์์๋ ํ๋์ transfomer๋ฅผ ์ฌ์ฉํฉ๋๋ค.
Image visual tokens
๋ง์ ๋ฐฉ๋ฒ์์ regions of interest(RoI)๋ฅผ word์ token์ฒ๋ผ ์ฌ์ฉํฉ๋๋ค.
VL-BERT์์๋ RoI๋ฅผ ๊ตฌํ๋ detection model๊น์ง ๊ฐ์ด ํ์ต์ ํ์์ต๋๋ค. ๋ณดํต ํ์ต๋ model์ ์ฌ์ฉํฉ๋๋ค. ๋ํ global image feature๋ token์ผ๋ก ์ถ๊ฐํ์์ต๋๋ค.
Pre-train data
๊ธฐ์กด์ ๋ง์ด ์์ฉ๋๋ data-set์ Conceptual Captions(3M)๊ณผ SBU Captions(1M)์ด ์์ง๋ง, NLP์์ ์ฌ์ฉํ๋ ๋ฐฉ๋ํ data-set์ ๋นํ๋ฉด, ๋ง์กฑํ ๋งํ data-set์ ์ป๊ธฐ ์ด๋ ต์ต๋๋ค.
3 Large-Scale Week-supervised Image-Text Data Collection
NLP์ ๊ฒฝ์ฐ, wiki์ ๊ฐ์ด ๋งค์ฐ ๋ฐฉ๋ํ data-set(wiki, book, โฆ)์ ์ป์ ์ ์์ง๋ง, cross-modal์์๋ ์ด๋ฌํ ๋ฐฉ๋ํ data-set์ ์ป๊ธฐ ์ด๋ ต์ต๋๋ค. ๋ง์ด ์ฌ์ฉํ๋ data-set์ผ๋ก CC(3M), SBU(1M)๊ฐ ์์ง๋ง, ์ถฉ๋ถํ๋ค๊ณ ๋ณผ ์ ์์ต๋๋ค. ์ด๋ฅผ ๋ณด์ํ๊ธฐ ์ํด์, ๋ ผ๋ฌธ์์๋ web-site์์ LAIT(Large-scale weAk-supervised Image-Text)(10M)๋ฅผ ๋ง๋ค์์ต๋๋ค. LAIT๋ฅผ ๋ง๋๋ ๋ฐฉ๋ฒ์ ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด 5๋จ๊ณ๋ก ์ด๋ฃจ์ด ์ง๋๋ค.
4 ImageBERT Model
4.1 Embedding Modelins
textual token์ ๋ง๋ค๊ธฐ ์ํด์, BERT์ ๋์ผํ WorlPiece[2]๋ฅผ ์ฌ์ฉ ํฉ๋๋ค.
visual token์ ์์ฑํ๊ธฐ ์ํด์, Faster-RCNN[3]๋ฅผ ์ฌ์ฉํ์ฌ RoI ์ถ์ถํฉ๋๋ค. ์ด๋ ์์ฑ๋๋ feature์ location ์ ๋ณด๋ฅผ ์ฌ์ฉํ์ฌ visual token์ embeddingํฉ๋๋ค. ์ด๋, ๊ฐ๊ฐ์ embedding layer๋ Transformer์ hidden size์ ๋์ผํ size๋ก ๊ฐ vector๋ฅผ projectํฉ๋๋ค.
$ c^{(i)} = (\frac{x_{tl}}{W},\frac{y_{tl}}{H},\frac{x_{br}}{W},\frac{y_{br}}{H},\frac{(x_{br}-x_{tl})(y_{br}-y_{tl})}{WH})\\ r^{(i)} = extracted\ features\ of\ the\ i_{th}\ RoI\\ v^{(i)} = ImageEmbed(r^{(i)})\\ s^{(i)} = SegmentEmbed(i)\\ p^{(i)}_{img} = PositionEmbed(c^{(i)})\\ p^{(i)}_{seq} = PositionEmbed(i)\\ e^{(i)} = LN(v^{(i)} + s^{(i)} + p^{(i)}_{img} + p^{(i)}_{seq}) $
sequence position embedding์ ๊ฒฝ์ฐ, visual token์๋ dummy vector๊ฐ ์ ์ฉ๋๋ฉด, textual token์๋ ์์์ ๋ฐ๋ผ vector๊ฐ ๋ฐฐ์น๋ฉ๋๋ค.
segment embedding์ ๊ฒฝ์ฐ, ์๋ก ๋ค๋ฅธ modality๋ฅผ ํํํ๋ ์ฉ๋๋ก ์ฌ์ฉ๋ฉ๋๋ค.
4.2 Multi-stage Pre-training
data-set์ ์ถ์ฒ๊ฐ ๋ค๋ฅด๊ณ , quality & noise distribution์ด ์๋ก ๋ค๋ฅด๊ธฐ ๋๋ฌธ์, ์๋์ ๊ฐ์ multi-stage pre-training์ ์ฌ์ฉํ์์ต๋๋ค. ํ์ต์ ์์๋ large-scale out-of-domain data-set์ผ๋ก ๋จผ์ ํ์ตํ๊ณ , ์ดํ์ ์ ์ฐจ small scale in-domain data-set์ผ๋ก ํ์ต์ ์งํํฉ๋๋ค.
4.3 Pre-training tasks
linguistic information๊ณผ visual content๋ฅผ ํ์ตํ๊ธฐ ์ํ์ฌ, 4๊ฐ์ง task๋ฅผ ์ฌ์ฉํ์์ต๋๋ค.
Masked Language Modeling (MLM): Bert์์ ์ฌ์ฉํ MLM์ ๋์ผํ task
masked randomly with a probability of 15%
replaced with a special token [MASK](80%), a random token(10%), remains unchanged(10%)
using the negative log-likelihood
$$ L_{MLM}(ฮธ) = โE_{(v,w)โผD} log Pฮธ(w_{mT}|w_{\mT}, v) $$
Masked Object Classification (MOC)
masked randomly with a probability of 15%
replaced with a zero out the masked token(90%), keep the original token(10%)
add a fully-connected layer to predict the correct label from K object classes
using the cross-entropy(CE) loss
Faster R-CNN model as ground truth label:
$l_{ฮธ}(v^{(i)}_{mI})$
the output vector corresponding to the masked token:
$f_{ฮธ}(v^{(i)}_{mI})$
$$ L_{MOC} (ฮธ) = โE_{(v,w)โผD}{\sum^{Mโ1}_{i=0}CE(l_{ฮธ}(v^{(i)}_{mI}), f_{ฮธ}(v^{(i)}_{mI}))} $$
Masked Region Feature Regression (MRFR)
This task aims to regress the embedding feature of each masked object
add a fully-connected layer on top of the output feature vector to make same dimension
using the L2 loss
Faster R-CNN model as ground truth feature:
$ r_{ฮธ}(v^{(i)}_{mI}) $
the output feature corresponding to the masked token:
$ h_{ฮธ}(v^{(i)}_{mI} $
$$ L_{MRFR} (ฮธ) = โE_{(v,w)โผD}{\sum^{Mโ1}_{i=0}\lVert h_{ฮธ}(v^{(i)}_{mI}) - r_{ฮธ}(v^{(i)}_{mI})} \rVert_2^2 $$
Image Text Matching (ITM)
This task aims to learn the image-text alignment
Negative training data
randomly sample negative sentences for each image
randomly sample negative images for each sentence
addy a fully-connected layer on top to obtain the image-text similarity score:
$s_ฮธ(v, w)$
using binary classification loss
the ground truth label:
$y โ {0, 1}$
$$ L_{ITM}(ฮธ) = โE_{(v,w)โผD}[y \log s_ฮธ(v, w) + (1 โ y) \log (1 โ s_ฮธ(v, w))] $$
4.4 Fine-tuning tasks
Fine-tuning์ MSCOCO and Flickr30k data-set์ผ๋ก ์งํํ์์ผ๋ฉฐ, input sequence๋ pre-training๊ณผ ๋์ผํฉ๋๋ค. Fine-tuning๊ณผ์ ์์๋ mask๋ฅผ ์ฌ์ฉํ task๋ค์ ์ฌ์ฉํ์ง ์๊ณ , ITM๋ง ์ฌ์ฉํ์์ต๋๋ค. 3๊ฐ์ง loss๋ฅผ ์ฌ์ฉํ์ฌ ์คํ์ ์งํํ์์ต๋๋ค.
Binary classification Loss
$ L_{BCE}(ฮธ) = โE_{(v,w)}[y \log c_ฮธ(t_{(v,w)}) + (1 โ y) \log (1 โ c_ฮธ(t_{(v,w)}))] $
Multi-class Classification Loss
$ L_{CE}(ฮธ) = โE^{(j)}_{(v,w)}{\sum^{P-1}_{j=0}}CE(s(t^{(j)}_{(v,w)}), l^{(j)}_{(v,w)}) $
Triplet Loss
$ L_{Triplet}(ฮธ) = โE^{(j)}_{(v,w)}{\sum_(n^-โN)} \max [ 0, s(t_{(v,w)^+} ), s(n^-_h)] $
5 Experiments
Transformer: a 12-layer with 768 hidden units, 3072 intermediate units, and 12 attention heads
Dropout probability to 0.1
Use GELU as activation function
The max length of our input sequence is fixed to 144, 100 visual tokens + other linguistic tokens and special tokens
Use a Faster RCNN model pre-trained on Visual Genome dataset with 1600 categories
Pre-training:
data-set
1-stage: Use the LAIT(10M), with parameter initialized from the BERT-base model
2-stage: Use pre-training on public datasets: CC(3M), SBU(1M)
hyperparamter
batch size = 48
learning rate = 1e-4 with Adam optimizer
17 epochs using 4 V100 GPUs
tasks
Use conditional mask in MLM, MOC and MRFR tasks
Only calculate the masked loss when the input pair is a positive sample
Fine-tuning
data-set
Use Flickr30k and MSCOCO
hyperparameter
batch size = 24
learning rate = 5e-5
130 epochs using 4 V100 GPUs
tasks
Only use ITM
5.1 Evaluation for the Pre-trained Model
Zero-shot result of pre-train model: 1-stage pre-training์ ์ฌ์ฉํ ๋ฐฉ๋ฒ๊ณผ ๋น๊ตํ์ฌ, comparable results๋ฅผ ๋ณด์ ๋๋ค. ์ถ๊ฐ์ ์ผ๋ก, fine-tuning์์๋ ๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ฃผ๊ธฐ ๋๋ฌธ์, multi-stage pre-training์ด single-stage pre-training ๋ณด๋ค usefulํ ์ง์์ ํ์ต ํ๋ค๊ณ ๋์ด ์์ต๋๋ค.
our multi-stage pre-training strategy learns more useful knowledge during pre-training, and can consequently contribute to the fine-tuning stage on the downstream tasks.
5.2 Evaluation for the Fine-tuned Model
Flickr30k์ MSCOCO์์ ๋ชจ๋ state-of-the-art๋ฅผ ๋ฌ์ฑํ์์ต๋๋ค. pre-training์์ quality & noise distribution๊ฐ ๋ค๋ฅธ data-set์ ์ฌ์ฉํ ๊ฒฝ์ฐ, data-set์ ๋๋์ด ํ์ตํ๋๊ฒ ํ๋๋ก ํ์ตํ๋ ๊ฒ๋ณด๋ค ์ข๋ค๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
5.3 Ablation Studies
Pre-train dataset: LAIT, CC, SBU๋ฅผ ์กฐํฉํ test๋ค์์, ์ ์ํ multi-stage pre-trining์ด ๊ฐ์ฅ ์ข์ ์ฑ๋ฅ์ ๋ณด์์ต๋๋ค.
Global image features: RoIs๊ฐ ์ ์ฒด image์ ์ ๋ณด๋ฅผ ๋ด์ง ๋ชป ํ ๋๋ฅผ ๋๋นํ์ฌ, Global image feature๋ฅผ ์ถ๊ฐํ๋ ๋ฐฉ๋ฒ์ ์คํํ์์ต๋๋ค. Table4.1์ ๊ฒฐ๊ณผ์ ๊ฐ์ด ์ ์๋ฏธํ ์ฑ๋ฅ ๊ฐ์ ์ ๋ณด์ด์ง ์์์ต๋๋ค.
Pre-train loss: MRFR task๋ฅผ ์ ์๋ฏธ์ฑ์ Table4.2๋ฅผ ํตํด ๋ณด์ฌ์ค๋๋ค. MRFR task๋ ๋ค๋ฅธ task๋ณด๋ค ์ด๋ ค์ด task์ ์ํ๋๋ฐ, ์ด๋ฅผ ํตํด์ ์ด๋ ค์ด task๋ฅผ ์ถ๊ฐํ๋ ๊ฒ์ด ๋ ์ข์ model์ ์ป๋๋ฐ ๋์์ด ๋๋ค๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
Number of objects (RoIs) from image: ์ง๊ธ๊น์ง ํ ์คํ๋ค์ ๋ชจ๋ 100๊ฐ์ RoI๋ฅผ ์ฌ์ฉํ์์ต๋๋ค. RoI์ ์๊ฐ ์ฑ๋ฅ์ ๋ฏธ์น๋ ์ํฅ์ Table4.3์ ํตํด ๋ณด์ฌ์ค๋๋ค. ๋ง์ objects๊ฐ ๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ด๋๋ฐ ๋์์ด ๋๋ ๊ฑธ ์ ์ ์์ต๋๋ค.
Fine-tune loss: Fine-tuning์ ์ฌ์ฉํ ์ฌ๋ฌ loss์ ๊ฒฐ๊ณผ๋ฅผ Table4.4๋ฅผ ํตํด ๋ณด์ฌ์ค๋๋ค. ๊ฐ์ธ์ ์ผ๋ก๋ Triplet loss๊ฐ ๊ฐ์ฅ ์ข์๊ฑฐ๋ผ๊ณ ์๊ฐํ์ต๋๋ค. ํ์ง๋ง ๊ฒฐ๊ณผ์ ์ผ๋ก Binary only๊ฐ ๊ฐ์ฅ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค.
6 Conclusion
์ด ๋ ผ๋ฌธ์ ํน์ง์ ์ ๋ฆฌํ๋ฉด, ์๋ 3๊ฐ์ง๋ก ์์ฝํ ์ ์์ต๋๋ค.
Transformer๋ฅผ ๊ธฐ๋ฐ, vision-language joint embedding architecture
๊ธฐ์กด data-set๋ณด๋ค ํฐ LAIT data-set
Multi-stage pre-training with 4 tasks(MLM, MOC, MRFR, ITM)
Reference
Last updated