Tensorflow: ошибка «логиты и метки должны быть транслируемыми» - PullRequest
0 голосов
/ 02 марта 2020

Я пытаюсь сделать очень простой классификатор нейронной сети для набора данных Iris. Но я получаю эту ошибку:

InvalidArgumentError (see above for traceback): logits and labels must be broadcastable: logits_size= 
[150,10] labels_size=[1,150]

Вот мой код:

iris = datasets.load_iris()
X_train = iris.data
y_train = iris.target
d=4 #dimensions of input features
n_hidden = 10 #units in hiddden layer

X = tf.placeholder(tf.float32, (None, d))
y = tf.placeholder(tf.float32, (None,))

W = tf.Variable(tf.random_normal([d, n_hidden]))
b = tf.Variable(tf.random_normal((n_hidden, )))
affine_transformation = tf.add(tf.matmul(X, W), b)
activation = tf.nn.relu(affine_transformation)

loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y, logits = activation))

optimization = tf.train.AdamOptimizer(0.01).minimize(loss_op)

init = tf.global_variables_initializer()  

with tf.Session() as sess:
  sess.run(init)
  fd = {X: X_train, y: y_train}
  sess.run(optimization, feed_dict=fd)

У кого-нибудь есть идеи, почему это происходит?

...