003.01 梯度下降

003.01 梯度下降

建檔日期: 2019/09/17
更新日期: None

相关软件信息:

Win 10 Python 3.7.2 matplotlib 3.2.1

说明:所有内容欢迎引用,只需注明来源及作者,本文内容如有错误或用词不当,敬请指正.

主题: 003.01 梯度下降

  • 如何找到y=f(x)曲线的最低点 ??

A. 地毯式搜索的方法

  1. 格点搜索

    给出一串固定距离的X值, 计算Y值, 找出在X0处, 有最小值的Y0, 得到最低点(X0, Y0)

  2. 随机搜索

    随机产生一串X值, 计算Y值, 找出在X0处, 有最小值的Y0, 得到最低点(X0, Y0)


    搜索式方法最大的缺点, 就是不能保证找到全局的最低点, 甚至是局部的最低点,

B. 解析的方法

基本的原则是斜率为0的点, 就是该局部的最低点或最高点. 可以使用导数算出某点的斜率. d表示小变化, 导数dy/dx就是指x小变化所造成的y小变化, 这就是斜率的定义. 而dy/dx再次导数, 以d2y/dx2表示, 代表dy/dx的变化, 也就是斜率的变化. 如果斜率的变化为正, 表示该点是局部最低点, 反之为局部最高点.

C. 数值分析的方法

本方法采用沿着曲线, 一点一点地寻找到我们的目标点. 最常用的方法就是梯度下降. 大部份的作法都是采用该方法或其修改版. 其方法说明如下:

  1. 随意取一个起点x
  2. 计算出y值
  3. 新的x点为x-(alpha*dy/dx)
    dy/dx代表曲线的变化, 变化越大, 离最低点越远, 变化越小进最低点越近, 因此以dy/dx作为x变化的比例, 再给个常数alpha, 控制一下变化的比例, alpha就称为学习速率, 用来控制步进的大小, 如果太小就必须花更久的时间找到最低点, 如果太大可能会错过我们要找的最低点
  4. 重复步骤3, 直到y值收敛, 也就是说y值不会再因x的变化而变化, 因为dy/dx=0

记得, 我们的目标是要找出全局的最低点, 这才是整个曲线的最低点. 所以光使用这个方法, 还不足达到我们追求的目标.

D. 数值分析的范例:

输出:

003.01 梯度下降


import matplotlib.pyplot as plt

import numpy as np

def function(x):
    y = x*x + 5*x + 8
    return y

def slope(x):
    s = 2*x + 5
    return s

def gradient_descent(x, alpha, num_iterations):
    for i in range(num_iterations):
        y = function(x)
        y_history[i,:] = [x, y]
        update = slope(x)
        x = x - alpha * update

num_iteration =20
start_x = -7
start_y = function(start_x)
alpha = 0.3

y_history = np.zeros((num_iteration, 2))
result = gradient_descent(start_x, alpha, num_iteration)

y_min = np.min(y_history[:,1])
x_min = y_history[np.argmin(y_history[:,1]),0]

x_range = np.linspace(-10,5,100)

plt.plot(x_range, function(x_range), linestyle='-', color='blue', linewidth=2)
plt.axis([-10, 5, 0, 58])
plt.xlabel('x')
plt.ylabel('y=x^2+5x+8')
plt.annotate('Start Point ({:.2f},{:.2f})'.format(start_x, start_y), xy=(start_x, start_y), xytext=(-6, 50), arrowprops=dict(facecolor='black', shrink=0.05),)
plt.annotate('Local Min.({:.2f},{:.2f})'.format(x_min, y_min), xy=(x_min, y_min), xytext=(-2, 40), arrowprops=dict(facecolor='black', shrink=0.05),)
plt.scatter(y_history[:,0], y_history[:,1], s=100 ,color="red")
plt.show()
本作品采用《CC 协议》,转载必须注明作者和本文链接
Jason Yang
讨论数量: 0
(= ̄ω ̄=)··· 暂无内容!

讨论应以学习和精进为目的。请勿发布不友善或者负能量的内容,与人为善,比聪明更重要!