from_tensor_slices () с большим массивом numpy при использовании tf.keras - PullRequest
1 голос
/ 12 марта 2019

У меня есть некоторые тренировочные данные в массиве - он помещается в память, но он больше 2 ГБ. Я использую tf.keras и API набора данных. Чтобы дать вам упрощенный, автономный пример:

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

model = tf.keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=(32,)),
    layers.Dense(64, activation='relu'),
    layers.Dense(1)
])

model.compile(optimizer=tf.train.AdamOptimizer(0.001),
          loss='mse',
          metrics=['mae'])

# generate some big input datasets, bigger than 2GB
data = np.random.random((1024*1024*8, 32))
labels = np.random.random((1024*1024*8, 1))
val_data = np.random.random((100, 32))
val_labels = np.random.random((100, 1))

train_dataset = tf.data.Dataset.from_tensor_slices((data, labels))
train_dataset = train_dataset.batch(32).repeat()

val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels))
val_dataset = val_dataset.batch(32).repeat()

model.fit(train_dataset, epochs=10, steps_per_epoch=30,
      validation_data=val_dataset, validation_steps=3)

Таким образом, выполнение этого приводит к ошибке «Невозможно создать протор-тензор, содержание которого превышает 2 ГБ». В документации перечислено решение этой проблемы: https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays - просто используйте tf.placeholder и затем feed_dict при запуске сеанса.

Теперь главный вопрос: как это сделать с помощью tf.keras? Я не могу ничего подать для заполнителей, когда вызываю model.fit (), и фактически, когда я представил заполнители, я получил ошибки, говорящие: «Вы должны указать значение для тензора заполнителя».

1 Ответ

0 голосов
/ 12 марта 2019

Как и в Estimator API, вы можете использовать from_generator

data_chunks = list(np.split(data, 1024))
labels_chunks = list(np.split(labels, 1024))

def genenerator():
    for i, j in zip(data_chunks, labels_chunks):
        yield i, j

train_dataset = tf.data.Dataset.from_generator(genenerator, (tf.float32, tf.float32))
train_dataset = train_dataset.shuffle().batch().repeat()

Также посмотрите https://github.com/tensorflow/tensorflow/issues/24520

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...