[Training Neural Networks] part 2
๐ง๐ป๐ป์ฉ์ด ์ ๋ฆฌ
Gradient Descent
MNIST handwritten digit classification
Back propagation
Computational graph
sigmoid
tanh
ReLU
Batch Normalization
Gradient Descent
- loss function์ ์ต์ํํ๋ Gradient ๊ฐ์ ์ฐพ๋๋ค.
- ๊ทธ๋ฌ๋ ๊ณง์ด ๊ณง๋๋ก gradient descent ๋ฅผ ์ฌ์ฉํ๋ค๋ฉด ๋นํจ์จ์ ์ผ๋ก ์ฌ์ฉ๋ ์ ์๋ค.
- ๊ทธ๋ฌ๋ฏ๋ก ์ฌ๋ฌ ๊ฐ์ง Gradient Descent Algorithm ๋ค์ด ์กด์ฌํฉ๋๋ค.
MNIST handwritten digit classification
-> mse loss ์ ์ฉ ์์
๋ชฉ์ : ํ์ต์ ํตํด ๊ฐ layer ๋ค์ ์กด์ฌํ๋ Parameter ๋ค.
- ํ์ต Data์ ๋ํด์ forward propagation์ ์ํ.
- parameter๋ค์ Random initialization์์๋ถํฐ ์์.
- ๋จผ์ , ์์ธก ๊ฐ๋ค์ ์ ๋ ฅ Data์ ground truth ๊ฐ๊ณผ ์์ดํ๋ค.
- ๊ณ์ํด์ ์ด ์ฐจ์ด๋ฅผ ์ค์ด๋ ๋ฐฉํฅ์ํธ ์์คํจ์๋ฅผ ์ ์ํ๋ค.
- ์์คํจ์๋ก๋ถํฐ ํธ๋ฏธ๋ถ ๊ฐ์ ๊ตฌํ๊ณ , ํธ๋ฏธ๋ถ ๊ฐ์ ํตํด parameter ๋ค์ updateํด์ผํฉ๋๋ค.
๊ฒฐ๊ตญ Loss function์ ๋ํ ํธ๋ฏธ๋ถ์ ๊ตฌํ๋ ๊ณผ์ ์ ๋ง์น neural network์์์ forward propagation์ ๋ฐ๋ ๋ฐฉํฅ์ผ๋ก layer ๋ณ๋ก ์์ฐจ์ ์ธ ๊ณ์ฐ์ ์ํํ๊ฒ ๋๋ back propagation์ด๋ผ๋ ๊ณผ์ ์ ํตํด Parameter ๋ค์ gradient ๊ฐ or loss function์ ํธ๋ฏธ๋ถ ๊ฐ์ ๊ณ์ฐํ ์๊ฐ ์์ต๋๋ค.
Neural Network ์ ์์ฐจ์ ์ธ ๊ณ์ฐ๊ณผ์ ์ ํ๋์
computational graph
๋ผ๋ ํํ๋ก ๋ํ๋ผ ์ ์์ต๋๋ค.
์ฒ์์ Parameter๊น์ง ํธ๋ฏธ๋ถ ๊ณผ์ ์ ๊ฑฐ์ณ ์ฌ๋ผ๊ฐ๊ฒ ๋๋ค๋ฉด,
์ฒ์์ parameter ๊ฐ์, learning rate ๊ฐ๊ณผ Gradient ๊ฐ์ ๋ํ ๊ฒ์ ๋นผ์ parameter ๊ฐ์ updateํฉ๋๋ค.
์ค๊ฐ ๊ณผ์ ์ค์ ๊ฒฐ๊ณผ๋ฌผ๋ก์ Gradient ๊ฐ์ด ๊ณ์ฐ ๋์ง๋ง ์ง์ ์ ์ธ gradient descent์๋ ์ฌ์ฉ๋์ง ์๋ ๊ทธ๋ฌํ ๋ ธ๋๋ค์ ๋๋ค.
ํ๋์ ํจ์์ ๋ํ ๊ฒฐ๊ณผ๋ฌผ๋ก ๋์จ๋ค๋ฉด, gradient descent ๊ณผ์ ์์ ํ๋์ ํจ์๋ฅผ ์ชผ๊ฐ์ด ํ๋์ฉ์ ๊ณ์ฐ๊ณผ์ ์ ํ๋ ๋์ ํ๋์ ํจ์๋ฅผ ํธ๋ฏธ๋ถํ์ฌ ๊ตฌํ ์ ์๊ฒ๋ฉ๋๋ค.
Sigmoid Activation
์ ๋ ฅ ์ ํธ๋ฅผ ์ ์ ํ ๊ฒฐํฉํด์ ๋ง๋ค์ด์ง ๊ฐ์ hard threshold๋ฅผ ์ ์ฉํด์ ์ต์ข ์ถ๋ ฅ ๊ฐ์ ๋ด์ด์คฌ๋ ๊ฒ์ ๋ถ๋๋ฌ์ด ํํ์ ํจ์๋ก ๊ทผ์ฌํํ ๊ฒ์ ๋๋ค.
- ์ ํ ๊ฒฐํฉ์ ๊ฒฐ๊ณผ๊ฐ ๋ง์ด๋์ค ๋ฌดํ๋๋ถํฐ ๋ฌดํ๋๊น์ง ๊ฐ์ง ์ ์๋ ์ด๋ฌํ ๊ฐ๋ค์ 0์์ 1์ฌ์ด์ ๊ฐ์ผ๋ก mapping ์์ผ์ค๋๋ค.
- ์ด ๊ฐ์ logistic regression์ ๊ฒฝ์ฐ positive class์ ๋์ํ๋ ์์ธก๋ ํ๋ฅ ๊ฐ์ผ๋ก ํด์ํ๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ด ์์ต๋๋ค.
- Back propagation์์ ํด๋น sigmoid function์ ํธ๋ฏธ๋ถ ์ 0์์ 1/4 ์ฌ์ด์ ๊ฐ์ ๊ฐ์ง๋ฏ๋ก, ์ด ๊ฐ์ด ๊ณฑํด์ ธ ์ ๋ ฅ ๊ฐ์ Gradient ๊ฐ ๊ฒฐ์ ๋๊ณ , ์ด๋ 1๋ณด๋ค ๋ง์ด ์์ ๊ฐ์ด๋ฏ๋ก sigmoid node๋ฅผ back propagation ํ ๋๋ง๋ค gradient ๊ฐ์ด ์ ์ฐจ ์์์ง๋ ์์์ด ๋ณด์ด๊ฒ ๋ฉ๋๋ค.
- ์ฌ๋ฌ Layer์ ์๋ Sigmoid node์ ๋ํด์ Back propagationํ ๋๋ง๋ค gradient ๊ฐ์ด ๊ณ์ํด์ ๊น์ฌ ๋๊ฐ๋๋ค.
- ๊ทธ๋์ Gradient๊ฐ 0์ ๊ฐ๊น์์ง๋ ๋ฌธ์ ์ ์ด ์ผ๊ธฐํ๊ฒ ๋ฉ๋๋ค.
- ๊ฒฐ๊ตญ ์ด๊ฒ์ผ๋ก Back Propagation ํ์ ๋, ์์ชฝ Layer์ ์๋ parameter๋ค์ ๋๋ฌํ๋ gradient ๊ฐ์ด ์ผ๋ฐ์ ์ผ๋ก ๊ต์ฅํ ์์ ๊ฐ์ ์ป๊ฒ ๋๊ณ learning rate๋ฅผ ์ฌ์ฉํ์ ๋, ์์ชฝ parameter๋ค์ Gradient ๊ฐ์ด ์์์ผ๋ก ์ธํด parameter๋ค์ update๊ฐ ๊ฑฐ์ ์ผ์ด๋์ง ์๊ฒ ๋๋ ํ์์ด ๋ฐ์ํฉ๋๋ค.
- ๊ทธ๋ฌ๋ฏ๋ก, Neural network์ ์ ์ฒด์ ์ธ ํ์ต์ด ๋๋ ค์ง๊ฒ ๋ฉ๋๋ค.
- ์ด๋ฅผ Gradient Vanishing์ด๋ผ๊ณ ํฉ๋๋ค.
- ์ด๋ฌํ ๋ฌธ์ ์ ์ ํด๊ฒฐํ๊ธฐ ์ํด์ ๋ค์๊ณผ ๊ฐ์ Activation function์ด ์ ์๋ฉ๋๋ค.
Tanh Activation
tanh๋ผ๋ ํจ์๋ฅผ sigmoid ํ์ฑ ํจ์ ๋์ ์
- ๋ง์ด๋์ค ๋ฌดํ๋๋ถํฐ ๋ฌดํ๋๊น์ง์ ๋ฒ์๋ฅผ ๋ง์ด๋์ค 1๋ถํฐ 1์ฌ์ด์ ๋ฒ์๋ก Mapping ์์ผ์ฃผ๊ฒ ๋ฉ๋๋ค.
- ํ์ต์ ์์ด ์๋๊ฐ ๋ ๋น ๋ฆ ๋๋ค.
- ์ด tanh activation function์ ํธ๋ฏธ๋ถํ๋ฉด 0์์ 1/2 ์ฌ์ด์ ๊ฐ์ด ๋์ค๋ฏ๋ก Back propagationํ ๋๋ง๋ค gradient ๊ฐ์ด ํจ์ฌ ๊น์ ๋๋ค.
- ๊ทธ๋์ Gradient vanishing ๋ฌธ์ ๋ฅผ ์ฌ์ ํ ๊ฐ์ง๊ฒ ๋ฉ๋๋ค.
- layer๋ค์ด ๋ง์ด ์์์ ๋, gradient vanishing ๋ฌธ์ ๋ฅผ ๊ทผ๋ณธ์ ์ผ๋ก ์ ํํ ์ ์๋ ๋ค์๊ณผ ๊ฐ์ Activation function์ด ์กด์ฌํฉ๋๋ค.
ReLU (Rectified Linear Unit)
- ์ ํ๊ฒฐํฉ์ผ๋ก ๋์จ ๊ฒฐ๊ณผ๊ฐ์ด ๋ง์ด๋์ค ๋ฌดํ๋๋ถํฐ ๋ฌดํ๋๊น์ง์ ๊ฐ์ ๊ฐ์ง ๋, 0๋ณด๋ค ์์ ๊ฒฝ์ฐ๋ 0์ผ๋ก, 0 ๋ณด๋ค ํฐ ๊ฒฝ์ฐ๋ ๊ทธ ๊ฐ์ ๊ณง์ด๊ณง๋๋ก ๋ด์ด์ฃผ๋ ํํ์ ๋๋ค.
- ReLU ํจ์๋ sigmoid๋ tanh function์ ๋นํด ๋ ๋น ๋ฅด๊ฒ ๊ณ์ฐ๋ ์ ์๋ค๋ ์ฅ์ ์ด ์์ต๋๋ค.
- ์ ํ๊ฒฐํฉ์ผ๋ก ๋์จ ๊ฐ์ด ์์๋ผ๋ฉด gradient๊ฐ 0์ด ๋๋ค๋ ๋จ์ ์ด ์์ต๋๋ค.
- ๊ทธ๋ฌ๋ gradient vanishing ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์๋ ์ธก๋ฉด์์ ์ฅ์ ์ด ์์ต๋๋ค.
- ๊ทธ๋ฌ๋ฏ๋ก, Lyaer๊ฐ ๋ง์ด ์์ฌ์ ธ ์์ ๋ Neural Network์ ํ์ต์ ํจ์ฌ ๋ ๋น ๋ฅด๊ฒ ๋ง๋ค์ด์ค ์ ์์ต๋๋ค.
Batch Normalization
- ํ์ต์ ์ต์ ํ๋ฅผ ๋๋ layer
- Tanh, sigmoid ๋ฑ ์ค์ ์ ๋์ ๊ฐ์ด ์๋ Satureted ๋ ๊ฐ์ ์์ด์ gradient ๊ฐ์ด 0์ ์๋ ดํ๊ฒ ๋๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํฉ๋๋ค.
- ๊ทธ๋ ๋ค๋ฉด Back Propagation ๊ณผ์ ์์ ํ์ต ์๊ทธ๋์ ์ค ์ ์๋ ์ํฉ์ด ๋ ์๊ฐ ์์ต๋๋ค.
- ๊ทน๋จ์ ์ธ ๊ฒฝ์ฐ ์ ํ๊ฒฐํฉ์ ๊ฒฐ๊ณผ๊ฐ ๋ชจ๋ ์์ ๊ฐ์ด ๋์ฌ ๊ฒฝ์ฐ ReLU function์์๋ gradient๊ฐ ๋ฌด์กฐ๊ฑด 0์ด ๋๊ธฐ ๋๋ฌธ์ ํด๋น node๊ฐ ํ์ต์ด ์ ํ ์งํ๋์ง ์๊ฒ ๋๋ ๋ฌธ์ ์ ์ด ์กด์ฌํฉ๋๋ค.
- forward propagation ๊ณผ์ ์์ ์ ํ๊ฒฐํฉ์ ๊ฐ์ด 0์ฃผ๋ณ์ ๊ฐ์ผ๋ก ์ด๋ฃจ์ด์ง๋ค๋ฉด ํ์ต์ ์ฉ์ดํ ๊ฒ์ ๋๋ค.
- ์ด๊ฒ์ด batch normalization์ ๊ธฐ๋ณธ์ ์ธ ์์ด๋์ด๊ฐ ๋ฉ๋๋ค.
- ์ด๋ค Mini batch๊ฐ ์ฃผ์ด์ ธ data๊ฐ, mini batch size๊ฐ 10์ด๋ผ๊ณ ํ๋ฉด, ํด๋น ๊ฐ๋ค์ด tanh์ ์
๋ ฅ์ผ๋ก ์ฃผ์ด์ง๊ธฐ ์ ์,
ํ๊ท ์ด 0, ๋ถ์ฐ์ด 1์ธ ๊ฐ์ Normalizationํ๋ค๋ฉด tanh์ ์ ๋ ฅ์ผ๋ก ์ฃผ์ด์ง๋ ๊ฐ์ ๋๋ต์ ์ธ ๋ฒ์๋ฅผ 0์ ์ค์ฌ์ผ๋ก ํ๋ ๊ทธ๋ฌํ ๋ถํฌ๋ก ๋ง๋ค ์ ์์ ๊ฒ์ ๋๋ค. - ๊ทธ ๊ฐ๋ค์ ๋ณํ ํญ์ด ๋๋ฌด ์์ผ๋ฉด tanh์ Output ๋ํ ๋ณํ ์์ฒด๊ฐ ๋๋ฌด ์์์ data item๋ค ๊ฐ์ ์ฐจ์ด๋ฅผ Neural network๊ฐ ๊ตฌ๋ถํ๊ธฐ ์ด๋ ค์ธ ๊ฒ์ ๋๋ค.
- ๊ทธ ๊ฐ๋ค์ ๋ณํ ํญ์ด ๋๋ฌด ํฌ๋ฉด ์ ์ ํ ์์ญ์ ๊ฐ์ผ๋ก ์ปจํธ๋กค์ ์ด๋ ค์์ด ์๊ฒ ๋ฉ๋๋ค.
- ์ด ๊ณผ์ ์ fully-connected layer ํน์ ์ ํ ๊ฒฐํฉ์ ์ํํ ์ดํ์ ํ์ฑ ํจ์๋ก ๋ค์ด๊ฐ๊ธฐ ์ง์ ์ ์ํ๋ฉ๋๋ค.
๊ทธ๋ฌ๋ data์ ์๋ฏธ๊ฐ ์ค์ํ ๊ฒฝ์ฐ์๋ ์์ ๊ฐ์ BN(Batch Normalization) ๊ณผ์ ์ ์ฐ๋ฆฌ์ data๋ฅผ ์์ด๋ฒ๋ฆฌ๊ฒ ๋ง๋๋ ๊ณผ์ ์ด๋ผ๊ณ ๋ณผ ์ ์์ต๋๋ค.
์ด๋ฌํ ์์ด๋ฒ๋ฆฐ ์ ๋ณด๋ฅผ ์ํ๋ ์ ๋ณด๋ก ๋ณต์ํ ์ ์๊ฒ๋ํ๋ ๊ทธ๋ฐ ๋ ๋ฒ์งธ ๋จ๊ณ๊ฐ BN์ ๋ ๋ฒ์งธ ๋จ๊ณ๊ฐ ๋ฉ๋๋ค.
๊ทธ๋ฌ๋ฏ๋ก BN ๊ณผ์ ์ดํ์๋ ํด๋น Data๊ฐ ๊ฐ์ง๋ ๊ณ ์ ์ ํ๊ท ๊ณผ ๋ถ์ฐ ๋ํ ๋ณต์ํด๋ผ ์ ์๋ ๋ฅ๋ ฅ์ ๋ถ์ฌํด ์ฃผ๊ฒ ๋ ๊ฒ์ ๋๋ค.
์ถ๊ฐ๋ก Gradient Vanishing ๋ฌธ์ ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ํด๊ฒฐํ ์ ์๋ ์ด๋ฐ ์ข์ ์ฅ์น๊ฐ ๋๋ ๊ฒ์ ๋๋ค.
'Artificial Intelligence > Deep Learning' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[XAI] ์ค๋ช ๊ฐ๋ฅํ AI (Explainable AI) (0) | 2023.01.26 |
---|---|
[Convolutional Neural Networks and Image Classification] Part 3 (0) | 2023.01.24 |
[Deep Neural Network] part 1 - 2 (0) | 2023.01.22 |
[Deep Neural Network] part 1 - 1 (0) | 2023.01.22 |
[Machine Learning] ์ ๊ฒฝ๋ง ๊ธฐ์ด 1 (0) | 2023.01.17 |