Как создать компилируемую модель tf.keras с несколькими тензорными входами? - PullRequest
0 голосов
/ 07 мая 2020

Это с tf 2.1.0

Следующее работает до тех пор, пока вы не попытаетесь вызвать скомпилированную модель. Есть ли что-то, что нужно сделать, чтобы методы .compile и .fit работали для нескольких тензорных входов?

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras

tf.keras.backend.set_floatx('float64')

m = 250  # samples
n_x = 1  # dim of x
n_tau = 11

x = (2 * np.random.rand(m, n_x).astype(np.float64) - 1) * 2
i = np.argsort(x[:, 0])
x = x[i]  # to make plotting nicer
A = np.random.randn(n_x, 1)
y = x ** 2 + 0.3 * x + 0.4 * np.random.randn(m, 1).astype(np.float64)
y = y.dot(A)  # y is 1d
y = y[:, :, None]
tau = np.linspace(1.0 / n_tau, 1 - 1.0 / n_tau, n_tau).astype(np.float64)
tau = tau[None, :, None]

def loss(tau_y, u):
    tau = tau_y[0]
    y = tau_y[1]
    u = y - u
    res = u ** 2 * (tau - tf.where(u <= np.float64(0.0), np.float64(1.0), np.float64(0.0)))
    return tf.reduce_sum(tf.reduce_mean(res, axis=[1, 2]), axis=0)

tf.keras.backend.set_floatx('float64')
class My(tf.keras.models.Model):
   def __init__(self):
       super().__init__()
       self._my_layer = tf.keras.layers.Dense(1, dtype=tf.float64)
   def call(self, inputs):
       tau = inputs[0]
       y = inputs[1]
       tf.print(tau.shape, y.shape)
       return self._my_layer(tau)


model = My()
u = model((tau, y)) # calling model works
l = loss((tau, y), model((tau, y))) # call loss works
opt = tf.keras.optimizers.Adam(learning_rate=0.01)
model.compile(optimizer=opt, loss=loss)

# this fails with the error below
model.fit((tau, y), (tau, y))
# ValueError: Error when checking model target: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 array(s), for inputs ['output_1'] but instead got the following list of 2 arrays: [array([[[0.09090909],
#         [0.17272727],
#         [0.25454545],
#         [0.33636364],
#         [0.41818182],
#         [0.5       ],
#         [0.58181818],
#         [0.66363636],
#         [0.74545455],
#  ...
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...