NLP
[pytorch] DataLoader
안녕진
2022. 7. 27. 02:18
개인 공부 목적의 포스팅입니다. 잘못된 정보가 있다면 댓글 남겨주시면 수정하도록 하겠습니다.
감사합니다.
💡 주제 설명
pytorch에서 정말 중요한 DataLoader에 관해 알아봤다.
📌 배경
pytorch 공식 tutorial만 한 번 읽어본 상황이라, 무턱대고 생각 없이 사용하던 DataLoader에 관해 제대로 알아보고 싶었다.
🔍 과정
- 기본 개념 [공식 문서]
pytorch가 데이터를 로딩하는 중심에는 torch.utils.data.DataLoader가 있는데,
데이터를 iterable 객체로 표현하는 역할을 한다.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
기본 역할
dataset과 sampler를 결합하여 주어진 dataset의 iterable을 제공한다.
찾아본 parameter들
- dataset
load 할 dataset.
두 종류의 dataset이 있다. (map-style, iterable-style)
- shuffle
매 epoch마다 data가 reshuffle 될지 결정한다.
shuffle이 설정되면, sampler를 설정할 수 없다.
source code를 확인해보면
이렇게 sampler를 설정하여 원하는 결과를 얻는 것을 확인할 수 있다.if shuffle: sampler = RandomSampler(dataset, generator=generator) else: sampler = SequentialSampler(dataset)
- sampler
map-style dataset에서 sample을 뽑아내는 방식을 정의한다.
index를 어떤 순서로 배치할지 결정한다고 생각하면 된다.
(random이면 index가 random, sequential이면 range(len(self.data_source)) 이렇게 순서대로)
__len__이 구현된 어떠한 Iterable도 가능하다.
sampler를 명시했다면, shuffle은 사용하면 안 된다. (shuffle의 default는 False)
- batch_sampler
sampler와 비슷하다.
근데 이건 batch의 indices를 반환한다.
batch_sampler를 사용하면, batch_size, shuffle, sampler, drop_last를 사용하면 안 된다.
batch_sampler 없이 batch_size를 명시해주면, 위와 같이 알아서 batch_size로 batch_sampler를 만들어 사용한다.if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last)
[소스 코드]
결국 sampler나 batch_sampler라는 것은, dataset에서 어떻게 뽑아올 거냐 결정하는 것이다.
매 epoch마다 reshuffle을 할지 결정하는 shuffle은, 그 자체로 dataset에서 sample을 어떻게 뽑아올지 결정해버리기 때문에, sampler, batch_sampler와 사용할 수 없다.
또 batch_sampler의 경우 그 자체로 batch 형태로 반환하겠다고 결정한 것이므로, batch에 관한 설정인 batch_size, drop_last도 사용할 수 없다.
- collate_fn
sample들을 실제 Tensor의 mini-batch로 표현해준다. [default collate 소스 코드]
default collate가 Tensor로 잘 변환해서 합쳐주지만, 문제가 생기는 경우 이용해봐야겠다.
- drop_last
dataset의 전체 크기가 batch_size로 나누어 떨어지지 않는 경우, (예를 들면, 300/64 => 4*64 + 44)
마지막의 44 크기의 불완전한 batch를 버릴 것인지 결정하는 것이다.