Elsaの技術日記(徒然なるままに)

主に自分で作ったアプリとかの報告・日記を記載

MENU

勾配について理解してみる

機械学習を勉強してきて、
”勾配って何だろう??”
と分からなくなってきました。。。

勾配は関数の傾きであり偏微分することで計算可能であることは分かります。
分かりますが、具体的にどんなもの??
と疑問に思ってしまいました。。

っということで、勾配を可視化してみることにしました!!


■勾配を確認するにあたって用いる関数

今回用いる関数はそれぞれ、
f:id:Elsammit:20201110231352p:plain
f:id:Elsammit:20201110231517p:plain

で確認してみたいと思います!!

■勾配グラフ化ソースコード

pythonで勾配を可視化していきます。
グラフの範囲はx、yともに-5 ~ 5の範囲とします。

勾配のグラフを作成するにはmatplotlibを用いればOKです。
それぞれコードはこんな感じです。
f:id:Elsammit:20201110231352p:plain
の場合

import numpy as np
import matplotlib.pyplot as plt

x1 = np.arange(-5,5,0.5)
y1 = np.arange(-5,5,0.5)

def f(x,y):
    return y*x**2 + y**3

def diffX(a,b):
    return 2*a*b

def diffY(a,b):
    return 3*b**2

xx,yy = np.meshgrid(x1, y1)
diffX = diffX(xx,yy)
diffY = diffY(xx,yy)

plt.quiver(xx, yy, diffX, diffY)
plt.show()

f:id:Elsammit:20201110231517p:plain
の場合

import numpy as np
import matplotlib.pyplot as plt

x1 = np.arange(-5,5,0.5)
y1 = np.arange(-5,5,0.5)

def f(x,y):
    return (x**2)/20 + 9*y**2

def diffX(a,b):
    return a/10

def diffY(a,b):
    return 18*b

xx,yy = np.meshgrid(x1, y1)

diffX = diffX(xx,yy)
diffY = diffY(xx,yy)

plt.quiver(xx, yy, diffX, diffY)
plt.show()

■勾配グラフ結果

勾配をグラフ化した結果はこんな感じになります。
f:id:Elsammit:20201110231352p:plain
f:id:Elsammit:20201110232554p:plain
f:id:Elsammit:20201110231517p:plain
f:id:Elsammit:20201110232641p:plain

■確認関数のグラフ化

勾配を確認したグラフについて3次元で表示してみます。
それぞれこんな感じになります。
f:id:Elsammit:20201110231352p:plain
f:id:Elsammit:20201110233029p:plain

f:id:Elsammit:20201110231517p:plain
f:id:Elsammit:20201110233049p:plain

確かに勾配(傾き)の分布とグラフの形状が一致しているのがわかるかと思います。


3次元のグラフですが、こちらのコードを利用しました。
※今回
f:id:Elsammit:20201110231352p:plain
のみ記載。

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt_2D
from mpl_toolkits.mplot3d import Axes3D

x1 = np.arange(-5,5,0.5)
y1 = np.arange(-5,5,0.5)

def f(x,y):
    return y*x**2 + y**3

def diffX(a,b):
    return 2*a*b

def diffY(a,b):
    return 3*b**2

xx,yy = np.meshgrid(x1, y1)

diffX = diffX(xx,yy)
diffY = diffY(xx,yy)

z = f(xx,yy)

fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(xx, yy, z)

plt.show()

■最後に

今回は勾配についてグラフ化してみました。
誤差逆伝搬法について分からない部分あるので、実際に手を動かして学んでいきたいと思います!!


■参考
https://rightcode.co.jp/blog/information-technology/machine-learning-gradient-differential-understand
https://qiita.com/kazukiii/items/8a20fd38d08657c4a36d