[Deep Learning] Recurrent Neural Network (2) - LSTM
๐ง๐ป๐ป์ฉ์ด ์ ๋ฆฌ
Neural Networks
Feed-forward
Backpropagation
Convolutional Neural Network
Recurrent Neural Network
LSTM
Attention
Cell state
forget gate
input gate
output gate
์ด์ ์๊ฐ๊น์ง RNN์ ์ฌ๋ฌ ๊ฐ์ง ์ข ๋ฅ์ ๋ํด ์์๋ณด์์ต๋๋ค.
์กฐ๊ธ ๋ ์ดํด๋ด ์๋ค.
RNN์ ๋จ์
RNN์ ์ญ ์ฐ๊ฒฐ์์ผ ๋๊ณ ๋ณด๋,
ํนํ Encoder, Decoder ๋ถ๋ถ์์ ํ์ฐํ ๋ค์ด๋๋ ์ด ํน์ง์,
๋๋ฌด ์ ๋ ฅ๋๋ ๋ฌธ์ฅ์ด ๊ธธ๋ค๋ณด๋,
Long-Term Dependency ๋ฌธ์ ๊ฐ ๋ฐ์ํฉ๋๋ค.
input signal์ ์ ๋ณด๋ฅผ h t์ ๊ณผ๊ฑฐ์์๋ถํฐ ๊ณ์ ๋ชจ์์ต๋๋ค.
๊ทธ๋ฐ๋ฐ ๊ณผ๊ฑฐ์ ๋ ์ด์ ์์ ์ ๋ฌด์ธ๊ฐ๊ฐ ํ์ํด์ง ๊ฒฝ์ฐ,
์ด chain์ด ๋๋ฌด ๊ธธ์ด์ ๊น๋จน๋ ๊ฒฝ์ฐ๊ฐ ๋ฐ์ํฉ๋๋ค.
์ฆ, short term memory๋ง ๋จ์์๋ ๊ฒ์ ๋๋ค.
๊ทธ๋์ ์ผ๋ฐ์ ์ธ RNN์ long-term dependency๋ฅผ ๊ณ ๋ คํ์ง ๋ชปํ์ฌ long-term memory๋ฅผ ๋ง๋ค์ด์ผํ๋ค๊ณ ์๊ฐํฉ๋๋ค.
LSTM(Long-Short Term Memory)
๊ทธ๋์ ๋ณด์๋ ๋ชจ๋ธ์ด LSTM์ ๋๋ค.
long๊ณผ short๋ฅผ ๋ค ๊ณ ๋ คํ๊ฒ ๋ค๋ ๋ป์ ๋๋ค.
์ฌ์ค์ RNN์ short term ์์ฃผ๋ก ๋ณด๋, ์ด๊ฒ์ผ๋ก long term๋ ๋ณด๊ฒ ๋ค๋ ๊ฒ์ ๋๋ค.
์๋ RNN์ด๊ณ ,
์๋๋ LSTM์ ๋๋ค.
์ RNN ๊ตฌ์กฐ๋ฅผ ๊ทธ๋ฆผ์ผ๋ก ๊ทธ๋ฆฌ๋ฉด ์ด๋ ์ต๋๋ค.
์ง๊ธ๊น์ง ์ฐ๋ฆฌ๊ฐ ๋ด์จ ๊ฒ๊ณผ ๊ฐ์ฃ ?
์ ์ด๋ ๊ฒ ๋ณต์กํ๊ฒ ๊ทธ๋ ธ์๊น์?
LSTM์ ์ ๋ณต์กํ ๊ตฌ์กฐ๋ฅผ ์ค๋ช ํ๊ธฐ ์ํจ์ด๋ฉฐ,
์ง๊ธ ์๋ก ๊ทธ๋ฆฐ RNN ๊ตฌ์กฐ์ ๋ค๋ฅผ ๊ฒ ์๋ค๋ ๊ฒ์ ์๊ณ ๊ฐ๋ด ์๋ค.
๊ทธ๋์ ์ฐ๋ฆฌ๊ฐ ์์์ ์ธ๊ธํ pattern 1,2,3, Seq2Seq, one2many, many2one ์ด๋ฌํ ๊ตฌ์กฐ๋ค์ด ๋๊ฐ์ด ๋ค ์ ์ฉ๋ฉ๋๋ค.
๊ทธ์ hidden node ๋ด๋ถ์์ ๋ฒ์ด์ง๋ ์ผ๋ง ๋ณต์กํ๊ฒ ๋ณ๊ฒฝ๋์๋ค๊ณ ๋ณด์๋ฉด ๋ฉ๋๋ค.
์๋ก์ด ๊ตฌ์กฐ๋ ์๋๋ฉฐ, hidden node ๋ด๋ถ์์ ์ ๋ณด๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐฉ์๋ง ๋ฐ๊ฟ์ค ๊ฒ์ ๋๋ค.
๊ทธ๋ผ ๊ทธ ๋ฐฉ์์ ํ๋์ฉ ๋ฏ์ด๋ด ์๋ค.
Cell state
LSTM์์ ๊ฐ์ฅ ์ค์ํ ๊ฒ์ด Cell state์ ๋๋ค.
์ด๊ฒ์, context๋ผ๊ณ ๋ ๋ถ๋ฅด๋ฉฐ long-term memory๋ผ๊ณ ๋ ๋ถ๋ฆ ๋๋ค.
์ด๊ฒ์ self-connection์ ํตํด์ ์๋ฅผ ์ญ ์ง๋๊ฐ๋๋ค.
์ญ ๊ดํตํด์ ์ง๋๊ฐ๋๋ค.
์ํ ์์ด ์ง๋๊ฐ๋๋ค.
์ด๊ฒ์ ๊ณผ๊ฑฐ๋ก๋ถํฐ ์ค๊ณ , ๋ญ๊ฐ ๋ ๋ฒ์ operation์ ๊ฑฐ์ณ์ ํ์ฌ์ context๋ก ๋ฐ๋๋๋ค.
LSTM์๋ ์ด 3๊ฐ์ง์ ๋ฌธ์ด ์กด์ฌํฉ๋๋ค.
์ด forget gate์์๋ ๊ณผ๊ฑฐ๋ก๋ถํฐ ์ค๋ ์ ๋ณด๋ค์ ์ผ๋ง๋ ์์ด๋ฒ๋ฆด์ง๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.
์ฆ, ๊ณผ๊ฑฐ๋ก๋ถํฐ ์ค๋ context๋ฅผ ์ผ๋ง๋ ํต๊ณผ์ํฌ ๊ฒ์ธ์ง๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.
1์ ๊ฐ๊น์ ์ง๋ค๋ฉด, ๊ณผ๊ฑฐ์ ์ ๋ณด๋ค์ ์ ๋ถ๋ค ์ ์ง์ํจ๋ค๊ณ ๋ณด๋ฉด ๋๊ณ , long-term memory์ ๊ณ์ ๊ธฐ์ตํ๋ค๊ณ ๋ด ๋๋ค.
0์ ๊ฐ๊น์ ์ง๋ค๋ฉด, ์ง๊ธ๊น์ง ๊ณผ๊ฑฐ๋ก๋ถํฐ ๊ตฌ์ฑ๋์ด ์จ ์ด ์ ๋ณด๋ ํ์ฌ ์์ ์์ ์ธ๋ชจ๊ฐ ์์ผ๋ ๋ค ๊ฐ๋ค ๋ฒ๋ฆฌ๋ผ๋ ๊ฒ์ ๋๋ค.
์ด gate๋ 0 ~ 1 ์ฌ์ด์์ bound๊ฐ ๋์ด์ผ ํฉ๋๋ค.
๊ทธ๋์ sigmoid๋ฅผ ์ฌ์ฉํฉ๋๋ค.
๊ทธ๋์ ๊ทธ๋ฆผ์ ์ฐ์ธก ์์๊ณผ ๊ฐ์ ์์์ด ๋ฉ๋๋ค.
์ด๊ฒ์ด hidden node ๊ณ์ฐ๊ณผ ์์ ๊ฐ์ง๋ง, W f์ b f๋ trainable parameter์ ๋๋ค.
์ฐ์ , forget gate๋ผ๋ ๋์์ธ์ ํด๋๊ณ , ์ด forget gate์ ์ญํ ์ ํ์ต ๋ฐ์ดํฐ์ ๋ฐ๋ผ ํ์ต์ด ๋๋๋ฐ,
W์ b๋ฅผ ์ฐ๋ฆฌ๊ฐ ์ธ ์ ์๋ ์ ๋ณด์ ์ ์กฐ์ํด์ ์ด gate์ ๋ํด ์ผ๋ง๋ ๋ฐ์ํ ์ง๋ฅผ ์ ๊ฒฐ์ ํ๋ผ๋ ๊ฒ์ด ๋ฉ๋๋ค.
๊ฒฐ๊ตญ ๊ตฌ์กฐ๋ง ์ง๋๊ณ Backpropagation์ ์ฑ ์์ ์ ๊ฐ์ํค๋ ๋๋์ด ๋๊ฒ ์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ ๋ ๋ฒ์งธ ๋ฌธ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
๋ ๋ฒ์งธ ๋ฌธ์ Input gate์ ๋๋ค.
๊ธ๋ฐฉ C t์ ๊ฐ์ด ๊ฒฐ์ ๋๋ ๊ฒ์ ๋ํด์ C t-1์ ๊ณผ๊ฑฐ์ ์ ๋ณด๊ฐ ์ผ๋ง๋ ๊ฒฐ์ ๋ ์ง, forget gate์ ๊ณฑํ๊ธฐ๋ฅผ ํตํด์ ๊ฒฐ์ ๋๊ณ ,
๊ทธ๋ฆฌ๊ณ ๋ํ๊ธฐ์ ๊ณผ์ ์ด ํ๋ ๋จ์์ต๋๋ค.
์ฌ๊ธฐ์๋ ํ์ฌ์ ์ ๋ณด๋ฅผ ๋ํด์ค๋๋ค.
ํ์ฌ ์ ๋ณด๋ ์์ ๊ฐ์ด C t์ ์์์ผ๋ก hidden node์์ activation ์์ผ์ฃผ๋ ๊ฒ๊ณผ ๊ฐ์ด ์ฐ์ฐ๋ฉ๋๋ค.
์ด C t๋ ํ์ฌ ์ ๋ณด์ ํ๋ณด๊ฐ ๋ฉ๋๋ค.
์ด ๋ํ, ๊ทธ๋ฅ ๋ํ๋ ๊ฒ์ด ์๋ ์ผ๋ง๋ ๋ฐ์์ํฌ์ง๋ฅผ ๊ณฑํ์ฌ ๋ฐ์์ํต๋๋ค.
์ด๊ฒ์ C t-1์ด ๋ค์ด์ค๋ ๊ฒ๊ณผ ๋์นญ์ด ๋๋, ๊ฐ์ ์ฐ์ฐ์ผ๋ก ์ด๋ฃจ์ด์ง๋๋ค.
์ฌ๊ธฐ์์ ๊ณฑํ๊ธฐ๋ i์ธ input gate์ด๊ณ , ๋๊ฐ์ sigmoid์ ๊ฐ์ input data๋ฅผ ์ฌ์ฉํ๋๋ฐ,
๋ ๋ค๋ฅธ trainable parameter๋ฅผ ์ฌ์ฉํ์ฌ ๋ฌธ์ ๋ง๋ค์ด ์ค ๊ฒ์ ๋๋ค.
๊ทธ๋์ ์์ ๊ฐ์ด input ํ๋ณด๊ฐ ๊ฒฐ์ ๋๊ณ , ์ด input ํ๋ณด์ ๋ํด์ ์ผ๋ง๋ ๋ฐ์๋ค์ผ์ง ๊ฒฐ์ ํ๋ input gate๊ฐ ์กด์ฌํฉ๋๋ค.
๊ฒฐ๊ตญ, ์๋ก์ด ์ ๋ณด๋ฅผ ์ฐ๋ฆฌ๊ฐ context์ ์ผ๋ง๋ ๋ฐ์ํ ๊ฒ์ธ์ง ๋ณด๋ ๊ฒ์ ๋๋ค.
input gate, forget gate๋ ๋ชจ๋ context์ ์ผ๋ง๋ ๋ฐ์ํ ์ง๋ฅผ ๋ณด๊ธฐ ์ํจ์ ๋๋ค.
๊ทธ๋์ ์๋ก์ด context๋ฅผ ๊ณ์ฐํฉ๋๋ค.
์ด ๊ฐ์ forget gate * ๊ณผ๊ฑฐ + input gate * ํ์ฌ input์ ํ๋ณด๋ก ๊ณ์ฐ์ด ๋ฉ๋๋ค.
์๋ฅผ ๋ค์ด, forget gate, input gate๊ฐ ๊ฐ๊ฐ 0.9, 0.1์ด๋ผ๋ฉด ๊ณผ๊ฑฐ๋ ๋ง์ด, ํ์ฌ๋ ์ ๊ฒ ๋ฐ์ํด๋ผ ์ ๋๋ค. ์ฆ, long-term memory๊ฐ ๋์ฃ .
0.1, 0.9์ ํํ๋ผ๋ฉด ๊ณผ๊ฑฐ์ ์๋ ๊ฑฐ ์กฐ๊ธ๋ง ๋จ๊ธฐ๊ณ ํ์ฌ์ ๊ฒ์ ๋ฐ์ํด๋ผ๊ฐ ๋ฉ๋๋ค.
์ด๋ ๊ฒ ์๋ก์ด context๊ฐ ๊ตฌ์ฑ์ด ๋ฉ๋๋ค.
์ง๊ธ ์๊ธฐํ ๊ฒ์ด context์ ๋ํ ์ด์ผ๊ธฐ ์ ๋๋ค.
hidden node์ ๋ํ ์ด์ผ๊ธฐ๋ ์์ง ์ ํ์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ 3๋ฒ์งธ ๋ง์ง๋ง ๋ฌธ์ธ output gate์ ์์์ผ ๋๋์ด hidden node์ ๊ฐ์ ๊ณ์ฐํ๊ฒ ๋ฉ๋๋ค.
๋ ๋ค๋ฅธ output gate๋ฅผ ๋ง๋ค์ด์ฃผ๊ณ ,
context๋ฅผ activation ํ Output gate์ ํต๊ณผ์ํค๋ฉด์ hidden node์ output์ ์ด์ ์์ผ ๋ง๋ค์ด์ฃผ๊ฒ ๋ฉ๋๋ค.
๊ทธ๋์ ์ด ๊ฐ์ ์์ layer๋ก ์ฌ๋ ค์ฃผ๊ณ , t + 1์ ๋ค์ ๋จ๊ณ๋ก๋ ๋ณด๋ด์ค๋๋ค.
๊ทธ๋์ ์์ ๋์จ ์ด์ผ๊ธฐ๋ฅผ ๋ค ์์ฝํด๋ณด์๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
๊ฒฐ๊ตญ RNN๊ณผ ๊ฐ์ ๊ตฌ์กฐ์ธ๋ฐ,
Hidden node ์์์ operation๋ง ๋ ๋ง์์ง๋ค๊ณ ๋ณผ ์ ์์ต๋๋ค.
์ด ๋ชจ๋ ๊ฒ์ ์ cell state ํน์ context ํน์ long-term memory๋ฅผ ๋ง๋ค์ด์ฃผ๊ธฐ ์ํจ์ ๋๋ค.
RNN์ ๊ธฐ๋ณธ์ ์ผ๋ก parameter๊ฐ 2 sets ํ์ํ๋ค๊ณ ํฉ๋๋ค.
๊ทธ๋ฐ๋ฐ, LSTM์ ๊ธฐ๋ณธ์ ์ผ๋ก parameter๊ฐ 8 sets ํ์ํ์ฌ, 4๋ฐฐ ์ด์ ์์๋ฉ๋๋ค.
RNN๋ computer๊ฐ ํ๋ค์ดํ๋ ์ฐ์ฐ์ด์ง๋ง, LSTM์ ๊ทธ๊ฒ์ ์ด๋ฆผ์ก์ 4๋ฐฐ ๋ ๋ง์ ์ฐ์ฐ์ด ๋ญ๋๋ค.
'Artificial Intelligence > Deep Learning' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Deep Learning] Recurrent Neural Network (4) - Transformer (0) | 2023.05.20 |
---|---|
[Deep Learning] Recurrent Neural Network (3) - Attention (0) | 2023.05.19 |
[Deep Learning] Recurrent Neural Network (1) (1) | 2023.05.16 |
[Deep Learning] Convolutional Neural Network (2) (0) | 2023.05.03 |
[Deep Learning] Convolutional Neural Network (1) (1) | 2023.04.30 |