Как использовать «Предсказание Кона» Keras в tf.data.Dataset.map ()? - PullRequest
0 голосов
/ 03 апреля 2019

Я хотел бы найти способ использовать Keras 'predict_on_batch внутри tf.data.Dataset.map() в TF2.0.

Допустим, у меня есть набор данных numpy

n_data = 10**5
my_data    = np.random.random((n_data,10,1))
my_targets = np.random.randint(0,2,(n_data,1))

data = ({'x_input':my_data}, {'target':my_targets})

и tf.keras модель

x_input = Input((None,1), name = 'x_input')
RNN     = SimpleRNN(100,  name = 'RNN')(x_input)
dense   = Dense(1, name = 'target')(RNN)

my_model = Model(inputs = [x_input], outputs = [dense])
my_model.compile(optimizer='SGD', loss = 'binary_crossentropy')

Я могу создать пакетную dataset с

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)
prediction_dataset = dataset.map(transform_predictions)

, где transform_predictions - это пользовательская функция, которая получает прогнозы от predict_on_batch

def transform_predictions(inputs, outputs):
    predictions = my_model.predict_on_batch(inputs)
    # predictions = do_transformations_here(predictions)
    return predictions

Это выдает ошибку от predict_on_batch:

AttributeError: 'Tensor' object has no attribute 'numpy'

Насколько я понимаю, predict_on_batch ожидает массивный массив, и он получает тензорный объект из набора данных.

Кажется, что одно из возможных решений - это обернуть predict_on_batch в функцию `tf.py_function, хотя я также не смог заставить это работать.

Кто-нибудь знает, как это сделать?

1 Ответ

0 голосов
/ 03 апреля 2019

Dataset.map () возвращает <class 'tensorflow.python.framework.ops.Tensor'>, который не имеет метода numpy ().

Перебор возвратов набора данных <class 'tensorflow.python.framework.ops.EagerTensor'>, который имеет метод numpy ().

Подача нетерпеливого тензора для семейства методов предиката () работает нормально.

Вы можете попробовать что-то вроде этого:

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)

for x,y in dataset:
    predictions = my_model.predict_on_batch(x['x_input'])
    #or 
    predictions = my_model.predict_on_batch(x)
...