Передача> 2 ГБ данных в tf.estimator - PullRequest
0 голосов
/ 08 декабря 2018

У меня есть x_train и y_train массивы numpy, каждый из которых> 2 ГБ.Я хочу обучить модель с использованием API tf.estimator, но получаю ошибки:

ValueError: Cannot create a tensor proto whose content is larger than 2GB

Я передаю данные, используя:

def input_fn(features, labels=None, batch_size=None,
             shuffle=False, repeats=False):
    if labels is not None:
        inputs = (features, labels)
    else:
        inputs = features
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    if shuffle:
        dataset = dataset.shuffle(shuffle)
    if batch_size:
        dataset = dataset.batch(batch_size)
    if repeats:
        # if False, evaluate after each epoch
        dataset = dataset.repeat(repeats)
    return dataset

train_spec = tf.estimator.TrainSpec(
    lambda : input_fn(x_train, y_train,
                      batch_size=BATCH_SIZE, shuffle=50),
    max_steps=EPOCHS
)

eval_spec = tf.estimator.EvalSpec(lambda : input_fn(x_dev, y_dev))

tf.estimator.train_and_evaluate(model, train_spec, eval_spec)

Документация tf.data упоминает об этой ошибке и предоставляет решение с использованием традиционного API TenforFlow с заполнителями.К сожалению, я не знаю, как это можно перевести в API tf.estimator?

1 Ответ

0 голосов
/ 10 декабря 2018

Решение, которое работало для меня, было использовать

tf.estimator.inputs.numpy_input_fn(x_train, y_train, num_epochs=EPOCHS,
                                   batch_size=BATCH_SIZE, shuffle=True)

вместо input_fn.Единственная проблема заключается в том, что tf.estimator.inputs.numpy_input_fn вызывает предупреждения об устаревании, поэтому, к сожалению, это также перестанет работать.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...