Я новичок в изучении питона с тензорным потоком. Я написал код для обучения модели линейной регрессии. Однако код не работает (он предсказывает очень большое значение для параметров). Хотя мне потребовалось несколько часов, я не смог решить проблему. Вот мой код. Пожалуйста, помогите мне ...
*** во время изучения кода, я исправил проблему. Когда я изменил стандартные отклонения «x_data» и «noise» на «1», код работал правильно. Тем не менее, я все еще не мог понять, почему код не прошел обучение модели регрессии для больших стандартных отклонений «x_data» и «noise».
import numpy as np
import tensorflow as tf
w_true = [[3]]
b_true = [[1]]
x_data = np.random.normal(0,10,[1000,1]) # change "10" to "1"
noise = np.random.normal(0,5,[1,1000]) # change "5" to "1"
y_data = np.matmul(w_true,x_data.T) + b_true + noise
NUM_STEPS = 10
g = tf.Graph()
with g.as_default():
x = tf.placeholder(dtype=tf.float32,shape=[None,1],name='x_data')
y = tf.placeholder(dtype=tf.float32,shape=None,name='y_data')
with tf.name_scope('inference') as scope:
w_fitted = tf.Variable(np.random.normal(0, 1, [1, 1]), dtype=tf.float32, name='w_fitted')
b_fitted = tf.Variable(np.random.normal(0, 1, [1, 1]), dtype=tf.float32, name='b_fitted')
y_pred = tf.matmul(w_fitted,tf.transpose(x)) + b_fitted
with tf.name_scope('loss') as scope:
delta = y - y_pred
square = tf.square(delta)
loss = tf.reduce_mean(square)
with tf.name_scope('train') as scope:
learning_rate = 0.5
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for iStep in range(NUM_STEPS):
sess.run(train,{x: x_data, y: y_data})
print(iStep,sess.run([w_fitted,b_fitted]))
вот результат кода
0 [array([[372.83685]], dtype=float32), array([[0.59495085]], dtype=float32)]
1 [array([[-36698.977]], dtype=float32), array([[36.981976]], dtype=float32)]
2 [array([[3642251.8]], dtype=float32), array([[-3573.9858]], dtype=float32)]
3 [array([[-3.614514e+08]], dtype=float32), array([[354772.47]], dtype=float32)]
4 [array([[3.5869893e+10]], dtype=float32), array([[-35207096.]], dtype=float32)]
5 [array([[-3.559673e+12]], dtype=float32), array([[3.4938926e+09]], dtype=float32)]
6 [array([[3.5325656e+14]], dtype=float32), array([[-3.467286e+11]], dtype=float32)]
7 [array([[-3.5056655e+16]], dtype=float32), array([[3.4408898e+13]], dtype=float32)]
8 [array([[3.4789694e+18]], dtype=float32), array([[-3.4146842e+15]], dtype=float32)]
9 [array([[-3.452476e+20]], dtype=float32), array([[3.3886797e+17]], dtype=float32)]