Я следую статье Траска о создании голой нейронной сети в Python. Хотя он строит 1-слойную сеть (которая отображает 3 входа на один выход) и 2-слойную сеть (которая имеет 3 входа, 4 скрытых нейрона и выходной слой с одним нейроном).
Моей задачей было построить сеть, которая может приближаться к функции Y = X1 + X2 + X3 . Я предоставляю сети Y, и она угадывает значения x1, x2 и x3.
Для этого я модифицировал вышеуказанную сеть. Я попытался инвертировать обе вышеупомянутые сети, т.е. попытался сопоставить один вход с 3 выходами. Я сделал это с помощью Tensorflow API, но хочу реализовать его без таких API высокого уровня. Моя сеть выглядит так:
Вот как я ее реализую:
import numpy as np
# sigmoid function
def nonlin(x,deriv=False):
if(deriv==True): #Find gradient
return x*(1-x)
return 1/(1+np.exp(-x))
#Training Data
##OUTPUT
y=np.random.randint(1,255,size=(50,3)).astype(int) #dims:(m,3), m is training examples
##INPUT
X = np.sum(y, axis = 1, keepdims=True) #dims:(m,1)
#Weights for synapses
##between Input layer and hidden layer
syn0 = 2*np.random.random((1,4)) - 1
##between hidden layer Output layer
syn1 = 2*np.random.random((4,3)) - 1
#Training
for iter in range(100):
# forward propagation
l0 = X
l1 = nonlin(np.dot(l0,syn0))
l2 = nonlin(np.dot(l1,syn1))
# how much did we miss?
l2_error = y-l2
#Visualizing the error change
if (iter% 100) == 0:
print ("Error:" + str(np.mean(np.abs(l2_error))))
l2_delta = l2_error*nonlin(l2, deriv=True)
l1_error = l2_delta.dot(syn1.T)
# multiply how much we missed by the
# slope of the sigmoid at the values in l1
l1_delta = l1_error * nonlin(l1,True)
# update weights
syn1 += l1.T.dot(l2_delta)
syn0 += l0.T.dot(l1_delta)
Но я замечаю, что весане обновляются, как они должны. Они скоро остановятся. Я не уверен, в чем может быть проблема. Я пытался проверить сеть на тестовом значении = 100.
test_case = np.array([[300]])
l1_out = nonlin(np.dot(test_case,syn0))
l2_out = nonlin(np.dot(l1_out,syn1))
print(l2_out)
Но я получаю странные значения x1 , x2 , x3 . Хотя я должен получить 100, 100 и 100 соответственно.
В чем может быть проблема такой простой реализации?
(PS: я думаю, что проблема с нормализацией. Если да, как я должен это реализовать?)