개발/2022-동계모각코

[2022-동계모각코] 2023-01-02(월) - 결과

안녕진 2023. 1. 2. 18:09
  • 계획
딥러닝 프레임워크 익히기
MNIST dataset을 사용한 Pytorch, Pytoch lightning 모델 구성
​MLP를 사용한 기본 학습 모델 구성

 

  • 결과
연구실 서버 접속, 환경 설정
Dataset, DataLoader 사용해보기
train_dataset, validation_dataset, test_dataset으로 분리

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets

from typing import Callable, Optional, Tuple

# train_dataset = datasets.MNIST(root="data",
#                                train=True,
#                                download=True)
# test_dataset = datasets.MNIST(root="data",
#                               train=False,
#                               download=True)

class MNISTDataset(Dataset):  # practive custom dataset
    def __init__(self, 
                 train: bool=True,
                 transform: Optional[Callable]=None,
                 target_transform: Optional[Callable]=None) -> None:  # data preprocessing
        super().__init__()
        self._dataset = datasets.MNIST("data", train, download=True, transform=None)
        self.images = self._dataset.data
        self.targets = self._dataset.targets
        self.transform = transform
        self.target_transform = target_transform
    
    def __getitem__(self, index: int):
        x = self.images[index]
        y = self.targets[index]
        
        if self.transform:
            x = self.transform(x)
        if self.target_transform:
            y = self.target_transform(y)
        
        return x, y
    
    def __len__(self) -> int:
        return self.images.shape[0]


train_dataset = MNISTDataset(train=True)
test_dataset = MNISTDataset(train=False)
# print(len(train_dataset))  # __len__
# print(train_dataset[0])  # __getitem__

validation_set_size = int(len(train_dataset) * 0.2)
train_dataset, validation_dataset = random_split(train_dataset, [len(train_dataset)-validation_set_size, validation_set_size])

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

- 구현한 코드