[Supervised Learning] ์ง๋ ํ์ต
2023. 2. 8. 17:55
๐ง๐ป๐ป์ฉ์ด ์ ๋ฆฌ
Supervised Learning
์ง๋ ํ์ต
- ํน์ง : ์ด๋ฏธ ์ธก์ ๋ ๋ฐ์ดํฐ์ ์์ฑ ์ฌ์ด์ ๊ด๊ณ
- "์ด๋ฏธ ์ธก์ ๋"
- ์ฌ๋ฌ ์ซ์๋ฅผ ๋ชจ์ ํ์ผ๋ก ๋ํ๋ด์ด ndarray ํํ๋ก ๋ฐ์ดํฐ์ ๋ฐฐ์น batch๋ฅผ ๋ง๋ค์ด ๋ชจ๋ธ์ ์ ๋ ฅ.
- ๊ฐ ๊ด์ฐฐ์ ๋ํ ์์ธก ๊ฒฐ๊ณผ๊ฐ ๋ด๊ธด ํ์ผ๋ก ๊ตฌ์ฑ๋ ndarray ๊ฐ์ฒด๋ฅผ ๋ฐํ.
- ํ์ ๊ธธ์ด๋ ๋ฐ์ดํฐ์ ํน์ง, ์ฆ feature์ ๊ฐ์์ด๋ค.
- feature๋ category์ ๋ฐ๋ผ ๋๋์ด ๋ถ๋ฅํ๋ค.
- ์ด๋ค ๋ฌธ์ ๋ฅผ ํด๊ฒฐํด์ผ ํ๋๊ฐ - target
์ง๋ ํ์ต ๋ชจ๋ธ
- ์ง๋ ํ์ต์ ๋ชฉ์ ์ ndarray ๊ฐ์ฒด๋ฅผ ์ ๋ ฅ๋ฐ์ ๋ ๋ค๋ฅธ ndarray ๊ฐ์ฒด๋ฅผ ์ถ๋ ฅํ๋ ํํ๋ฅผ ๊ฐ์ง ํจ์๋ฅผ ์ฐพ๋ ๊ฒ์ด๋ค.
- ์ด ํจ์๋ ์ฐ๋ฆฌ๊ฐ ์ ์ํ ํน์ง๊ฐ์ ๋ด์ ndarray๊ฐ์ฒด๋ฅผ ์ ๋ ฅ ๋ฐ๊ณ , ๋ชฉํ ์์ฑ์ ํน์ง๊ณผ ๊ฐ๊น์ด ๊ฐ์ ๋ด์ ndarray ๊ฐ์ฒด๋ฅผ ๋ฐํํ๋ ํํ๋ก ๊ด์ฐฐ์ ์์ฑ๊ฐ์ ๋ชฉํ ์์ฑ๊ณผ ๋์์ํฌ ์ ์์ด์ผ ํ๋ค.
์ ํํ๊ท (linear regression)
linear combination๊ณผ ์์ธก๊ฐ์ ๊ธฐ์ค์ ์ ์กฐ์ ํ๋ ํญ์ผ๋ก ๋ชฉํ์พ์ ๊ณ์ฐํ ์ ์๋ค๋ ์์ด๋์ด๋ฅผ ๋ํ๋ธ ๊ฒ.
X์ W์ ํ๋ ฌ๊ณฑ ์ฐ์ฐ.
์์ ๊ฐ์ด ์ฌ๋ฌ ๊ด์ฐฐ์ด ๋ด๊ธด ๋ฐฐ์น์ ๋ํ ์์ธก๊ฐ๋ ๋ง์ฐฌ๊ฐ์ง๋ก ํ๋ ฌ๊ณฑ ์ฐ์ฐ ํ ๋ฒ์ผ๋ก ๊ณ์ฐํ ์ ์๋ค.
์ด๋ ๊ฒ ํ๋ ฌ๊ณฑ ์ฐ์ฐ์ ํตํด ์ฌ๋ฌ ๊ด์ธก์ ๋ํด ์ ํํ๊ท ๋ชจํ์ ๋ฐ๋ฅด๋ ์์ธก๊ฐ์ ํ ๋ฒ์ ๊ณ์ฐํ ์ ์๋ค.
์ ํํ๊ท ๋ชจ๋ธ ํ์ตํ๊ธฐ
- ๋ชจ๋ธ์ ํ์ตํ๋ค๋ ๊ฑด, ์ถ์์ ์ผ๋ก ๋ณด๋ฉด ๋ชจ๋ธ์ ๋ฐ์ดํฐ๋ฅผ ์ ๋ ฅ๋ฐ์ ํ๋ผ๋ฏธํฐ parameter๋ฅผ ์ด์ฉํด ์ด ๋ฐ์ดํฐ๋ฅผ ๋ชจ์ข ์ ๋ฐฉ๋ฒ์ผ๋ก ์ทจํฉํ์ฌ ์์ธก๊ฐ์ ๊ณ์ฐํ๋ค.
- ์ ํํ๊ท ๋ชจํ์ ๋ฐ์ดํฐ X์ ํ๋ผ๋ฏธํฐ W๋ฅผ ์ ๋ ฅ๋ฐ์ ํ๋ ฌ๊ณฑ ์ฐ์ฐ์ ํตํด ์์ธก๊ฐ P๋ฅผ ๊ณ์ฐํ๋ค.
- ๋ชจ๋ธ ํ์ต์ ์ํด ๋ชจ๋ธ์ด ๊ณ์ผํ ์์ธก๊ฐ์ด ์ ํํ์ง์ ๋ํ ์ ๋ณด๊ฐ ๋ ํ์ํ๋ค.
- ์ด๋ฅผ ์ํด, Xbatch์ ๊ฐ ๊ด์ฐฐ์ ๋์ํ๋ ์ณ์ ์์ธก๊ฐ์ ๋ชจ์ ์๋ก์ด ๋ฒกํฐ Ybatch๋ฅผ ๋์ ํ๋ค.
- ๊ทธ๋ฆฌ๊ณ Ybatch์ Pbatch์ ์ํด ๊ฒฐ์ ๋๋ ๋ ๋ค๋ฅธ ๊ฐ์ ๊ณ์ฐํ๋ค.
- ์ด ๊ฐ์ ๋ชจ๋ธ์ ์์ธก์ด ์ผ๋ง๋ ์ ํํ๋์ง๋ฅผ ๋ฐ์ ธ ๊ทธ ์ ํ๋์ ๋ฐ๋ผ ๋ชจ๋ธ์ ๋ถ์ฌ๋๋ '๋ฒ์ '์ญํ ์ ํ๋ค.
- ์ด ๋ฒ์ ์ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ผ๋ก MSE (Mean Squared Error, ํ๊ท ์ ๊ณฑ์ค์ฐจ)๊ฐ ์ ๋นํ๋ค.
- ํ๊ท ์ ๊ณฑ์ค์ฐจ๋ ๊ฐ ๊ด์ฐฐ์ ์์ธก๊ฐ๊ณผ ์ ๋ต์ ์ค์ฐจ์ ์ ๊ณฑ์ ํ๊ท ๋ธ ๊ฐ์ด๋ค.
- ์ด ๊ฐ์ ํตํด W์ ๊ฐ ์์์ ๋ํ ์ด ๊ฐ์ ๊ธฐ์ธ๊ธฐ๋ฅผ ์ ์ ์๋ค.
์ ๊ณผ์ ์ ํธํฅ b๊น์ง ๋ํด์ฃผ๋ฉด ์์ ๊ฐ์ ์์์ด ์์ฑ๋๋ค.
์ ๊ณผ์ ์ ์ฝ๋๋ก ์ดํด๋ณด์.
def forward_linear_regression(X_batch: ndarray,
y_batch: ndarray,
weights: Dict[str, ndarray]
)-> Tuple[float, Dict[str, ndarray]]:
'''
์ ํํ๊ท์ ์๋ฐฉํฅ ๊ณ์ฐ ๊ณผ์
'''
# X์ y์ ๋ฐฐ์น ํฌ๊ธฐ๊ฐ ๊ฐ์์ง ํ์ธ
assert X_batch.shape[0] == y_batch.shape[0]
# ํ๋ ฌ๊ณฑ ๊ณ์ฐ์ด ๊ฐ๋ฅํ์ง ํ์ธ
assert X_batch.shape[1] == weights['W'].shape[0]
# B์ ๋ชจ์์ด 1x1์ธ์ง ํ์ธ
assert weights['B'].shape[0] == weights['B'].shape[1] == 1
# ์๋ฐฉํฅ ๊ณ์ฐ ์ํ
N = np.dot(X_batch, weights['W'])
P = N + weights['B']
loss = np.mean(np.power(y_batch - P, 2))
# ์๋ฐฉํฅ ๊ณ์ฐ ๊ณผ์ ์ ์ค๊ฐ๊ฐ ์ ์ฅ
forward_info: Dict[str, ndarray] = {}
forward_info['X'] = X_batch
forward_info['N'] = N
forward_info['P'] = P
forward_info['y'] = y_batch
return loss, forward_info
'Artificial Intelligence > Deep Learning' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Deep Learning] - MLP(Multilayer Perceptron) (0) | 2023.03.26 |
---|---|
[Deep Learning] - Neural Networks (0) | 2023.03.26 |
[Machine Learning] ์ ๊ฒฝ๋ง ๊ธฐ์ด 3 (0) | 2023.02.08 |
[Machine Learning] ์ ๊ฒฝ๋ง ๊ธฐ์ด 2 (2) | 2023.02.06 |
[XAI] ์ค๋ช ๊ฐ๋ฅํ AI (Explainable AI) (0) | 2023.01.26 |