본문 바로가기
프로그램/SKT FLY AI

SKT FLY AI : 34일차

by hsloth 2023. 8. 10.

DataLoader

  • 데이터셋을 batch기반의 딥러닝 모델 학습을 위해서 미니배치 형태로 만들어서 실제 학습할 때 이용할 수 있도록 형태를 만들어준다.
  • 길이가 다른 데이터를 DataLoader에서 mini batch로 로드할 경우 에러가 발생한다.
# 에러 발생
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch\_size=2)  
sample = next(iter(loader))

실습


import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import lightning.pytorch as pl
import torchmetrics


class BaselineNetwork(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.lstm = torch.nn.LSTM(input_size=66, hidden_size=256, num_layers=1, batch_first=True, bidirectional=True)
        self.fc = torch.nn.Linear(in_features=512, out_features=9)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=9, ignore_index=-100)
        self.lr = 0.001

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.9, patience=5, threshold=1e-4, min_lr=1e-5,
            cooldown=10)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "train_loss",
                "strict": False,
                "interval": "epoch",
                "frequency": 1
            }
        }
    def forward(self, x: torch.Tensor, lengths: torch.Tensor):
        # 1. 4차원 입력 ( B x N x 33 x 4 )을 3차원 입력 ( B x N x 33*4 )로 flatten
        x = torch.flatten(x, 2)

        # 2. run lstm
        packed_input = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input)
        padded_packed_output, _ = pad_packed_sequence(packed_output, batch_first=True)

        # 3. classification head
        output = self.fc(padded_packed_output)

        return output

    def training_step(self, batch, batch_idx):
        # batch : Dict {'x': torch.Tensor, 'lengths': torch.Tensor, 'labels': torch.Tensor}

        output = self(batch['x'], batch['lengths'])

        # calculate loss
        logits = output.transpose(1, 2)  # [B x N x C] -> [B x C x N]
        loss = self.criterion(logits, batch['labels'])

        self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # batch : Dict {'x': torch.Tensor, 'lengths': torch.Tensor, 'labels': torch.Tensor}

        output = self(batch['x'], batch['lengths'])

        # calculate loss
        logits = output.transpose(1, 2)  # [B x N x C] -> [B x C x N]
        loss = self.criterion(logits, batch['labels'])

        # metric
        accuracy = self.accuracy(logits, batch['labels'])

        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_accuracy", accuracy, on_epoch=True, prog_bar=True, logger=True)
        return loss
  • 학습에 필요한 DATA Loader 정의
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from utils.data import GolfDBPoseDataset, collate_sequence, SequenceBatchSampler
import utils.transform as T


train_dataset = GolfDBPoseDataset(
    split='train',
    transform=T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomRotation(),
        T.RandomPerspective(),
        # T.RandomNoise(),
        T.Normalize()])
)
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=collate_sequence,
    batch_sampler=SequenceBatchSampler(
        RandomSampler(train_dataset), batch_size=5000, drop_last=True)
)

val_dataset = GolfDBPoseDataset(
    split='val',
    transform=T.Normalize()
)
val_dataloader = DataLoader(
    val_dataset,
    collate_fn=collate_sequence,
    batch_sampler=SequenceBatchSampler(
        SequentialSampler(val_dataset), batch_size=5000, drop_last=False)
)

model = BaselineNetwork()
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')

trainer = pl.Trainer(
    devices=1, check_val_every_n_epoch=5,
    max_epochs=300, callbacks=[lr_monitor])
trainer.fit(
    model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

import os
import pickle

train_data_path = os.path.join('data', 'train.pkl')
with open(train_data_path, 'rb') as file:
    train_data = pickle.load(file)

type(train_data)

train_data.keys()

poses = train_data['poses']
events = train_data['events']
# 전체 비디오 클립의 갯수
len(poses)  # == len(events)

x = poses[0]
event = events[0]

import numpy as np

clip_length = len(x)
label = np.zeros(clip_length, dtype=int)

for i in range(1, len(event) + 1):
    start = event[i - 1] # 구간 시작
    to = event[i] if i < len(event) else None # 구간 종료, 마지막은 finish부터 끝까지
    label[start:to] = i

import os
import pickle
import numpy as np
from torch.utils.data import Dataset


class GolfDBPoseDataset(Dataset):
    def __init__(self, split: str = 'train'):
        # 전체 데이터 로드
        with open(os.path.join('data', f'{split}.pkl'), 'rb') as file:
          self.data = pickle.load(file)

    def __len__(self):
        # 데이터셋의 전체 길이 리턴
        return len(self.data['poses'])

    def __getitem__(self, index):
        # 인덱스를 사용해 데이터 접근
        # return x, label
        x = self.data['poses'][index]
        event = self.data['events'][index]

        label = np.zeros(len(x), dtype=int)
        for i in range(1, len(event) + 1):
          start = event[i-1]
          to = event[i] if i < len(event) else None

        return x, label