개발/2022-동계모각코

[2022-동계모각코] 2023-02-05(일) - 결과

안녕진 2023. 2. 5. 16:30
  • 개념
huggingface의 transformers와 datasets를 사용하여 pretrained model을 fine tuning하는 방법을 익힌다.

 

  • 결과
KLUE 데이터셋 - ynat을 이용하여 Topic Classification하는 모델을 학습시켜봤습니다.

import argparse

import numpy as np

import torch
import torchmetrics

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from transformers import AutoTokenizer, BertForSequenceClassification
from datasets import load_dataset


PRETRAINED_MODEL = "kykim/bert-kor-base"
DATASET = "klue"
SUBSET_NAME_FOR_TC = "ynat"


class Dataset(torch.utils.data.Dataset):
    def __init__(self,
                 texts,
                 labels,
                 tokenizer,
                 max_seq_length,
                 
                 **kwargs):
        super().__init__()
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
    
    
    def __len__(self):
        return len(self.texts)
    
    
    def __getitem__(self, index):
        text, label = self.texts[index], self.labels[index]
        inputs = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_seq_length, return_tensors="pt")

        input_ids = inputs["input_ids"]
        token_type_ids = inputs["token_type_ids"]
        attention_mask = inputs["attention_mask"]
        
        return input_ids.squeeze(0), token_type_ids.squeeze(0), attention_mask.squeeze(0), label


class Datamodule(pl.LightningDataModule):
    def __init__(self,
                 max_seq_length,
                 batch_size,
                 num_workers,
                 **kwargs) -> None:
        super().__init__()
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        
    
    def prepare_data(self):
        AutoTokenizer.from_pretrained(PRETRAINED_MODEL)
        load_dataset(DATASET, SUBSET_NAME_FOR_TC)
    
    
    def setup(self, stage):
        self.tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)
        all_dataset = load_dataset(DATASET, SUBSET_NAME_FOR_TC)
        
        train_test_data = all_dataset["train"].train_test_split(train_size=0.8, test_size=0.2)
        train_data = train_test_data["train"]
        test_data = train_test_data["test"]
        val_data = all_dataset["validation"]
        
        if stage == "fit":
            self.train_dataset = Dataset(train_data["title"], train_data["label"], self.tokenizer, self.max_seq_length)
            self.val_dataset = Dataset(val_data["title"], val_data["label"], self.tokenizer, self.max_seq_length)
        if stage == "test":
            self.test_dataset = Dataset(test_data["title"], test_data["label"], self.tokenizer, self.max_seq_length)
    
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers,
                                           shuffle=True)


    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers)


    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers)


class TopicClassifier(pl.LightningModule):
    def __init__(self,
                 lr,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model = BertForSequenceClassification.from_pretrained(PRETRAINED_MODEL, num_labels=7)
        self.lr = lr
    
    
    def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
        outputs = self.model(input_ids=input_ids,
                   token_type_ids=token_type_ids,
                   attention_mask=attention_mask,
                   labels=labels)
        return outputs.loss, outputs.logits
    
    
    def training_step(self, batch, batch_idx):
        loss, logits = self(*batch)
        self.log_dict({
            "train_loss": loss.item(),
        })
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        loss, logits = self(*batch)
        labels = batch[3].detach().cpu()
        predicts = torch.argmax(logits, dim=1).detach().cpu()
        
        return {
            "val_predicts": predicts,
            "val_labels": labels,
            "val_loss": loss.item(),
        }
        
    
    def validation_epoch_end(self, outputs):
        predicts = []
        labels = []
        loss = []
        for output in outputs:
            predicts.append(output["val_predicts"])
            labels.append(output["val_labels"])
            loss.append(output["val_loss"])
        
        self.log_dict({
            "val_loss": np.mean(loss),
            "val_f1": torchmetrics.functional.f1_score(torch.cat(predicts), torch.cat(labels), num_classes=7, task="multiclass"),
        })
    
    
    def test_step(self, batch, batch_idx):
        loss, logits = self(*batch)
        labels = batch[3].detach().cpu()
        predicts = torch.argmax(logits, dim=1).detach().cpu()
        
        return {
            "test_predicts": predicts,
            "test_labels": labels,
            "test_loss": loss.item(),
        }
        
    
    def test_epoch_end(self, outputs):
        predicts = []
        labels = []
        loss = []
        for output in outputs:
            predicts.append(output["test_predicts"])
            labels.append(output["test_labels"])
            loss.append(output["test_loss"])
        
        self.log_dict({
            "test_loss": np.mean(loss),
            "test_f1": torchmetrics.functional.f1_score(torch.cat(predicts), torch.cat(labels), num_classes=7, task="multiclass"),
        })
    
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.lr)
    
    
    @staticmethod
    def add_model_specific_args(parent_parser: argparse.ArgumentParser):
        parser = parent_parser.add_argument_group("TopicClassifier")
        parser.add_argument("--lr", type=float, default=0.0001)
        return parent_parser


def main():
    ######
    # args
    ######
    parser = argparse.ArgumentParser()
    parser = TopicClassifier.add_model_specific_args(parser)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--num_workers", type=int, default=16)
    parser.add_argument("--max_seq_length", type=int, default=64)
    args = vars(parser.parse_args())
    
    ######
    # trainer
    ######
    trainer = pl.Trainer(
        callbacks=[
            EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=5),
            ModelCheckpoint(monitor="val_loss", mode="min", dirpath="./checkpoints", filename="{val_f1:.2f}-{epoch}", save_top_k=3),
        ],
        accelerator="gpu",
        devices=4,
    )
    
    ######
    # datamodule
    ######
    datamodule = Datamodule(**args)

    ######
    # model
    ######
    model = TopicClassifier(**args)
    
    ######
    # train
    ######
    trainer.test(model, datamodule)
    trainer.fit(model, datamodule)
    trainer.test(model, datamodule)


if __name__ == "__main__":
    main()​


f1_score 0.86
훈련 시간을 단축하고 싶어서 multi processing하도록 했는데, log가 여러 개로 분리되어 저장됐습니다.
이 부분은 더 알아보고 보완해야겠습니다.