[kakao x goorm] 생성 AI 응용 서비스 개발자 양성 과정/회고록

[kakao x goorm] RNN, LSTM, GRU의 구조와 차이

Hoonia 2025. 5. 21. 13:38

 

 

 

 

오늘은 자연어 처리에서 시퀀스 데이터를 다루는 핵심 기술인 RNN 계열 모델들에 대해 학습했다.
RNN의 기본 구조부터 LSTM, GRU, 그리고 양방향 RNN까지 차근차근 정리해보았다.

RNN(Recurrent Neural Network) 이해하기

RNN은 연속적인 시퀀스를 처리하는 데 적합한 신경망 구조로, 이전 입력 정보를 내부적으로 기억해 다음 결과에 반영하는 특징을 가진다.
입력의 길이만큼 동일한 구조가 시간축으로 펼쳐지는(unrolled) 형태로 구성되며, 이때의 각 시점을 time step이라 부른다.

RNN의 응용 유형

유형  설명  예시
One to Many 입력 1개 → 출력 여러 개 이미지 캡셔닝
Many to One 여러 입력 → 출력 1개 감성 분석, 문서 분류
Many to Many 여러 입력 → 여러 출력 개체명 인식, 번역, POS 태깅

핵심 구조

  • RNN의 은닉층에서 반복적으로 동작하는 최소 단위를 셀(cell)이라 하며, 셀의 출력은 은닉 상태(hidden state)로 표현된다.
  • 현재 시점의 은닉 상태는
  • 현재 시점의 은닉 상태 $h_t$는 이전 시점의 은닉 상태 $h_{t-1}$과 현재 입력 $x_t $를 기반으로 계산된다.
  • 이 구조 덕분에 RNN은 과거 정보를 기억하는 순환 구조를 형성할 수 있다.
import torch
import torch.nn as nn

class BasicRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BasicRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)  # hidden state는 여기선 사용 안함
        out = self.fc(out[:, -1, :])  # 마지막 시점의 hidden state 사용
        return out

# 예시 입력
x = torch.randn(32, 10, 100)  # (batch_size, seq_len, input_size)
model = BasicRNN(input_size=100, hidden_size=64, output_size=2)
output = model(x)
print(output.shape)  # torch.Size([32, 2])

RNN의 한계: Long-Term Dependency Problem

기존 RNN은 시점이 길어질수록 초반 정보가 뒤로 갈수록 점점 소실되는 장기 의존성 문제(Long-Term Dependency)에 취약하다.
시퀀스의 앞쪽 정보가 뒤로 갈수록 영향력이 희미해지며, 이는 정보의 전파가 어려워진다는 의미이다.

LSTM (Long Short-Term Memory)

LSTM은 RNN의 장기 의존성 문제를 해결하기 위해 등장했다.
핵심은 cell state(셀 상태)라는 새로운 경로를 도입한 것이다.
셀 상태는 정보가 흐르는 통로이며, 이를 다양한 게이트(gate)가 조절한다.

주요 게이트 구성

  1. 입력 게이트 (Input Gate)
    • 현재 정보를 얼마나 기억할지 결정
    • 시그모이드→하이퍼볼릭탄젠트시그모이드 → 하이퍼볼릭 탄젠트 두 단계 연산을 통해 입력 정보를 필터링
  2. 삭제 게이트 (Forget Gate)
    • 이전의 기억 중 버릴 정보를 결정
    • 출력값이 0에 가까울수록 해당 기억을 삭제한다
  3. 셀 상태 업데이트
    • $c_t = f_t \cdot c_{t-1} + i_t \cdot g_t$
    • 삭제 게이트 $f_t$가 0이면 과거 정보를 완전히 삭제한다.
  4. 출력 게이트 (Output Gate)
    • 최종적으로 은닉 상태 $h_t$를 결정하는 데 사용된다.
    • 셀 상태를 기반으로 출력값을 조절
class BasicLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BasicLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])
        return out

x = torch.randn(32, 10, 100)
model = BasicLSTM(input_size=100, hidden_size=64, output_size=2)
output = model(x)
print(output.shape)

GRU (Gated Recurrent Unit)

GRU는 LSTM의 간소화된 버전으로, 게이트 수를 3개에서 2개로 축소한 구조이다.

  • 업데이트 게이트 (Update Gate): 장기 기억과 현재 기억을 어떻게 조합할지 결정
  • 리셋 게이트 (Reset Gate): 과거 정보를 얼마나 무시할지 결정

GRU는 계산량이 더 적어 빠른 학습이 가능하며, 데이터에 따라 LSTM보다 성능이 뛰어난 경우도 많다.

class BasicGRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BasicGRU, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.gru(x)
        out = self.fc(out[:, -1, :])
        return out

x = torch.randn(32, 10, 100)
model = BasicGRU(input_size=100, hidden_size=64, output_size=2)
output = model(x)
print(output.shape)

Bidirectional RNN (양방향 RNN)

Bidirectional RNN은 순방향뿐 아니라 역방향으로도 입력을 처리하는 구조다.
이렇게 하면 문맥의 과거와 미래를 모두 참고하여 더 풍부한 정보로 예측이 가능하다.

  • 예: 문장의 마지막 단어를 예측할 때 앞의 단어뿐 아니라 뒤 단어들도 참고할 수 있음
  • 구조적으로는 정방향 RNN과 역방향 RNN을 나란히 배치해 출력값을 결합한다.
import torch
import torch.nn as nn

class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BiRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, output_size)  # 양방향이므로 hidden_size * 2

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :])  # 마지막 시점의 양방향 hidden state 사용
        return out

# 입력 예시 (배치 크기 32, 시퀀스 길이 10, 입력 차원 100)
x = torch.randn(32, 10, 100)
model = BiRNN(input_size=100, hidden_size=64, output_size=2)
output = model(x)
print(output.shape)  # torch.Size([32, 2])

오늘의 회고

기존 RNN의 구조를 이해하고, 이를 보완한 LSTM과 GRU의 동작 원리를 학습하면서 시퀀스 모델이 왜 중요한지를 다시금 느낄 수 있었다.
단순한 반복 구조가 아니라, 정보를 어떻게 기억하고 버릴지를 스스로 학습하도록 설계된 게이트 구조는 매우 인상적이었다.
앞으로는 실제 텍스트 분류나 개체명 인식 작업에 RNN 기반 모델들을 어떻게 적용하는지도 실습해볼 계획이다.
다음에는 Transformer 구조와 Attention 메커니즘을 배워보며, RNN을 넘어선 새로운 시퀀스 모델을 이해하는 데 도전할 예정이다.