Bert_fine_tuning

๊ฐœ์ธ study๋ฅผ ์œ„ํ•œ ์ž๋ฃŒ์ž…๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‹ค๋ณด๋‹ˆ ๋‚ด์šฉ์— ์ž˜๋ชป๋œ ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

1 Advantages of Fine-Tuning

pre-trained๋œ model์„ ์‚ฌ์šฉํ•œ, Fine-Tuning์€ ์•„๋ž˜์™€ ๊ฐ™์€ ์žฅ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

  1. Time

  2. Less Data

  3. Better Results

์‹ค์ œ ๊ตฌํ˜„ ์˜ˆ์ œ๋กœ, huggingface[1]์—์„œ ์ œ๊ณตํ•˜๋Š” pre-trained model์„ ๋ฐ”ํƒ•์œผ๋กœ, fine-tuning์„ ์ง„ํ–‰ํ•ด๋ณด๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ๊ตฌํ˜„ํ•˜๊ณ ์ž ํ•˜๋Š” model์€ ์•„๋ž˜์™€ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.[2] SBERT architecture with consine-smiliarity ์ฝ”๋“œ๋Š” ์ฃผ๋กœ [3]์„ ์ฐธ์กฐํ–ˆ์Šต๋‹ˆ๋‹ค. fine-tuning์„ ์œ„ํ•œ code๋Š” ํฌ๊ฒŒ ์•„๋ž˜์™€ ๊ฐ™์ด ๋‚˜๋ˆ„์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค.

  1. bertFineTuningWithConnectionData.py : fine-tuning์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

  2. ConnectionBert.py : pre-trained ๋œ model์„ load ํ•ฉ๋‹ˆ๋‹ค.

  3. ConnectionDataset.py : fine-tuning์„ ์œ„ํ•œ Dataset, DataLoader๊ฐ€ ์ •์˜ ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

2 ๊ตฌํ˜„

1 bertFineTuningWithConnectionData

transformers๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ, pre-trained model๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋ช…๋ น์–ด๋Š” ๋งค์šฐ ๊ฐ„๋‹จํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ, ํŠน๋ณ„ํ•œ ์‚ฌ์ •์œผ๋กœ ๋ฏธ๋ฆฌ ๋‹ค์šด์„ ๋ฐ›์•„์„œ ์‚ฌ์šฉํ•ด์•ผ ํ•  ๊ฒฝ์šฐ, ์•„๋ž˜ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

huggingsface ์ ‘์† -> MODELS(์šฐ์ธก์ƒ๋‹จ) -> ์›ํ•˜๋Š” model ๊ฒ€์ƒ‰ -> List all files in model ํด๋ฆญ -> config.json, pythorch_model.bin, vocab.txt๋ฅผ ์›ํ•˜๋Š” directory์— ์ €์žฅ -> directory load

์—ฌ๊ธฐ์„œ๋Š” โ€œtransformers\bert\bert-base-uncased"์— ์œ„ 3 ํŒŒ์ผ์„ ์ €์žฅํ•ด ๋†“๊ณ  ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค.

2 ConnectionBert

pre-trained ๋œ data๋ฅผ loadํ•˜์—ฌ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ๊ฐ„๋‹จํ•ฉ๋‹ˆ๋‹ค.

model์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์„ค๋ช…์€ DocsยปTransformers์—์„œ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ€์žฅ ๊ธฐ๋ณธ์ด ๋˜๋Š” BertModel์˜ ๊ฒฝ์šฐ, embedding layer + bertEncoder + pooled layer๋กœ ๋˜์–ด์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด๋ถ€ weight parameter๋Š” print(model)๋กœ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1~11): BertLayer()
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)      

3. ConnectionDataset

Dataset์€ __len__()๊ณผ __getitem__()๋งŒ ๊ตฌํ˜„ํ•ด์ฃผ๋ฉด, ์‰ฝ๊ฒŒ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Reference

Last updated

Was this helpful?