[Deep Learning] Recurrent Neural Network (3) - Attention

2023. 5. 19. 00:08
๐Ÿง‘๐Ÿป‍๐Ÿ’ป์šฉ์–ด ์ •๋ฆฌ
Neural Networks
Recurrent Neural Network
LSTM
Attention

 

์ง€๋‚œ ์‹œ๊ฐ„์— RNN์— ์ด์–ด LSTM๊นŒ์ง€ ์‚ดํŽด๋ณด์•˜์Šต๋‹ˆ๋‹ค.

 

์ด๋ฒˆ์—๋Š” NLP๋ฅผ ํ˜„์žฌ ์—„์ฒญ๋‚˜๊ฒŒ ํ•ซํ•˜๊ฒŒ ํ•ด์ค€ Attention์— ๋Œ€ํ•ด ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

 

 

Attention Mechanism

์ด Attention ๊ฐœ๋…์„ ํ†ตํ•ด์„œ 

 

์šฐ์„ , ์šฐ๋ฆฌ๊ฐ€ ์ฃผ๋ชฉํ•˜์ง€ ์•Š์•˜๋˜ ๊ฒƒ๋“ค์— ๋Œ€ํ•ด์„œ ์ฃผ๋ชฉํ•˜๊ธฐ ์‹œ์ž‘ํ•œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

์ด๊ฒƒ์€ ๊ฒฐ๊ตญ ์žƒ์–ด๋ฒ„๋ ธ๋˜ ๊ฒƒ์— ๋Œ€ํ•ด์„œ ์ฃผ๋ชฉ์„ ํ•˜๊ฒ ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

์ด๊ฒƒ์€ Long-term Dependency์— ๋Œ€ํ•œ ๋ฌธ์ œ๋กœ ๋Œ์•„์˜ต๋‹ˆ๋‹ค.

 

 

CNN์—์„œ์˜ depth์ฒ˜๋Ÿผ RNN์—์„œ์˜ Sequence๊ฐ€ ๊ธธ์–ด์กŒ์„ ๋•Œ ๋ฐœ์ƒํ•˜๋Š” ๋ฌธ์ œ๋Š” ๋น„์Šทํ•ฉ๋‹ˆ๋‹ค.

 

์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” Algorithm์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

 

 

์ด๋ฅผํ…Œ๋ฉด,

 

Decoder์˜ ํ•œ ์ถœ๋ ฅ์„ ๋‚ด๋ณด๋‚ด๋Š” ๋ฐ ์žˆ์–ด, ์ด์ „ ๋‹จ์–ด์˜ ์ถœ๋ ฅ๊ณผ hidden state์˜ ์—ฐ์‚ฐ์œผ๋กœ ์ด๋ฃจ์–ด์ง€๋Š”๋ฐ,

 

์—ฌ๊ธฐ์„œ Encoder๊ฐ€ ๋„ˆ๋ฌด๋‚˜ ๋ฉ€์–ด์ง€๋‹ˆ Long-term Dependency๊ฐ€ ๋ฐœ์ƒํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

 

์—ฌ๊ธฐ์„œ Encoder ์‹œ์ ์—์„œ์˜ ์ •๋ณด๋ฅผ ๋‹ค์‹œ ์ถœ๋ ฅ ํ•˜๋‚˜ํ•˜๋‚˜์— ํฌํ•จ์‹œ์ผœ์ค€๋‹ค๋ฉด ์ด๊ฒƒ๋“ค์ด ์˜ˆ์ธกํ•˜๋Š”๋ฐ ๋„์›€์„ ์ฃผ์ง€ ์•Š์„๊นŒํ•˜๋Š” ์•„์ด๋””์–ด๊ฐ€ ๋‚˜์˜ค๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

 

 

๊ฒฐ๊ตญ, ๋ฉ€๋ฆฌ ๋–จ์–ด์ ธ์žˆ๋Š” Encoder์˜ ์ •๋ณด ์ค‘ ์ค‘์š”ํ•œ ์ •๋ณด๋Š” Decoder์—์„œ๋„ ๋ฐ˜์˜ํ•ด์ฃผ์ž๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

Encoder์˜ hidden vector ์ค‘ decoder์˜ hidden vector์™€ ์œ ์‚ฌํ•œ ๊ฒŒ ์žˆ๋‹ค๋ฉด, ๊ทธ encoder์˜ hidden vector๋Š” ์˜ˆ์ธก์— ์ค‘์š”ํ•œ key๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

 

๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋…์ผ์–ด๋ฅผ ์˜์–ด๋กœ ๋ฒˆ์—ญํ•˜๋Š” ์˜ˆ์‹œ๋ฅผ ๋ด…์‹œ๋‹ค.

 

์œ„์™€ ๊ฐ™์ด beer๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ hidden state์— ๋Œ€ํ•ด์„œ, ์ด ๊ฒƒ์€ ์ •๋ณด๋ฅผ ์กฐํ•ฉํ•˜์—ฌ beer๋ผ๋Š” ๊ฐ’์„ ์ž˜ ๋‚˜ํƒ€๋‚ด๋Š” vector๊ฐ€ ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

๊ทธ๋ฆฌ๊ณ , ์œ„์™€ ๊ฐ™์ด bier์—์„œ์˜ hidden vector์™€ beer์—์„œ์˜ encoding๋œ hidden vector๊ฐ€ ์žˆ๋‹ค๋ฉด, ์ด ์ •๋ณด๋„ ํ•จ๊ป˜ ์ด์šฉํ•˜์ž๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

๊ทธ๋ ‡๋‹ค๋ฉด ์กฐ๊ธˆ ๋” decoding ์‹œ ๋„์›€์ด ๋˜์ง€ ์•Š์„๊นŒ ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

๋‹จ์–ด๊ฐ€ ์œ ์‚ฌํ•˜๊ฒŒ vector๋กœ ํ‘œํ˜„๋œ ๊ฒƒ์„ ์ฐพ์•„๋‚ด์ž๋Š” ์˜๋ฏธ๋กœ ๋ณด์‹œ๋ฉด ๋˜๊ฒ ์Šต๋‹ˆ๋‹ค.

 

์กฐ๊ธˆ ๋” ์ž์„ธํžˆ ๋ด…์‹œ๋‹ค.

 

์œ„์™€ ๊ฐ™์ด attention์ด ์žˆ๋Š” ์˜ˆ์ธก๊ณผ ์—†๋Š” ์˜ˆ์ธก์œผ๋กœ ๋‚˜๋ˆ„์–ด ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

attention์ด ์—†๋‹ค๋ฉด, ์ „ ์‹œ์ ์˜ hidden state์™€ ์ „ ์‹œ์ ์˜ ์ถœ๋ ฅ๊ฐ’์ด๋ž‘ weight ์—ฐ์‚ฐ ํ›„ bias๋ฅผ ํฌํ•จ์‹œ์ผœ์„œ ํ˜„์žฌ ์‹œ์ ์˜ state๊ฐ€ ๋˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

 

attention์ด ํฌํ•จ๋˜์—ˆ์„ ๋•Œ, ์ถ”๊ฐ€๋กœ c t๋ฅผ ์ถ”๊ฐ€ํ•ด์ค๋‹ˆ๋‹ค.

 

c t๋ผ๋Š” context vector๋ฅผ ๋งŒ๋“ค์–ด ๊ฐ™์ด ๋„ฃ์–ด์ฃผ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์„œ c t๋Š” ์‰ฝ๊ฒŒ ๊ตฌํ•ด๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

 

 

์ผ๋‹จ, ์ „ ์‹œ์ ์˜ decoder์—์„œ์˜ hidden state s t-1๊ณผ encoder์˜ hidden state๋“ค ๊ฐ„์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.

๊ทธ๊ฒƒ์ด ๋ฐ”๋กœ ์ขŒ์ธก์— ๋‚˜์™€์žˆ๋Š” ์ˆ˜์‹์ž…๋‹ˆ๋‹ค.

 

s t-1 ์€ Decoder์—, h j๋Š” Encoder์— ์žˆ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

๊ทธ๋ฆฌ๊ณ  ์—ฌ๊ธฐ์— runnable parameter 3๊ฐœ V, W, U๋ฅผ ๋‘ก๋‹ˆ๋‹ค.

 

์—ฌ๊ธฐ์„œ ์œ ์‚ฌ๋„ ์–ด๋–ป๊ฒŒ ๋˜๋Š”์ง„ ์•„์ง ๋ชจ๋ฅด์ง€๋งŒ,

 

parameter 3๊ฐœ์™€ activation์„ ํ•œ๋‹ค๋ฉด, ์šฐ๋ฆฌ๊ฐ€ ํ•™์Šต์„ ์ž˜ ์‹œํ‚จ๋‹ค๋ฉด ์˜ˆ์ธก์„ ์ž˜ ์‹œํ‚ฌ ์ˆ˜ ์žˆ์ง€ ์•Š์„๊นŒ ๋ณด๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.



์—ฌ๊ธฐ์„œ, j๋ผ๋Š” encoder์˜ index, ์œ„ ๊ทธ๋ฆผ์œผ๋กœ๋ณด๋ฉด 4๊ฐœ, t๋Š” decoder์˜ index๋กœ ์šฐ์ธก ๋ถ€๋ถ„ ์ค‘ ํ•˜๋‚˜์˜ ๊ฐ’์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค.

 

 

์ด ์œ ์‚ฌ๋„์— ๋Œ€ํ•ด ๊ณ„์‚ฐํ•œ ๊ฒƒ์„ ๊ทธ๋ƒฅ ๋‚ด๋ณด๋‚ด๋ฉด ์•ˆ ๋˜๊ฒ ์ฃ ?

 

softmax function์œผ๋กœ normalization์„ ํ•ฉ๋‹ˆ๋‹ค.

 

๊ฒฐ๊ตญ ๊ฐ encoder์˜ hidden node์— ๋Œ€ํ•œ ์œ ์‚ฌ๋„๋ฅผ 0 ~ 1 ์‚ฌ์ด์˜ ๊ฐ’์œผ๋กœ scalingํ•˜์—ฌ ํ™•๋ฅ ๊ฐ’์ฒ˜๋Ÿผ ๋ฐ”๊ฟ”์ค๋‹ˆ๋‹ค.

 

๊ฒฐ๊ตญ a์— ๋Œ€ํ•œ nornalization๋œ ๊ฐ’์€ ๋”ฐ์ง€๊ณ ๋ณด๋ฉด ๊ฐ€์ค‘์น˜๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

 

๊ทธ๋ž˜์„œ ์–ด๋–ค ๊ฒƒ์ด ์ œ์ผ ์œ ์‚ฌํ•œ์ง€ ๊ฐ€์ค‘์น˜๋กœ์„œ ์•Œ ์ˆ˜ ์žˆ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

 

 

 

๊ทธ๋ž˜์„œ ์œ„์™€ ๊ฐ™์ด a i์˜ i index๋Š” decoder์˜ i index์ž…๋‹ˆ๋‹ค.

 

๊ทธ๋ž˜์„œ s i๋ผ๋Š” hidden state์˜ attention์„ encoding ์‹œ์ผœ์ฃผ๊ธฐ ์œ„ํ•œ ๊ฐ€์ค‘์น˜๊ฐ€ ๋“ค์–ด๊ฐ€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

๊ทธ๋ฆฌ๊ณ  ์šฐ์ธก๊ณผ ๊ฐ™์ด context vector๋ผ๋Š” ๊ฒƒ์„ ๋งŒ๋“ค์–ด ์ด๊ฒƒ๋“ค์„ ๋‹ค ๊ฐ€์ค‘ํ•ฉ์„ ํ•ฉ๋‹ˆ๋‹ค.

 

encoder์—์„œ์˜ hidden state 4๊ฐœ์— ๋Œ€ํ•ด์„œ ์•ž์—์„œ ๊ตฌํ•œ ๊ฐ€์ค‘์น˜์— ๋Œ€ํ•ด์„œ ์„ ํ˜•์กฐํ•ฉ์„ ํ•˜์—ฌ ์—ฐ์‚ฐํ•ฉ๋‹ˆ๋‹ค.

 

์ด ๊ฐ€์ค‘์น˜๋Š” ๊ฒฐ๊ตญ ์œ ์‚ฌ๋„์— ๊ธฐ๋ฐ˜ํ•œ ๊ฐ’์ด๋ฉฐ, ์ด ๊ฐ’์€ 0 ~ 1 ์‚ฌ์ด์˜ ๊ฐ’์œผ๋กœ scalingํ•˜์—ฌ ํ•ฉ์ณ์„œ 1์ด ๋˜๊ฒŒ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค.

 

 

๊ทธ๋ž˜์„œ ์ด๋ ‡๊ฒŒ ๋˜๋ฉด, ๊ฐ€์žฅ ๊ด€๋ จ์ด ๊นŠ์€ ๊ฐ’์ด context vector์— ๊ฐ€์žฅ ๋งŽ์ด ์˜ํ–ฅ์„ ๋ฏธ์น  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

๊ทธ๋ฆฌ๊ณ  ๊ฐ€์žฅ ์œ ์‚ฌํ•˜์ง€ ์•Š์€ ๊ฒƒ์€ 0์— ๊ฐ€๊นŒ์šด ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ–๊ฒŒ ๋˜์–ด ์‚ฌ์‹ค์ƒ context vector์— ๊ทธ๋‹ค์ง€ ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ์•Š์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

์ด context vector๋ฅผ ์ „์ฒด Decoder์— ํ’€์–ด์„œ concatenateํ•˜๋ฉด, context matrix๊ฐ€ ๋‚˜์˜ค๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

 

์ขŒ์ธก์˜ ์‹๊ณผ ๊ฐ™์ด Attention์„ ์ด์šฉํ•˜์—ฌ prediction์„ ํ•ด์ฃผ๋Š” ์‹์ด ๋‚˜์˜ค๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

์œ„ ๊ณ„์‚ฐ์œผ๋กœ๋ถ€ํ„ฐ ๋‚˜์˜จ C_t๋ฅผ ๋„ฃ์–ด์ฃผ์ž๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

๋˜ํ•œ runnable parameter๊ฐ€ ์กด์žฌํ•˜๋ฉฐ,

 

๊ทธ๋ฆฌ๊ณ  ์œ„ ๊ตฌ์กฐ๊ฐ€ LSTM style๋กœ ๊ตฌ์„ฑ๋˜๋ฉด runnable parameter๊ฐ€ ํ›จ์”ฌ ๋งŽ์•„์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

 

์œ„์™€ ๊ฐ™์ด Encoder, Decoder ๊ตฌ์กฐ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

Decoder์˜ ๊ฐ ๊ฐ’์— ๋Œ€ํ•ด์„œ ์œ ์‚ฌ๋„๊ฐ€ ๋‹ฌ๋ผ์ง€๋ฏ€๋กœ weight ๊ฐ’์ด ๋‹ฌ๋ผ์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

 

๊ฒฐ๊ตญ์—๋Š” ์—ฌ๋Ÿฌ encoder์˜ ์ž…๋ ฅ ์ค‘ ์–ด๋–ค ๊ฒƒ์ด prediction์— ๊ฐ€์žฅ ํฐ ์˜ํ–ฅ๋ ฅ์„ ๋ฏธ์น˜๋Š”์ง€,

 

์–ด๋– ํ•œ ์ธ์ž๋ฅผ ๊ฐ€์ง€๊ณ  ์˜ˆ์ธก์„ ํ•ด๋‚ด๋Š”์ง€ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

์ธ๊ณต์ง€๋Šฅ ๋ชจ๋ธ์€ ์›๋ž˜ ์„ค๋ช…์ด ๋ถˆ๊ฐ€๋Šฅ ํ–ˆ์ง€๋งŒ,

 

Attention์„ ๋„์ž…ํ•˜๋ฉด ์„ค๋ช…์ด ๊ฐ€๋Šฅํ•ด์ง‘๋‹ˆ๋‹ค.

 

 

๊ฒฐ๊ตญ prediction์„ ์ •ํ™•ํžˆ ํ•˜๊ฒŒ ์œ„ํ•ด Attention์„ ๋„์ž…ํ–ˆ์ง€๋งŒ, ํ•˜๋‹ค ๋ณด๋‹ˆ Attention์ด ์„ค๋ช…์— ๊ต‰์žฅํžˆ ํฐ ๋„์›€์ด ๋œ๋‹ค๋ผ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

๊ทธ๋ฆฌ๊ณ  ์œ„์™€ ๊ฐ™์ด ์œ ์‚ฌ๋„๋ฅผ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋Š” score function๋“ค์ž…๋‹ˆ๋‹ค.

 

 

BELATED ARTICLES

more