โฑ ํด๋น ๋ชจ๋ธ์์ LlaMA3.1์ Foundation ๋ชจ๋ธ๋ก ํ๋ ํ๊ตญ์ด ๋ฐ ํ๊ตญ์ ๋ค์ํ
๋ฌธํ์ ์ ์ฉํ ์ ์๋๋ก ํ๊ธฐ ์ํด
๊ฐ๋ฐ ๋์์ผ๋ฉฐ ์์ฒด ์ ์ํ 53๊ฐ ์์ญ์ ํ๊ตญ์ด ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ์ฌ ํ๊ตญ ์ฌํ ๊ฐ์น์
๋ฌธํ๋ฅผ ์ดํดํ๋ ๋ชจ๋ธ ์
๋๋ค. โ
DPO Train=100K
โถ ๋ชจ๋ธ ์ค๋ช
- ๋ชจ๋ธ๋ช
๋ฐ ์ฃผ์๊ธฐ๋ฅ:
ํด๋น ๋ชจ๋ธ์์ LlaMA3.1 ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก SFT ๋ฐฉ์์ผ๋ก ํ์ธํ๋๋ ๋ชจ๋ธ์
๋๋ค.
ํ๊ตญ์ด์ ํ๊ตญ์ ๋ค์ํ ๋ฌธํ์ ๋งฅ๋ฝ์ ์ดํดํ๋๋ก ์ค๊ณ๋์์ผ๋ฉฐ โจโจ, ์์ฒด ์ ์ํ 53๊ฐ ์์ญ์ ํ๊ตญ์ด
๋ฐ์ดํฐ๋ฅผ ํ์ฉํด ํ๊ตญ ์ฌํ์ ๊ฐ์น์ ๋ฌธํ๋ฅผ ๋ฐ์ํฉ๋๋ค.
์ฃผ์ ๊ธฐ๋ฅ์ผ๋ก๋ ํ
์คํธ ์์ฑ, ๋ํ ์ถ๋ก , ๋ฌธ์ ์์ฝ, ์ง์์๋ต, ๊ฐ์ ๋ถ์ ๋ฐ ์์ฐ์ด ์ฒ๋ฆฌ ๊ด๋ จ ๋ค์ํ ์์
์ ์ง์ํ๋ฉฐ,
ํ์ฉ ๋ถ์ผ๋ ๋ฒ๋ฅ , ์ฌ๋ฌด, ๊ณผํ, ๊ต์ก, ๋น์ฆ๋์ค, ๋ฌธํ ์ฐ๊ตฌ ๋ฑ ๋ค์ํ ๋ถ์ผ์์ ์์ฉ๋ ์ ์์ต๋๋ค.
- ๋ชจ๋ธ ์ํคํ
์ฒ:
ํด๋น ๋ชจ๋ธ์ LlaMA3.0 8B ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก, ํ๋ผ๋ฏธํฐ ์๋ 80์ต ๊ฐ(8B)๋ก ๊ตฌ์ฑ๋ ๊ณ ์ฑ๋ฅ ์ธ์ด ๋ชจ๋ธ์
๋๋ค.
์ด ๋ชจ๋ธ์LlaMA3.0 8B๋ฅผ ํ์ด๋ฐ์ด์
๋ชจ๋ธ๋ก ์ผ์, SFT(์ง๋ ๋ฏธ์ธ ์กฐ์ ) ๋ฐฉ์์ ํตํด ํ๊ตญ์ด์ ํ๊ตญ ๋ฌธํ์ ํนํ๋ ์ฑ๋ฅ์ ๋ฐํํ๋๋ก ํ๋ จ๋์์ต๋๋ค.
LlaMA3.0 8B์ ๊ฒฝ๋ํ๋ ๊ตฌ์กฐ๋ ๋น ๋ฅธ ์ถ๋ก ์๋์ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ๋ณด์ฅํ๋ฉฐ, ๋ค์ํ ์์ฐ์ด ์ฒ๋ฆฌ ์์
์ ์ ํฉํ๊ฒ ์ต์ ํ๋์ด ์์ต๋๋ค.
์ด ์ํคํ
์ฒ๋ ํ
์คํธ ์์ฑ, ์ง์์๋ต, ๋ฌธ์ ์์ฝ, ๊ฐ์ ๋ถ์๊ณผ ๊ฐ์ ๋ค์ํ ์์
์์ ํ์ํ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋๋ค.
โท ํ์ต ๋ฐ์ดํฐ
โธ ์ฌ์ฉ ์ฌ๋ก
ํด๋น ๋ชจ๋ธ์ ๋ค์ํ ์์ฉ ๋ถ์ผ์์ ์ฌ์ฉ๋ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด:
- ๊ต์ก ๋ถ์ผ: ์ญ์ฌ, ์ํ, ๊ณผํ ๋ฑ ๋ค์ํ ํ์ต ์๋ฃ์ ๋ํ ์ง์์๋ต ๋ฐ ์ค๋ช
์์ฑ.
- ๋น์ฆ๋์ค: ๋ฒ๋ฅ , ์ฌ๋ฌด, ์ธ๋ฌด ๊ด๋ จ ์ง์์ ๋ํ ๋ต๋ณ ์ ๊ณต ๋ฐ ๋ฌธ์ ์์ฝ.
- ์ฐ๊ตฌ ๋ฐ ๋ฌธํ: ํ๊ตญ ์ฌํ์ ๋ฌธํ์ ๋ง์ถ ์์ฐ์ด ์ฒ๋ฆฌ ์์
, ๊ฐ์ ๋ถ์, ๋ฌธ์ ์์ฑ ๋ฐ ๋ฒ์ญ.
- ๊ณ ๊ฐ ์๋น์ค: ์ฌ์ฉ์์์ ๋ํ ์์ฑ ๋ฐ ๋ง์ถคํ ์๋ต ์ ๊ณต.
- ์ด ๋ชจ๋ธ์ ๋ค์ํ ์์ฐ์ด ์ฒ๋ฆฌ ์์
์์ ๋์ ํ์ฉ๋๋ฅผ ๊ฐ์ง๋๋ค.
โน ํ๊ณ โโ
- ํด๋น ๋ชจ๋ธ์ ํ๊ตญ์ด์ ํ๊ตญ ๋ฌธํ์ ํนํ๋์ด ์์ผ๋,
ํน์ ์์ญ(์: ์ต์ ๊ตญ์ ์๋ฃ, ์ ๋ฌธ ๋ถ์ผ)์ ๋ฐ์ดํฐ ๋ถ์กฑ์ผ๋ก ์ธํด ๋ค๋ฅธ ์ธ์ด ๋๋
๋ฌธํ์ ๋ํ ์๋ต์ ์ ํ์ฑ์ด ๋จ์ด์ง ์ ์์ต๋๋ค.
๋ํ, ๋ณต์กํ ๋
ผ๋ฆฌ์ ์ฌ๊ณ ๋ฅผ ์๊ตฌํ๋ ๋ฌธ์ ์ ๋ํด ์ ํ๋ ์ถ๋ก ๋ฅ๋ ฅ์ ๋ณด์ผ ์ ์์ผ๋ฉฐ,
ํธํฅ๋ ๋ฐ์ดํฐ๊ฐ ํฌํจ๋ ๊ฒฝ์ฐ ํธํฅ๋ ์๋ต์ด ์์ฑ๋ ๊ฐ๋ฅ์ฑ๋ ์กด์ฌํฉ๋๋ค.
โบ ์ฌ์ฉ ๋ฐฉ๋ฒ
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("AIDX-ktds/ktdsbaseLM-v0.12-onbased-llama3.1")
model = AutoModel.from_pretrained("AIDX-ktds/ktdsbaseLM-v0.12-onbased-llama3.1")
input_text = """ ใ๊ตญ๋ฏผ๊ฑด๊ฐ๋ณดํ๋ฒใ์ 44์กฐ, ใ๊ตญ๋ฏผ๊ฑด๊ฐ๋ณดํ๋ฒ ์ํ๋ นใ์ 19์กฐ,ใ์ฝ๊ด์ ๊ท์ ์ ๊ดํ ๋ฒ๋ฅ ใ์ 5์กฐ, ใ์๋ฒใ์ 54์กฐ ์ฐธ์กฐ ํ๋จ ํด์ค""" + " ๋ต๋ณ:"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(**inputs, max_length=1024, temperature=0.5, do_sample=True, repetition_penalty=1.15)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)