получение 'AttributeError: объект' tuple 'не имеет атрибута' dtype '' при использовании tf.metrics.mean_absolute_error - PullRequest
0 голосов
/ 25 января 2019

Я хочу обучить очень простую сеть с одним скрытым слоем, но я не могу обучить сеть. Я продолжаю получать ошибку в названии. Хотя, когда я определяю потери как просто y - a2, это не проблема (за исключением того, что результат - все Nan, а не то, что я ожидаю). Чего мне не хватает?

import tensorflow as tf
import numpy as np

# import data
X = np.array([[0,0,1], #XOR prob
              [0,1,1],
              [1,0,1],
              [1,1,1],])


# output dataset, same as before
y = np.array([[0,1,1,0]]).T


# ----------------design network architecture
# define variables

X = tf.convert_to_tensor(X, dtype=tf.float32) # convert np X to a tensor
y = tf.convert_to_tensor(y, dtype=tf.float32) # convert np y to a tensor
W1 = tf.Variable(tf.random_normal([3, 4]))
W2 = tf.Variable(tf.random_normal([4, 1]))
a1 = tf.matmul(X, W1)
a2 = tf.matmul(a1, W2)

# define operations

# ---------------define loss and select training algorithm
loss = tf.metrics.mean_absolute_error(labels=y, predictions=a2)
#loss = y - a2
optimizer = tf.train.GradientDescentOptimizer(0.1)
train = optimizer.minimize(loss)

# ----------------run graph to train and get result
with tf.Session() as sess:

    #initialize variables
    sess.run(tf.initialize_all_variables())

    for i in range(60000):
        sess.run(train)
        if i % 10000 == 0:
            print("Loss: ", sess.run(loss))

    print("Activation: ", sess.run(a2))
    print("Loss: ", sess.run(loss))

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...