[Deep Learning] Recurrent Neural Network (3) - Attention
๐ง๐ป๐ป์ฉ์ด ์ ๋ฆฌ
Neural Networks
Recurrent Neural Network
LSTM
Attention
์ง๋ ์๊ฐ์ RNN์ ์ด์ด LSTM๊น์ง ์ดํด๋ณด์์ต๋๋ค.
์ด๋ฒ์๋ NLP๋ฅผ ํ์ฌ ์์ฒญ๋๊ฒ ํซํ๊ฒ ํด์ค Attention์ ๋ํด ์ดํด๋ณด๊ฒ ์ต๋๋ค.
Attention Mechanism
์ด Attention ๊ฐ๋ ์ ํตํด์
์ฐ์ , ์ฐ๋ฆฌ๊ฐ ์ฃผ๋ชฉํ์ง ์์๋ ๊ฒ๋ค์ ๋ํด์ ์ฃผ๋ชฉํ๊ธฐ ์์ํ๋ค๋ ๊ฒ์ ๋๋ค.
์ด๊ฒ์ ๊ฒฐ๊ตญ ์์ด๋ฒ๋ ธ๋ ๊ฒ์ ๋ํด์ ์ฃผ๋ชฉ์ ํ๊ฒ ๋ค๋ ๊ฒ์ ๋๋ค.
์ด๊ฒ์ Long-term Dependency์ ๋ํ ๋ฌธ์ ๋ก ๋์์ต๋๋ค.
CNN์์์ depth์ฒ๋ผ RNN์์์ Sequence๊ฐ ๊ธธ์ด์ก์ ๋ ๋ฐ์ํ๋ ๋ฌธ์ ๋ ๋น์ทํฉ๋๋ค.
์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ Algorithm์ด ํ์ํฉ๋๋ค.
์ด๋ฅผํ ๋ฉด,
Decoder์ ํ ์ถ๋ ฅ์ ๋ด๋ณด๋ด๋ ๋ฐ ์์ด, ์ด์ ๋จ์ด์ ์ถ๋ ฅ๊ณผ hidden state์ ์ฐ์ฐ์ผ๋ก ์ด๋ฃจ์ด์ง๋๋ฐ,
์ฌ๊ธฐ์ Encoder๊ฐ ๋๋ฌด๋ ๋ฉ์ด์ง๋ Long-term Dependency๊ฐ ๋ฐ์ํ๋ ๊ฒ์ ๋๋ค.
์ฌ๊ธฐ์ Encoder ์์ ์์์ ์ ๋ณด๋ฅผ ๋ค์ ์ถ๋ ฅ ํ๋ํ๋์ ํฌํจ์์ผ์ค๋ค๋ฉด ์ด๊ฒ๋ค์ด ์์ธกํ๋๋ฐ ๋์์ ์ฃผ์ง ์์๊นํ๋ ์์ด๋์ด๊ฐ ๋์ค๊ฒ ๋ฉ๋๋ค.
๊ฒฐ๊ตญ, ๋ฉ๋ฆฌ ๋จ์ด์ ธ์๋ Encoder์ ์ ๋ณด ์ค ์ค์ํ ์ ๋ณด๋ Decoder์์๋ ๋ฐ์ํด์ฃผ์๋ ๊ฒ์ ๋๋ค.
Encoder์ hidden vector ์ค decoder์ hidden vector์ ์ ์ฌํ ๊ฒ ์๋ค๋ฉด, ๊ทธ encoder์ hidden vector๋ ์์ธก์ ์ค์ํ key๊ฐ ๋ ์ ์์ต๋๋ค.
๋ค์๊ณผ ๊ฐ์ด ๋ ์ผ์ด๋ฅผ ์์ด๋ก ๋ฒ์ญํ๋ ์์๋ฅผ ๋ด ์๋ค.
์์ ๊ฐ์ด beer๋ฅผ ์์ธกํ๋ ๋ฐ ํ์ํ hidden state์ ๋ํด์, ์ด ๊ฒ์ ์ ๋ณด๋ฅผ ์กฐํฉํ์ฌ beer๋ผ๋ ๊ฐ์ ์ ๋ํ๋ด๋ vector๊ฐ ๋ ๊ฒ์ ๋๋ค.
๊ทธ๋ฆฌ๊ณ , ์์ ๊ฐ์ด bier์์์ hidden vector์ beer์์์ encoding๋ hidden vector๊ฐ ์๋ค๋ฉด, ์ด ์ ๋ณด๋ ํจ๊ป ์ด์ฉํ์๋ ๊ฒ์ ๋๋ค.
๊ทธ๋ ๋ค๋ฉด ์กฐ๊ธ ๋ decoding ์ ๋์์ด ๋์ง ์์๊น ํ๋ ๊ฒ์ ๋๋ค.
๋จ์ด๊ฐ ์ ์ฌํ๊ฒ vector๋ก ํํ๋ ๊ฒ์ ์ฐพ์๋ด์๋ ์๋ฏธ๋ก ๋ณด์๋ฉด ๋๊ฒ ์ต๋๋ค.
์กฐ๊ธ ๋ ์์ธํ ๋ด ์๋ค.
์์ ๊ฐ์ด attention์ด ์๋ ์์ธก๊ณผ ์๋ ์์ธก์ผ๋ก ๋๋์ด ๋ณผ ์ ์์ต๋๋ค.
attention์ด ์๋ค๋ฉด, ์ ์์ ์ hidden state์ ์ ์์ ์ ์ถ๋ ฅ๊ฐ์ด๋ weight ์ฐ์ฐ ํ bias๋ฅผ ํฌํจ์์ผ์ ํ์ฌ ์์ ์ state๊ฐ ๋๋๋ก ํฉ๋๋ค.
attention์ด ํฌํจ๋์์ ๋, ์ถ๊ฐ๋ก c t๋ฅผ ์ถ๊ฐํด์ค๋๋ค.
c t๋ผ๋ context vector๋ฅผ ๋ง๋ค์ด ๊ฐ์ด ๋ฃ์ด์ฃผ๋ ๊ฒ์ ๋๋ค.
์ฌ๊ธฐ์ c t๋ ์ฝ๊ฒ ๊ตฌํด๋ผ ์ ์์ต๋๋ค.
์ผ๋จ, ์ ์์ ์ decoder์์์ hidden state s t-1๊ณผ encoder์ hidden state๋ค ๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค.
๊ทธ๊ฒ์ด ๋ฐ๋ก ์ข์ธก์ ๋์์๋ ์์์ ๋๋ค.
s t-1 ์ Decoder์, h j๋ Encoder์ ์๋ ๊ฒ์ ๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ฌ๊ธฐ์ runnable parameter 3๊ฐ V, W, U๋ฅผ ๋ก๋๋ค.
์ฌ๊ธฐ์ ์ ์ฌ๋ ์ด๋ป๊ฒ ๋๋์ง ์์ง ๋ชจ๋ฅด์ง๋ง,
parameter 3๊ฐ์ activation์ ํ๋ค๋ฉด, ์ฐ๋ฆฌ๊ฐ ํ์ต์ ์ ์ํจ๋ค๋ฉด ์์ธก์ ์ ์ํฌ ์ ์์ง ์์๊น ๋ณด๋ ๊ฒ์ ๋๋ค.
์ฌ๊ธฐ์, j๋ผ๋ encoder์ index, ์ ๊ทธ๋ฆผ์ผ๋ก๋ณด๋ฉด 4๊ฐ, t๋ decoder์ index๋ก ์ฐ์ธก ๋ถ๋ถ ์ค ํ๋์ ๊ฐ์ ํด๋นํฉ๋๋ค.
์ด ์ ์ฌ๋์ ๋ํด ๊ณ์ฐํ ๊ฒ์ ๊ทธ๋ฅ ๋ด๋ณด๋ด๋ฉด ์ ๋๊ฒ ์ฃ ?
softmax function์ผ๋ก normalization์ ํฉ๋๋ค.
๊ฒฐ๊ตญ ๊ฐ encoder์ hidden node์ ๋ํ ์ ์ฌ๋๋ฅผ 0 ~ 1 ์ฌ์ด์ ๊ฐ์ผ๋ก scalingํ์ฌ ํ๋ฅ ๊ฐ์ฒ๋ผ ๋ฐ๊ฟ์ค๋๋ค.
๊ฒฐ๊ตญ a์ ๋ํ nornalization๋ ๊ฐ์ ๋ฐ์ง๊ณ ๋ณด๋ฉด ๊ฐ์ค์น๊ฐ ๋ฉ๋๋ค.
๊ทธ๋์ ์ด๋ค ๊ฒ์ด ์ ์ผ ์ ์ฌํ์ง ๊ฐ์ค์น๋ก์ ์ ์ ์๊ฒ ๋ฉ๋๋ค.
๊ทธ๋์ ์์ ๊ฐ์ด a i์ i index๋ decoder์ i index์ ๋๋ค.
๊ทธ๋์ s i๋ผ๋ hidden state์ attention์ encoding ์์ผ์ฃผ๊ธฐ ์ํ ๊ฐ์ค์น๊ฐ ๋ค์ด๊ฐ๊ฒ ๋ฉ๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ฐ์ธก๊ณผ ๊ฐ์ด context vector๋ผ๋ ๊ฒ์ ๋ง๋ค์ด ์ด๊ฒ๋ค์ ๋ค ๊ฐ์คํฉ์ ํฉ๋๋ค.
encoder์์์ hidden state 4๊ฐ์ ๋ํด์ ์์์ ๊ตฌํ ๊ฐ์ค์น์ ๋ํด์ ์ ํ์กฐํฉ์ ํ์ฌ ์ฐ์ฐํฉ๋๋ค.
์ด ๊ฐ์ค์น๋ ๊ฒฐ๊ตญ ์ ์ฌ๋์ ๊ธฐ๋ฐํ ๊ฐ์ด๋ฉฐ, ์ด ๊ฐ์ 0 ~ 1 ์ฌ์ด์ ๊ฐ์ผ๋ก scalingํ์ฌ ํฉ์ณ์ 1์ด ๋๊ฒ ๋ง๋ค์์ต๋๋ค.
๊ทธ๋์ ์ด๋ ๊ฒ ๋๋ฉด, ๊ฐ์ฅ ๊ด๋ จ์ด ๊น์ ๊ฐ์ด context vector์ ๊ฐ์ฅ ๋ง์ด ์ํฅ์ ๋ฏธ์น ๊ฒ์ ๋๋ค.
๊ทธ๋ฆฌ๊ณ ๊ฐ์ฅ ์ ์ฌํ์ง ์์ ๊ฒ์ 0์ ๊ฐ๊น์ด ๊ฐ์ค์น๋ฅผ ๊ฐ๊ฒ ๋์ด ์ฌ์ค์ context vector์ ๊ทธ๋ค์ง ์ํฅ์ ๋ฏธ์น์ง ์์ ๊ฒ์ ๋๋ค.
์ด context vector๋ฅผ ์ ์ฒด Decoder์ ํ์ด์ concatenateํ๋ฉด, context matrix๊ฐ ๋์ค๊ฒ ๋ฉ๋๋ค.
์ข์ธก์ ์๊ณผ ๊ฐ์ด Attention์ ์ด์ฉํ์ฌ prediction์ ํด์ฃผ๋ ์์ด ๋์ค๊ฒ ๋ฉ๋๋ค.
์ ๊ณ์ฐ์ผ๋ก๋ถํฐ ๋์จ C_t๋ฅผ ๋ฃ์ด์ฃผ์๋ ๊ฒ์ ๋๋ค.
๋ํ runnable parameter๊ฐ ์กด์ฌํ๋ฉฐ,
๊ทธ๋ฆฌ๊ณ ์ ๊ตฌ์กฐ๊ฐ LSTM style๋ก ๊ตฌ์ฑ๋๋ฉด runnable parameter๊ฐ ํจ์ฌ ๋ง์์ง๊ฒ ๋ฉ๋๋ค.
์์ ๊ฐ์ด Encoder, Decoder ๊ตฌ์กฐ๋ฅผ ๋ณผ ์ ์์ต๋๋ค.
Decoder์ ๊ฐ ๊ฐ์ ๋ํด์ ์ ์ฌ๋๊ฐ ๋ฌ๋ผ์ง๋ฏ๋ก weight ๊ฐ์ด ๋ฌ๋ผ์ง๊ฒ ๋ฉ๋๋ค.
๊ฒฐ๊ตญ์๋ ์ฌ๋ฌ encoder์ ์ ๋ ฅ ์ค ์ด๋ค ๊ฒ์ด prediction์ ๊ฐ์ฅ ํฐ ์ํฅ๋ ฅ์ ๋ฏธ์น๋์ง,
์ด๋ ํ ์ธ์๋ฅผ ๊ฐ์ง๊ณ ์์ธก์ ํด๋ด๋์ง ์ ์ ์์ต๋๋ค.
์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ์๋ ์ค๋ช ์ด ๋ถ๊ฐ๋ฅ ํ์ง๋ง,
Attention์ ๋์ ํ๋ฉด ์ค๋ช ์ด ๊ฐ๋ฅํด์ง๋๋ค.
๊ฒฐ๊ตญ prediction์ ์ ํํ ํ๊ฒ ์ํด Attention์ ๋์ ํ์ง๋ง, ํ๋ค ๋ณด๋ Attention์ด ์ค๋ช ์ ๊ต์ฅํ ํฐ ๋์์ด ๋๋ค๋ผ๋ ๊ฒ์ ๋๋ค.
๊ทธ๋ฆฌ๊ณ ์์ ๊ฐ์ด ์ ์ฌ๋๋ฅผ ๊ตฌํํ ์ ์๋ score function๋ค์ ๋๋ค.
'Artificial Intelligence > Deep Learning' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Deep Learning] Autoencoders (0) | 2023.05.23 |
---|---|
[Deep Learning] Recurrent Neural Network (4) - Transformer (0) | 2023.05.20 |
[Deep Learning] Recurrent Neural Network (2) - LSTM (0) | 2023.05.17 |
[Deep Learning] Recurrent Neural Network (1) (1) | 2023.05.16 |
[Deep Learning] Convolutional Neural Network (2) (0) | 2023.05.03 |