Gemma2-ko-9B 모델을 법률 QA 셋에 대해 Fine-tuning 진행 해 보았습니다.
허깅페이스의 trl, peft 라이브러리를 사용하여 LoRA, SFT Tuning을 진행하였습니다.
LoRA에 대한 개념을 모르시는 분은 아래 게시물을 참고해주세요🙌
LLM fine-tuning 방법에 대해 잘 모르거나, 처음 fine-tuning을 진행해보시는 분들이 참고하면 좋겠습니다.
혹여 게시물을 보고 궁금한점이 있다면 언제든 댓글 남겨주시길 바랍니다.😊
튜닝 모델은 아래에서 확인해보실 수 있습니다.
https://huggingface.co/architectyou/law-gemma-2-ko-9b-it
https://github.com/architectyou/gemma-2-ko-QA-Instruct.git
데이터셋 준비 & 태스크 선정
https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&dataSetSn=71723
AI hub에서 본 데이터(Law-QA-Dataset)를 다운받아 준비하였습니다.
저는 QA 셋에 대한 fine-tuning을 진행할 것이기 때문에 이에 맞는 데이터 전처리가 필요합니다.
사실 튜닝에 있어 이 부분이 가장 중요하다고 해도 과언이 아닐 것 같습니다.
데이터 전처리
def load_json_files(directory):
data = []
for filename in os.listdir(directory):
if filename.endswith('.json'):
with open(os.path.join(directory, filename), 'r', encoding='utf-8') as f:
data.append(json.load(f))
return data
def create_dataset(data):
dataset_dict = {
"id": [],
"question": [],
"answer": [],
"context": []
}
for item in data:
dataset_dict["id"].append(item["id"])
dataset_dict["question"].append(item["question"])
dataset_dict["answer"].append(item["answer"])
dataset_dict["context"].append(f"{item['title']}\n{item['commentary']}")
return Dataset.from_dict(dataset_dict)
# dataset load & preprocessing
data_directory = "./dataset/law_QA_dataset/"
all_data = load_json_files(data_directory)
train_data, val_data = train_test_split(all_data, test_size=0.1, random_state=42)
train_dataset = create_dataset(train_data)
val_dataset = create_dataset(val_data)
다음과 같이 다운 받은 데이터셋(json)의 형식에 맞춰
json 파일들을 로드 후, QA 태스크에 맞게 id, question, answer, context 별로 데이터를 분류해 재구성하였습니다.
저는 RAG를 기반으로 한 QA셋 태스크에 맞춘 학습이기 때문에 context 항목을 따로 추가 해 주었습니다.
필요 라이브러리 import & Base 모델 load
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
)
from datasets import Dataset
import os, torch, json, wandb, subprocess
from sklearn.model_selection import train_test_split
from peft import (
get_peft_model,
LoraConfig,
TaskType,
)
import torch.nn as nn
from trl import SFTTrainer
다음과 같이 학습에 필요한 라이브러리들을 import 해 줍니다.
base_model = "/data/gguf_models/ko-gemma-2-9b-it/"
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation='eager',
)
다음과 같이 학습하고자 하는 base 모델을 로드해주었습니다.
https://huggingface.co/rtzr/ko-gemma-2-9b-it
저는 위 모델을 base 모델로 사용하였습니다. (로컬에 모델을 다운받아 사용했기 때문에 경로가 다릅니다.)
LoRA (PEFT) Config 설정
제가 학습한 환경은 H100 80GB 입니다. A100에서도 테스트 해 봤는데 시간 차이가 상당하더군요.
LoRA 학습을 진행해주기 위해선 LoRA Config를 설정해주어야 합니다.
R값을 얼마로 지정할 것인지, dropout은 어떻게 설정할 것인지, 어떤 가중치들을 학습할 것인지 등이 중요합니다.
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=modules,
)
# model, tokenizer = setup_chat_format(model, tokenizer)
# LoRA 모델 생성
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# 모델을 명시적으로 훈련 모드로 설정
model.train()
LoRA Config는 다음과 같이 설정하였습니다.
하드웨어가 넉넉하다고 생각해 r=18으로 설정해 두었는데, 기본은 8에서 하드웨어 특성에 따라 r값을 낮춰 학습을 진행할 수 있습니다.
def find_all_linear_names(model):
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names:
lora_module_names.remove('lm_head')
return list(lora_module_names)
modules = find_all_linear_names(model)
학습하고자 하는 target module은 다음과 같이 설정하였습니다.
target module을 보통 q_val, v_val, 등 attention output에 대한 값을 target module로 학습시키는데
명확히 어떤 값들을 훈련해야 할 지 몰라 lm_head 를 제외한 부분을 모두 학습시켜주기로 했습니다.
target module값을 변화시킴에 따라도 훈련결과가 많이 달라집니다.
model.print_trainable_parameters()
를 사용하면 훈련하는 parameter 개수를 확인할 수 있습니다.
Prompt
데이터셋 처리만큼 llm tuning에 있어 중요한 부분입니다.
어떤 프롬프트로 튜닝하느냐에 따라 task가 달라지기 때문입니다.
def generate_prompts(examples):
prompt_list=[]
for context, question, answer in zip(examples["context"], examples["question"], examples["answer"]):
prompt_list.append(
f"""<bos><start_of_turn>user
다음 문서를 참고하여 질문에 답변해주세요:
Context: {context}
Question: {question}
<end_of_turn>
<start_of_turn>model
{answer}<end_of_turn><eos>"""
)
return prompt_list
주의사항!
위에서 전처리한 데이터를 바탕으로 프롬프트를 다음과 같이 작성해주었습니다.
여기서 주의할 점은, gemma 모델에 쓰이는 chat prompt를 사용했다는 것인데,
모델별로 동작하는 chat prompt가 다릅니다. 따라서 base모델로 어떤 모델을 선정했는지에 따라 적합한 chat template을 적용해주어야 합니다.
https://huggingface.co/docs/transformers/v4.35.0/chat_templating
huggingface 라이브러리에는 Chat 모델별 다양한 탬플릿들이 소개되어 있습니다.
gemma 모델의 경우
user / model로 나뉘어져
<start_of_turn><end_of_turn> 사이에 대화 내용이 반복되는 것을 확인할 수 있습니다.
모델 Chat template을 맞춰주지 않으면
훈련 후의 모델 성능에 큰 타격이 생기기 때문에 꼭 어떤 모델을 base model로 하고 튜닝하는지 확인하고 맞춰주어야 합니다.
SFT(Supervised Fine-tuning) Training Config 설정
다음으론 SFT Config를 설정해줍니다.
# 훈련 인자 설정
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8,
optim="paged_adamw_32bit",
eval_strategy="steps",
eval_steps=0.1,
logging_dir="./logs",
logging_steps=11,
warmup_steps=10,
logging_strategy="steps",
learning_rate=2e-4,
group_by_length=True,
bf16=True,
report_to="wandb",
run_name="gemma-2-9b-lora-bf16-0926",
)
# Trainer 초기화 및 훈련
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
peft_config=peft_config,
max_seq_length = 512,
args=training_args,
formatting_func=generate_prompts,
)
여기서 주의할 점은 LLM은 기존의 AI 모델과 달리 Learning rate를 너무 낮게 설정해두면
학습이 거의 진행되지 않습니다.
기본적으로 2e-4 부터 설정 후 테스트 결과를 바탕으로 하이퍼 파라미터 값을 조정해보시길 바랍니다.
학습
model.config.use_cache = False
trainer.train()
위와 같이 훈련시켜주면 마무리 됩니다.
LoRA 어댑터 병합
LoRA를 이용해 PEFT를 진행해 주었기 때문에, 학습한 어댑터를
원본 모델에 병합시켜주어야 합니다.
model = PeftModel.from_pretrained(base_model, model_name)
merged_model = model.merge_and_unload()
위와 같이 작성하면 merged_model 은 학습한 LoRA 어댑터가 적용된 모델이 됩니다.
이 후, 모델에 대해 테스트를 해 보셔도 좋고, 로컬에 저장해도 좋고, huggingface에 업로드 해도 좋습니다.
결과
튜닝 결과 eval loss가 잘 떨어진 것을 확인할 수 있습니다.
Reference
https://devocean.sk.com/blog/techBoardDetail.do?ID=165703&boardType=techBlog
https://www.datacamp.com/tutorial/fine-tuning-gemma-2
'LLM > PlayGround' 카테고리의 다른 글
LangChain Arxiv 문서 불러오기 (0) | 2024.04.11 |
---|---|
LangChain, Chromadb 이용 문서 기반 RAG 구현하기 (0) | 2024.04.04 |
OpenAI API Key 발급 및 환경변수에 관리하기 + colab 에서 사용하기 (0) | 2024.03.06 |
[LangChain for LLM Application Development] 랭체인 Agent (0) | 2024.02.03 |
[LangChain for LLM Application Development] 랭체인 Evaluation (0) | 2024.02.03 |