Bert_fine_tuning
๊ฐ์ธ study๋ฅผ ์ํ ์๋ฃ์ ๋๋ค. ๊ทธ๋ฌ๋ค๋ณด๋ ๋ด์ฉ์ ์๋ชป๋ ์ ์ด ์์ต๋๋ค.
1 Advantages of Fine-Tuning
pre-trained๋ model์ ์ฌ์ฉํ, Fine-Tuning์ ์๋์ ๊ฐ์ ์ฅ์ ์ด ์์ต๋๋ค.
Time
Less Data
Better Results
์ค์ ๊ตฌํ ์์ ๋ก, huggingface[1]์์ ์ ๊ณตํ๋ pre-trained model์ ๋ฐํ์ผ๋ก, fine-tuning์ ์งํํด๋ณด๋ ค๊ณ ํฉ๋๋ค. ๊ตฌํํ๊ณ ์ ํ๋ model์ ์๋์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ต๋๋ค.[2] ์ฝ๋๋ ์ฃผ๋ก [3]์ ์ฐธ์กฐํ์ต๋๋ค. fine-tuning์ ์ํ code๋ ํฌ๊ฒ ์๋์ ๊ฐ์ด ๋๋์ด์ ธ ์์ต๋๋ค.
bertFineTuningWithConnectionData.py : fine-tuning์ ์งํํฉ๋๋ค.
ConnectionBert.py : pre-trained ๋ model์ load ํฉ๋๋ค.
ConnectionDataset.py : fine-tuning์ ์ํ Dataset, DataLoader๊ฐ ์ ์ ๋์ด ์์ต๋๋ค.
2 ๊ตฌํ
1 bertFineTuningWithConnectionData
transformers
๋ฅผ ์ฌ์ฉํ์ฌ, pre-trained model๋ฅผ ๋ถ๋ฌ์ค๋ ๋ช
๋ น์ด๋ ๋งค์ฐ ๊ฐ๋จํฉ๋๋ค. ํ์ง๋ง, ํน๋ณํ ์ฌ์ ์ผ๋ก ๋ฏธ๋ฆฌ ๋ค์ด์ ๋ฐ์์ ์ฌ์ฉํด์ผ ํ ๊ฒฝ์ฐ, ์๋ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค.
huggingsface ์ ์ -> MODELS(์ฐ์ธก์๋จ) -> ์ํ๋ model ๊ฒ์ ->
List all files in model
ํด๋ฆญ ->config.json
,pythorch_model.bin
,vocab.txt
๋ฅผ ์ํ๋ directory์ ์ ์ฅ -> directory load
์ฌ๊ธฐ์๋ โtransformers\bert\bert-base-uncased"์ ์ 3 ํ์ผ์ ์ ์ฅํด ๋๊ณ ์ฌ์ฉํ์ต๋๋ค.
2 ConnectionBert
pre-trained ๋ data๋ฅผ loadํ์ฌ ์ฌ์ฉํ๋ ๊ฒ์ ๋งค์ฐ ๊ฐ๋จํฉ๋๋ค.
model์ ๋ํ ์์ธํ ์ค๋ช
์ DocsยปTransformers์์ ํ์ธ ํ ์ ์์ต๋๋ค. ๊ฐ์ฅ ๊ธฐ๋ณธ์ด ๋๋ BertModel์ ๊ฒฝ์ฐ, embedding layer + bertEncoder + pooled layer๋ก ๋์ด์์ต๋๋ค. ์์ธํ ๋ด๋ถ weight parameter๋ print(model)
๋ก ํ์ธ ํ ์ ์์ต๋๋ค.
BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(1~11): BertLayer()
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
3. ConnectionDataset
Dataset์ __len__()
๊ณผ __getitem__()
๋ง ๊ตฌํํด์ฃผ๋ฉด, ์ฝ๊ฒ ๊ตฌํํ ์ ์์ต๋๋ค.
Reference
[1] huggingface
Last updated
Was this helpful?