Я пытаюсь выучить RL и тензор потока. К сожалению, есть проблема с кодом, которую я не могу понять, как ее решить. Следующий вызов завершается неудачно:
train_loss, _, train_summary = session.run([loss, opt, all_summary], feed_dict={x_ph: X, y_ph: y})
Я получаю следующую ошибку:
TypeError: Аргумент Fetch Ни один из них не имеет недопустимого класса типов NoneType
Я использую windows 10 ОС.
Что я делаю не так?
Мне действительно нужна помощь, спасибо.
Вот полный код:
import tensorflow as tf
import tensorflow.compat.v1 as tfc
import numpy as np
from datetime import datetime
np.random.seed(10)
tfc.set_random_seed(10)
# Line: y = W*X + b
W, b = 0.5, 1.4
# 100 item data sample set
X = np.linspace(0, 100, num=100)
# add random noise to y
y = np.random.normal(loc=W * X + b, scale=2.0, size=len(X))
# Tensorflow
gr = tf.Graph()
with gr.as_default():
x_ph = tfc.placeholder(shape=[None, ], dtype=tf.float32)
y_ph = tfc.placeholder(shape=[None, ], dtype=tf.float32)
v_weight = tfc.get_variable("weight", shape=[1], dtype=tf.float32)
v_bias = tfc.get_variable("bias", shape=[1], dtype=tf.float32)
# Line computation
out = v_weight * x_ph + v_bias
# compute mean squared error
loss = tf.reduce_mean((out - y_ph) ** 2)
# minimize MSE loss
opt = tfc.train.AdamOptimizer(0.4).minimize(loss)
tf.summary.scalar('MSEloss', loss)
tf.summary.histogram('model_weight', v_weight)
tf.summary.histogram('model_bias', v_bias)
# merge summary
all_summary = tfc.summary.merge_all()
# log summary to file
now = datetime.now()
clock_time = f'{now.day}_{now.hour}.{now.minute}.{now.second}'
file_writer = tfc.summary.FileWriter('log_dir\\' + clock_time, tfc.get_default_graph())
# create session
session = tfc.Session(graph=gr)
session.run(tfc.global_variables_initializer())
# loop to train the parameters
for ep in range(210):
# run optimizer
train_loss, _, train_summary = session.run([loss, opt, all_summary], feed_dict={x_ph: X, y_ph: y})
file_writer.add_summary(train_summary, ep)
# print epoch and loss
if ep % 40 == 0:
print(f'Epoch: {ep}'.ljust(13) + f'MSE: {train_loss:.4f}'.ljust(16) + f'W: {session.run(v_weight)[0]:.3f}'.ljust(11) + f'b: {session.run(v_bias)[0]:.3f}')
print(f'Final weight: {session.run(v_weight)[0]:.3f}, bias: {session.run(v_bias)[0]:.3f}')
file_writer.close()
session.close()