Стохастический градиентный спуск в Python - PullRequest
1 голос
/ 09 апреля 2019

Я пытаюсь реализовать стохастический градиентный спуск с нуля в Python, чтобы предсказать конкретную полиномиальную функцию.Я чувствую, что получил правильную общую структуру, но мои веса (thetas), видимо, не обновляются правильно.Это мой код:

from matplotlib import pyplot as plt
import math
import numpy as np

def epsilon():
    '''Adds noise to the data points'''
    return np.random.normal(0, 0.3, 100)

def yFunction(x):
    '''Function to predict'''
    return np.sin(2 * math.pi * x) + epsilon()

def predict(x, thetas):
    '''Predict value of x with the given thetas'''
    prediction = 0
    for i in range(thetas.size):
        prediction += (x ** i) * thetas[i]
    return prediction

# learning rate
alpha = 0.1

# generate random data points
X = np.random.random_sample(100)
y = yFunction(X)

# init weights
thetas = np.random.normal(0, 0.5, 3)

# init loss history
lossHistory = []

for epoch in range(1000):
    # predict
    prediction = predict(X[epoch % 100], thetas)

    # calculate loss
    error = prediction - y[epoch % 100]
    loss = np.sum(error ** 2)

    # update thetas
    if error <= 0:
        thetas += alpha * loss
    else:
        thetas -= alpha * loss

    # log current loss
    lossHistory.append(loss)

# final predictions based on trained model
Y = predict(X, thetas)

# plot the original data along with our line of best fit
fig = plt.figure()
plt.scatter(X, y)
plt.plot(X, Y, "r-")
plt.suptitle("Prediction line over actual values")

# construct a figure that plots the loss over time
fig = plt.figure()
plt.plot(np.arange(0, len(lossHistory)), lossHistory)
fig.suptitle("Training Loss")
plt.show()

Вот такие графики:

Python script results

Я бы предположил, что мне нужно обновить весанезависимо, вместо того, чтобы добавлять / вычитать одну сумму за эпоху, но я не уверен, как правильно распределить потери.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...