Форма ввода tf.data.Dataset не принимается model.fit () - PullRequest
1 голос
/ 13 июля 2020

Я хотел бы наполнить данными мою модель, применив tf.data.Dataset.

Проверив документацию TF 2.0, я обнаружил, что функция .fit() (https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit ) принимает:

x - набор данных tf.data. Должен возвращать кортеж из (входы, цели) или (входы, цели, sample_weights).

Итак, я написал следующее минимальное доказательство концептуального кода:

from sklearn.datasets import make_blobs
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.metrics import Accuracy, AUC

X, Y = make_blobs(n_samples=500, n_features=2, cluster_std=3.0, random_state=1)

def define_model():
    model = Sequential()
    model.add(Dense(units=1, activation="sigmoid", input_shape=(2,)))
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=[AUC(), Accuracy()])
    return model

model = define_model()

X_ds = tf.data.Dataset.from_tensor_slices(X)
Y_ds = tf.data.Dataset.from_tensor_slices(Y)
dataset = tf.data.Dataset.zip((X_ds, Y_ds))

for elem in dataset.take(1):
    print(type(elem))
    print(elem)

model.fit(x=dataset) #<-- does not work
#model.fit(x=X, y=Y) <-- does work without any problems....

Как упоминалось во втором комментарии, код, который не применяет tf.data.Dataset, работает нормально.

Однако при применении объекта Dataset я получаю следующее сообщение об ошибке:

<class 'tuple'>
(<tf.Tensor: shape=(2,), dtype=float64, numpy=array([-10.42729974,  -0.85439721])>, <tf.Tensor: shape=(), dtype=int64, numpy=1>)
... other output here...
ValueError: Error when checking input: expected dense_19_input to have
shape (2,) but got array with shape (1,)

From Насколько я понимаю в документации, созданный мной набор данных должен быть именно тем объектом кортежа, который ожидает метод fit.

Я не понимаю это сообщение об ошибке.

Что я здесь делаю не так?

1 Ответ

1 голос
/ 13 июля 2020

Когда вы передаете набор данных в fit, ожидается, что он будет генерировать пакеты напрямую, а не отдельные примеры. Вам просто нужно сгруппировать набор данных перед обучением.

dataset = dataset.batch(batch_size)
model.fit(x=dataset)
...