반응형

CNN-LSTM은 CNN으로 공간적 정보를 추출하고 LSTM으로 시계열 정보를 추출하여 동영상을 입력하는 테스크(행동인식 등)에서 좋은 성능을 얻을 수 있습니다.

그러나 I3D가 나오면서 CNN-LSTM모델의 단점을 거론하였는데요 바로 CNN에서 추출된 특징이 하이레벨 정보만 포함하고 있고 로우레벨은 LSTM으로 전달이 되지않아 더 높은 정확도를 얻을 수 없다는 점입니다.

우리는 I3D는 제쳐두고 CNN-LSTM 모델을 파이토치로 구현하는 방법을 알아봅시다.

 

import torch
import torch.nn as nn
from torch.autograd import Variable
from MobileNetV2 import MobileNetV2
from efficientnet_pytorch import EfficientNet

class EventDetector(nn.Module):
    def __init__(self, lstm_layers, lstm_hidden, bidirectional=True, dropout=True):
        super(EventDetector, self).__init__()
        self.lstm_layers = lstm_layers
        self.lstm_hidden = lstm_hidden
        self.bidirectional = bidirectional
        self.dropout = dropout

        net = EfficientNet.from_pretrained('efficientnet-b0',include_top=False )        
        self.cnn = net        
        self.rnn = nn.LSTM(int(1280),
                           self.lstm_hidden, self.lstm_layers,
                           batch_first=True, bidirectional=bidirectional)
        if self.bidirectional:
            self.lin = nn.Linear(2*self.lstm_hidden, 9)
        else:
            self.lin = nn.Linear(self.lstm_hidden, 9)
        if self.dropout:
            self.drop = nn.Dropout(0.5)

    def init_hidden(self, batch_size):
        if self.bidirectional:
            return (Variable(torch.zeros(2*self.lstm_layers, batch_size, self.lstm_hidden).cuda(), requires_grad=True),
                    Variable(torch.zeros(2*self.lstm_layers, batch_size, self.lstm_hidden).cuda(), requires_grad=True))
        else:
            return (Variable(torch.zeros(self.lstm_layers, batch_size, self.lstm_hidden).cuda(), requires_grad=True),
                    Variable(torch.zeros(self.lstm_layers, batch_size, self.lstm_hidden).cuda(), requires_grad=True))

    def forward(self, x, lengths=None):
        batch_size, timesteps, C, H, W = x.size()
        self.hidden = self.init_hidden(batch_size)

        # CNN forward
        c_in = x.view(batch_size * timesteps, C, H, W)
        c_out = self.cnn(c_in)
        c_out = c_out.mean(3).mean(2)
        if self.dropout:
            c_out = self.drop(c_out)

        # LSTM forward
        r_in = c_out.view(batch_size, timesteps, -1)
        r_out, states = self.rnn(r_in, self.hidden)
        out = self.lin(r_out)
        out = out.view(batch_size*timesteps,9)

        return out



위 코드는 CNN-LSTM에서 CNN부분에 EfficientNet을 적용할 때의 코드입니다. 9가지 클래스의 분류문제입니다.

Pytorch Efficientnet코드는 위의 from efficientnet_pytorch import EfficientNet 에서 가져올 수 있습니다.

classifier는 사용하지 않기때문에 include_top=false로 설정해줍니다.

이후 cnn을 efficientNet으로 설정해두고 rnn = LSTM 설정을 해줍니다. 

EfficientNet의 output은 avgpool이기 때문에 1280을 인풋으로 넣어줍니다. 

bidirectional을 적용할 경우에 LSTM layer 개수를 2배 해줍니다. 또한 다음 레이어인 linear에 히든노드를 2배 적용해줍니다.

init에서 모델 파라미터를 설정해주고 forward에서 모델 구조를 설계해줍니다.

x는 인풋 데이터이고, 인풋 데이터는 동영상 데이터입니다. 따라서 batch_first일 경우 위와 같이 분리할 수 있습니다.

각각 CNN과 LSTM에 인풋 데이터를 설정해주고 CNN의 아웃풋을 LSTM에 입력해줍니다.

 

 

반응형

'인공지능' 카테고리의 다른 글

(Tensorflow) valueerror:unknown layer : functional  (0) 2022.04.14
[Keras] 자주쓰는 callback 모음  (0) 2021.03.30
pytorch class weight 주는법  (0) 2021.03.16
pytorch model summary  (0) 2021.03.16
GRU VS LSTM  (0) 2021.03.15
반응형

pytorch에서 모델을 학습할 때 특정 클래스에 가중치를 주어 더 잘 학습하게 하는 방법이 있습니다.

weights = torch.FloatTensor([1/20, 1/20, 1/20, 7/20, 1/20, 7/20, 1/20, 1/20, 1/100]).cuda()
criterion = torch.nn.CrossEntropyLoss(weight=weights)

위의 상황은 클래스가 9개일 때 인덱스 3과 5번은 더 잘 탐지하고 싶고, 마지막 인덱스는 불필요한 상황입니다.

이후 아래와 같이 설정해주면 됩니다.

logits = model(inputs)
labels = labels.view(bs*seq_length)

loss = criterion(logits, labels)
반응형

'인공지능' 카테고리의 다른 글

[Keras] 자주쓰는 callback 모음  (0) 2021.03.30
Pytorch CNN-LSTM 모델 설계  (0) 2021.03.17
pytorch model summary  (0) 2021.03.16
GRU VS LSTM  (0) 2021.03.15
YOLO V5 사용하기  (0) 2021.03.15
반응형

파이 토치 모델 summary에는 두가지 방법이 있습니다.

GRU 모델로 알아보겠습니다.

첫째는 단순히 model을 print하는 방법입니다.

print(model)

두번째는 torchsummary 모듈을 이용하는 것입니다.

from torchsummary import summary
summary(model, (1,1,26))

summary의 파라미터는 model과 인풋 사이즈입니다. 이때 배치와 시퀀스 길이는 1로 설정해야합니다. 이미지는 (채널 수, 이미지 사이즈)입니다.

반응형

'인공지능' 카테고리의 다른 글

Pytorch CNN-LSTM 모델 설계  (0) 2021.03.17
pytorch class weight 주는법  (0) 2021.03.16
GRU VS LSTM  (0) 2021.03.15
YOLO V5 사용하기  (0) 2021.03.15
YOLO v3 커스텀 데이터 셋 학습  (0) 2021.01.05

+ Recent posts