ImageBERT

๋ฆฌ๋ทฐํ•  paper๋Š” ImageBERT: Cross-modal Pre-training with Large-scale Weak-supervised Image-Text Data ์ž…๋‹ˆ๋‹ค. ์ˆœ์„œ๋Š” ๋…ผ๋ฌธ๊ณผ ๋™์ผํ•œ ์ˆœ์„œ๋กœ ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. Key word

  • Large Data Set

  • One Transformer

  • Multi-stage Pre-training + 4 tasks

1 Introduction

Text-Image Retrieval, Visual Question Answering(VQA), Visual Commonsense Reasoning(VCR)๋“ฑ image์™€ text๋ฅผ ๋‘˜๋‹ค ์ฒ˜๋ฆฌํ•˜๋Š” task์— ๊ด€์‹ฌ์ด ๋งŽ์•„ ์ง€๊ณ  ์žˆ๋Š”๋ฐ์š”. NLP์—์„œ ์„ฑ๊ณต์„ ๋ณด์ธ, pre-training์„ ํ™œ์šฉํ•œ ๋ฐฉ๋ฒ•์„ language + vision cross modal task์—๋„ ์ ์šฉํ•˜๊ณ  ํ•˜๋Š” ๋…ธ๋ ฅ์ด ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์ด ๋…ผ๋ฌธ์—์„œ ์ด๋Ÿฌํ•œ cross-modal pre-training ๋ฐฉ๋ฒ•๋“ค์„ ๋น„๊ตํ•˜๊ณ , ์ž์‹ ๋“ค์ด ์ œ์•ˆํ•œ ๋ฐฉ๋ฒ•๋“ค์„ ์†Œ๊ฐœํ•ฉ๋‹ˆ๋‹ค.

Transformer๊ฐ€ ์ œ์•ˆ๋œ ํ›„, ๋งŽ์€ ๋…ผ๋ฌธ๋“ค์ด ์ด๋ฅผ ํ™œ์šฉํ•˜์—ฌ cross-modal ๋ฌธ์ œ๋ฅผ ํ’€๊ณ ์ž ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ์ด๋“ค์„ ์ค‘์‹ฌ์œผ๋กœ related work์„ ๋ถ„์„ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

  • Model Architecture

    • language์™€ vision์— ๋ณ„๋„์˜ transformer๋ฅผ ์ ์šฉํ›„, ์ด์— cross-modal transformer๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๋ฐฉ๋ฒ•: ViLBERT, LXMERT๋“ฑ

    • image์™€ sentence๋ฅผ ํ•˜๋‚˜๋กœ concat ์‹œํ‚จ ํ›„, ํ•˜๋‚˜์˜ transfomer์— ๋„ฃ์–ด์„œ ์ฒ˜๋ฆฌํ•˜๋Š” ๋ฐฉ๋ฒ•: VisualBERT, B2T2, Unicoder-VL, VL-BERT, Unified VLP, UNITER๋“ฑ

    • ๋‘ ๋ฐฉ๋ฒ• ์ค‘ ์–ด๋Š ๋ฐฉ๋ฒ•์ด ์ข‹๋‹ค๊ณ  ๋ง ํ•  ์ˆ˜ ์—†์ง€๋งŒ, ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ํ•˜๋‚˜์˜ transfomer๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

  • Image visual tokens

    • ๋งŽ์€ ๋ฐฉ๋ฒ•์—์„œ regions of interest(RoI)๋ฅผ word์˜ token์ฒ˜๋Ÿผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

    • VL-BERT์—์„œ๋Š” RoI๋ฅผ ๊ตฌํ•˜๋Š” detection model๊นŒ์ง€ ๊ฐ™์ด ํ•™์Šต์„ ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋ณดํ†ต ํ•™์Šต๋œ model์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ global image feature๋„ token์œผ๋กœ ์ถ”๊ฐ€ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

  • Pre-train data

    • ๊ธฐ์กด์— ๋งŽ์ด ์ƒ์šฉ๋˜๋Š” data-set์€ Conceptual Captions(3M)๊ณผ SBU Captions(1M)์ด ์žˆ์ง€๋งŒ, NLP์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋Œ€ํ•œ data-set์— ๋น„ํ•˜๋ฉด, ๋งŒ์กฑํ•  ๋งŒํ•œ data-set์„ ์–ป๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค.

3 Large-Scale Week-supervised Image-Text Data Collection

NLP์˜ ๊ฒฝ์šฐ, wiki์™€ ๊ฐ™์ด ๋งค์šฐ ๋ฐฉ๋Œ€ํ•œ data-set(wiki, book, โ€ฆ)์„ ์–ป์„ ์ˆ˜ ์žˆ์ง€๋งŒ, cross-modal์—์„œ๋Š” ์ด๋Ÿฌํ•œ ๋ฐฉ๋Œ€ํ•œ data-set์„ ์–ป๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค. ๋งŽ์ด ์‚ฌ์šฉํ•˜๋Š” data-set์œผ๋กœ CC(3M), SBU(1M)๊ฐ€ ์žˆ์ง€๋งŒ, ์ถฉ๋ถ„ํ•˜๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ๋ณด์™„ํ•˜๊ธฐ ์œ„ํ•ด์„œ, ๋…ผ๋ฌธ์—์„œ๋Š” web-site์—์„œ LAIT(Large-scale weAk-supervised Image-Text)(10M)๋ฅผ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค. LAIT๋ฅผ ๋งŒ๋“œ๋Š” ๋ฐฉ๋ฒ•์€ ์•„๋ž˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด 5๋‹จ๊ณ„๋กœ ์ด๋ฃจ์–ด ์ง‘๋‹ˆ๋‹ค.

4 ImageBERT Model

4.1 Embedding Modelins

  • textual token์„ ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด์„œ, BERT์™€ ๋™์ผํ•œ WorlPiece[2]๋ฅผ ์‚ฌ์šฉ ํ•ฉ๋‹ˆ๋‹ค.

  • visual token์„ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด์„œ, Faster-RCNN[3]๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ RoI ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค. ์ด๋•Œ ์ƒ์„ฑ๋˜๋Š” feature์™€ location ์ •๋ณด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ visual token์„ embeddingํ•ฉ๋‹ˆ๋‹ค. ์ด๋•Œ, ๊ฐ๊ฐ์˜ embedding layer๋Š” Transformer์˜ hidden size์™€ ๋™์ผํ•œ size๋กœ ๊ฐ vector๋ฅผ projectํ•ฉ๋‹ˆ๋‹ค. $ c^{(i)} = (\frac{x_{tl}}{W},\frac{y_{tl}}{H},\frac{x_{br}}{W},\frac{y_{br}}{H},\frac{(x_{br}-x_{tl})(y_{br}-y_{tl})}{WH})\\ r^{(i)} = extracted\ features\ of\ the\ i_{th}\ RoI\\ v^{(i)} = ImageEmbed(r^{(i)})\\ s^{(i)} = SegmentEmbed(i)\\ p^{(i)}_{img} = PositionEmbed(c^{(i)})\\ p^{(i)}_{seq} = PositionEmbed(i)\\ e^{(i)} = LN(v^{(i)} + s^{(i)} + p^{(i)}_{img} + p^{(i)}_{seq}) $

  • sequence position embedding์˜ ๊ฒฝ์šฐ, visual token์—๋Š” dummy vector๊ฐ€ ์ ์šฉ๋˜๋ฉด, textual token์—๋Š” ์ˆœ์„œ์— ๋”ฐ๋ผ vector๊ฐ€ ๋ฐฐ์น˜๋ฉ๋‹ˆ๋‹ค.

  • segment embedding์˜ ๊ฒฝ์šฐ, ์„œ๋กœ ๋‹ค๋ฅธ modality๋ฅผ ํ‘œํ˜„ํ•˜๋Š” ์šฉ๋„๋กœ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

4.2 Multi-stage Pre-training

data-set์˜ ์ถœ์ฒ˜๊ฐ€ ๋‹ค๋ฅด๊ณ , quality & noise distribution์ด ์„œ๋กœ ๋‹ค๋ฅด๊ธฐ ๋•Œ๋ฌธ์—, ์•„๋ž˜์™€ ๊ฐ™์€ multi-stage pre-training์„ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค. ํ•™์Šต์˜ ์ˆœ์„œ๋Š” large-scale out-of-domain data-set์œผ๋กœ ๋จผ์ € ํ•™์Šตํ•˜๊ณ , ์ดํ›„์— ์ ์ฐจ small scale in-domain data-set์œผ๋กœ ํ•™์Šต์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

4.3 Pre-training tasks

linguistic information๊ณผ visual content๋ฅผ ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•˜์—ฌ, 4๊ฐ€์ง€ task๋ฅผ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.

  • Masked Language Modeling (MLM): Bert์—์„œ ์‚ฌ์šฉํ•œ MLM์™€ ๋™์ผํ•œ task

    • masked randomly with a probability of 15%

    • replaced with a special token [MASK](80%), a random token(10%), remains unchanged(10%)

    • using the negative log-likelihood $$ L_{MLM}(ฮธ) = โˆ’E_{(v,w)โˆผD} log Pฮธ(w_{mT}|w_{\mT}, v) $$

  • Masked Object Classification (MOC)

    • masked randomly with a probability of 15%

    • replaced with a zero out the masked token(90%), keep the original token(10%)

    • add a fully-connected layer to predict the correct label from K object classes

    • using the cross-entropy(CE) loss

      • Faster R-CNN model as ground truth label: $l_{ฮธ}(v^{(i)}_{mI})$

      • the output vector corresponding to the masked token: $f_{ฮธ}(v^{(i)}_{mI})$ $$ L_{MOC} (ฮธ) = โˆ’E_{(v,w)โˆผD}{\sum^{Mโˆ’1}_{i=0}CE(l_{ฮธ}(v^{(i)}_{mI}), f_{ฮธ}(v^{(i)}_{mI}))} $$

  • Masked Region Feature Regression (MRFR)

    • This task aims to regress the embedding feature of each masked object

    • add a fully-connected layer on top of the output feature vector to make same dimension

    • using the L2 loss

      • Faster R-CNN model as ground truth feature: $ r_{ฮธ}(v^{(i)}_{mI}) $

      • the output feature corresponding to the masked token: $ h_{ฮธ}(v^{(i)}_{mI} $ $$ L_{MRFR} (ฮธ) = โˆ’E_{(v,w)โˆผD}{\sum^{Mโˆ’1}_{i=0}\lVert h_{ฮธ}(v^{(i)}_{mI}) - r_{ฮธ}(v^{(i)}_{mI})} \rVert_2^2 $$

  • Image Text Matching (ITM)

    • This task aims to learn the image-text alignment

    • Negative training data

      • randomly sample negative sentences for each image

      • randomly sample negative images for each sentence

    • addy a fully-connected layer on top to obtain the image-text similarity score: $s_ฮธ(v, w)$

    • using binary classification loss

      • the ground truth label: $y โˆˆ {0, 1}$ $$ L_{ITM}(ฮธ) = โˆ’E_{(v,w)โˆผD}[y \log s_ฮธ(v, w) + (1 โˆ’ y) \log (1 โˆ’ s_ฮธ(v, w))] $$

4.4 Fine-tuning tasks

Fine-tuning์€ MSCOCO and Flickr30k data-set์œผ๋กœ ์ง„ํ–‰ํ•˜์˜€์œผ๋ฉฐ, input sequence๋Š” pre-training๊ณผ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. Fine-tuning๊ณผ์ •์—์„œ๋Š” mask๋ฅผ ์‚ฌ์šฉํ•œ task๋“ค์€ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ , ITM๋งŒ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค. 3๊ฐ€์ง€ loss๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‹คํ—˜์„ ์ง„ํ–‰ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

  • Binary classification Loss $ L_{BCE}(ฮธ) = โˆ’E_{(v,w)}[y \log c_ฮธ(t_{(v,w)}) + (1 โˆ’ y) \log (1 โˆ’ c_ฮธ(t_{(v,w)}))] $

  • Multi-class Classification Loss $ L_{CE}(ฮธ) = โˆ’E^{(j)}_{(v,w)}{\sum^{P-1}_{j=0}}CE(s(t^{(j)}_{(v,w)}), l^{(j)}_{(v,w)}) $

  • Triplet Loss $ L_{Triplet}(ฮธ) = โˆ’E^{(j)}_{(v,w)}{\sum_(n^-โˆˆN)} \max [ 0, s(t_{(v,w)^+} ), s(n^-_h)] $

5 Experiments

  • Transformer: a 12-layer with 768 hidden units, 3072 intermediate units, and 12 attention heads

  • Dropout probability to 0.1

  • Use GELU as activation function

  • The max length of our input sequence is fixed to 144, 100 visual tokens + other linguistic tokens and special tokens

  • Use a Faster RCNN model pre-trained on Visual Genome dataset with 1600 categories

  • Pre-training:

    • data-set

      • 1-stage: Use the LAIT(10M), with parameter initialized from the BERT-base model

      • 2-stage: Use pre-training on public datasets: CC(3M), SBU(1M)

    • hyperparamter

      • batch size = 48

      • learning rate = 1e-4 with Adam optimizer

      • 17 epochs using 4 V100 GPUs

    • tasks

      • Use conditional mask in MLM, MOC and MRFR tasks

      • Only calculate the masked loss when the input pair is a positive sample

  • Fine-tuning

    • data-set

      • Use Flickr30k and MSCOCO

    • hyperparameter

      • batch size = 24

      • learning rate = 5e-5

      • 130 epochs using 4 V100 GPUs

    • tasks

      • Only use ITM

5.1 Evaluation for the Pre-trained Model

Zero-shot result of pre-train model: 1-stage pre-training์„ ์‚ฌ์šฉํ•œ ๋ฐฉ๋ฒ•๊ณผ ๋น„๊ตํ•˜์—ฌ, comparable results๋ฅผ ๋ณด์ž…๋‹ˆ๋‹ค. ์ถ”๊ฐ€์ ์œผ๋กœ, fine-tuning์—์„œ๋Š” ๋” ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์ฃผ๊ธฐ ๋•Œ๋ฌธ์—, multi-stage pre-training์ด single-stage pre-training ๋ณด๋‹ค usefulํ•œ ์ง€์‹์„ ํ•™์Šต ํ•œ๋‹ค๊ณ  ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

our multi-stage pre-training strategy learns more useful knowledge during pre-training, and can consequently contribute to the fine-tuning stage on the downstream tasks.

5.2 Evaluation for the Fine-tuned Model

Flickr30k์™€ MSCOCO์—์„œ ๋ชจ๋‘ state-of-the-art๋ฅผ ๋‹ฌ์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค. pre-training์—์„œ quality & noise distribution๊ฐ€ ๋‹ค๋ฅธ data-set์„ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ, data-set์„ ๋‚˜๋ˆ„์–ด ํ•™์Šตํ•˜๋Š”๊ฒŒ ํ•˜๋‚˜๋กœ ํ•™์Šตํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค ์ข‹๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

5.3 Ablation Studies

Pre-train dataset: LAIT, CC, SBU๋ฅผ ์กฐํ•ฉํ•œ test๋“ค์—์„œ, ์ œ์•ˆํ•œ multi-stage pre-trining์ด ๊ฐ€์žฅ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€์Šต๋‹ˆ๋‹ค.

Global image features: RoIs๊ฐ€ ์ „์ฒด image์˜ ์ •๋ณด๋ฅผ ๋‹ด์ง€ ๋ชป ํ• ๋•Œ๋ฅผ ๋Œ€๋น„ํ•˜์—ฌ, Global image feature๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์‹คํ—˜ํ•˜์˜€์Šต๋‹ˆ๋‹ค. Table4.1์˜ ๊ฒฐ๊ณผ์™€ ๊ฐ™์ด ์œ ์˜๋ฏธํ•œ ์„ฑ๋Šฅ ๊ฐœ์„ ์€ ๋ณด์ด์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.

Pre-train loss: MRFR task๋ฅผ ์œ ์˜๋ฏธ์„ฑ์„ Table4.2๋ฅผ ํ†ตํ•ด ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. MRFR task๋Š” ๋‹ค๋ฅธ task๋ณด๋‹ค ์–ด๋ ค์šด task์— ์†ํ•˜๋Š”๋ฐ, ์ด๋ฅผ ํ†ตํ•ด์„œ ์–ด๋ ค์šด task๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์ด ๋” ์ข‹์€ model์„ ์–ป๋Š”๋ฐ ๋„์›€์ด ๋œ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Number of objects (RoIs) from image: ์ง€๊ธˆ๊นŒ์ง€ ํ•œ ์‹คํ—˜๋“ค์€ ๋ชจ๋‘ 100๊ฐœ์˜ RoI๋ฅผ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค. RoI์˜ ์ˆ˜๊ฐ€ ์„ฑ๋Šฅ์— ๋ฏธ์น˜๋Š” ์˜ํ–ฅ์„ Table4.3์„ ํ†ตํ•ด ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ๋งŽ์€ objects๊ฐ€ ๋” ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ๋‚ด๋Š”๋ฐ ๋„์›€์ด ๋˜๋Š” ๊ฑธ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Fine-tune loss: Fine-tuning์— ์‚ฌ์šฉํ•œ ์—ฌ๋Ÿฌ loss์˜ ๊ฒฐ๊ณผ๋ฅผ Table4.4๋ฅผ ํ†ตํ•ด ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ๊ฐœ์ธ์ ์œผ๋กœ๋Š” Triplet loss๊ฐ€ ๊ฐ€์žฅ ์ข‹์„๊ฑฐ๋ผ๊ณ  ์ƒ๊ฐํ–ˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ๊ฒฐ๊ณผ์ ์œผ๋กœ Binary only๊ฐ€ ๊ฐ€์žฅ ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

6 Conclusion

์ด ๋…ผ๋ฌธ์˜ ํŠน์ง•์„ ์ •๋ฆฌํ•˜๋ฉด, ์•„๋ž˜ 3๊ฐ€์ง€๋กœ ์š”์•ฝํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  1. Transformer๋ฅผ ๊ธฐ๋ฐ˜, vision-language joint embedding architecture

  2. ๊ธฐ์กด data-set๋ณด๋‹ค ํฐ LAIT data-set

  3. Multi-stage pre-training with 4 tasks(MLM, MOC, MRFR, ITM)

Reference

Last updated