Как преобразовать массивы NumPy объектов и меток в набор данных TensorFlow, который можно использовать для model.fit ()? - PullRequest
2 голосов
/ 01 мая 2020

У меня есть две простые NumPy функции и метки массивов:

features = np.array([
    [6.4, 2.8, 5.6, 2.2],
    [5.0, 2.3, 3.3, 1.0],
    [4.9, 2.5, 4.5, 1.7],
    [4.9, 3.1, 1.5, 0.1],
    [5.7, 3.8, 1.7, 0.3],
])
labels = np.array([2, 1, 2, 0, 0])

Я преобразую эти два NumPy массива в набор данных TensorFlow следующим образом:

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

Я определяю и скомпилируйте модель:

model = keras.Sequential([
    keras.layers.Dense(5, activation=tf.nn.relu, input_shape=(4,)),
    keras.layers.Dense(3, activation=tf.nn.softmax)
])
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

Теперь я попытался обучить модель, используя fit() метод:

model.fit(dataset, epochs=100)

и получаю ошибку:

ValueError: Error when checking input: expected dense_input to have shape (4,) but got array with shape (1,)

Если я предоставьте NumPy массивов функций и меток непосредственно методу fit(), тогда все хорошо.

model.fit(features, labels, epochs=100)

Результаты:

Train on 5 samples
Epoch 1/100
5/5 [==============================] - 0s 84ms/sample - loss: 1.8017 - accuracy: 0.4000

Epoch 2/100
5/5 [==============================] - 0s 0s/sample - loss: 1.7910 - accuracy: 0.4000

...............................
Epoch 100/100
5/5 [==============================] - 0s 0s/sample - loss: 1.2484 - accuracy: 0.2000

Если я правильно понял, мне нужно создать набор данных TensorFlow, который вернет кортеж (features, labels). Итак, Как преобразовать массивы NumPy объектов и меток в набор данных TensorFlow, который можно использовать для model.fit()?

1 Ответ

1 голос
/ 01 мая 2020

Просто установите размер партии при создании Dataset:

batch_size = 2
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(batch_size)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...