[Machine Learning] ์‹ ๊ฒฝ๋ง ๊ธฐ์ดˆ 2

2023. 2. 6. 02:10
๐Ÿง‘๐Ÿป‍๐Ÿ’ป์šฉ์–ด ์ •๋ฆฌ

์—ฐ์‡„๋ฒ•์น™ - Chain rule

 

 

์—ฐ์‡„๋ฒ•์น™์— ๋Œ€ํ•œ ์ˆ˜์‹์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

์ˆ˜์‹

์—ฐ์‡„๋ฒ•์น™

์ฝ”๋“œ

def sigmoid(x: ndarray) -> ndarray:
    '''
    ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์€ ndarray์˜ ๊ฐ ์š”์†Œ์— ๋Œ€ํ•œ sigmoid ํ•จ์ˆซ๊ฐ’์„ ๊ณ„์‚ฐํ•œ๋‹ค.
    '''
    return 1 / (1 + np.exp(-x))
def chain_deriv_2(chain:Chain,
                 input_range: ndarray) -> ndarray:
    '''
    ๋‘ ํ•จ์ˆ˜๋กœ ๊ตฌ์„ฑ๋œ ํ•ฉ์„ฑํ•จ์ˆ˜์˜ ๋„ํ•จ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด ์—ฐ์‡„๋ฒ•์น™์„ ์‚ฌ์šฉํ•จ
    (f2(f1(x)))` = f2`(f1(x)) * f1`(x)
    '''
    
    assert len(chain) == 2, \
    "์ธ์ž chain์˜ ๊ธธ์ด๋Š” 2์—ฌ์•ผ ํ•จ"
    
    assert input_range.ndim == 1, \
    "input_range๋Š” 1์ฐจ์› ndarray์—ฌ์•ผ ํ•จ"
    
    f1 = chain[0]
    f2 = chain[1]
    
    # df1/dx
    f1_of_x = f1(input_range)
    
    #df1/du
    df1dx = deriv(f1, input_range)
    
    #df2/du(f1(x))
    df2du = deriv(f2, f1(input_range))
    
    # ๊ฐ ์  ๋ผ๋ฆฌ ๊ฐ’์„ ๊ณฑํ•จ
    return df1dx * df2du

 

def plot_chain(ax,
               chain: Chain, 
               input_range: ndarray) -> None:
    '''
    2๊ฐœ ์ด์ƒ์˜ ndarray -> ndarray ๋งคํ•‘์œผ๋กœ ๊ตฌ์„ฑ๋œ ํ•ฉ์„ฑํ•จ์ˆ˜์˜
    ๊ทธ๋ž˜ํ”„๋ฅผ input_range ๊ตฌ๊ฐ„์— ๋Œ€ํ•ด ์ž‘๋„ํ•จ.
    
    ax: ์ž‘๋„์— ์‚ฌ์šฉํ•  matplotlib์˜ ์„œ๋ธŒํ”Œ๋กฏ
    '''
    
    assert input_range.ndim == 1, \
    "input_range๋Š” 1์ฐจ์› ndarray์—ฌ์•ผ ํ•จ"

    output_range = chain_length_2(chain, input_range)
    ax.plot(input_range, output_range)
def plot_chain_deriv(ax,
                     chain: Chain,
                     input_range: ndarray) -> ndarray:
    '''
    ์—ฐ์‡„๋ฒ•์น™์„ ์ด์šฉํ•ด ํ•ฉ์„ฑํ•จ์ˆ˜์˜ ๋„ํ•จ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ  ๊ทธ๋ž˜ํ”„๋ฅผ ์ž‘๋„ํ•จ.
    
    ax: ์ž‘๋„์— ์‚ฌ์šฉํ•  matplotlib์˜ ์„œ๋ธŒํ”Œ๋กฏ
    '''
    output_range = chain_deriv_2(chain, input_range)
    ax.plot(input_range, output_range)

 

 

fig, ax = plt.subplots(1, 2, sharey=True, figsize=(16, 8))  # 2 Rows, 1 Col

chain_1 = [square, sigmoid]
chain_2 = [sigmoid, square]

PLOT_RANGE = np.arange(-3, 3, 0.01)
plot_chain(ax[0], chain_1, PLOT_RANGE)
plot_chain_deriv(ax[0], chain_1, PLOT_RANGE)

ax[0].legend(["$f(x)$", "$\\frac{df}{dx}$"])
ax[0].set_title("$f(x) = sigmoid(square(x))$์˜ ํ•จ์ˆ˜์™€ ๋„ํ•จ์ˆ˜")

plot_chain(ax[1], chain_2, PLOT_RANGE)
plot_chain_deriv(ax[1], chain_2, PLOT_RANGE)
ax[1].legend(["$f(x)$", "$\\frac{df}{dx}$"])
ax[1].set_title("$f(x) = square(sigmoid(x))$์˜ ํ•จ์ˆ˜์™€ ๋„ํ•จ์ˆ˜")

 

 

 

์œ„ ์ฝ”๋“œ๋กœ ์œ„์™€ ๊ฐ™์€ ๊ฒฐ๊ณผ๋ฅผ ๋‚ผ ์ˆ˜ ์žˆ๋‹ค.

BELATED ARTICLES

more