Есть ли способ использовать tf.keras.model.predict в конвейере tf.data? - PullRequest
0 голосов
/ 09 апреля 2020

У меня есть обученная модель, которую я хотел бы использовать в tf.data конвейере для второй модели. Когда я пытаюсь сделать это, я получаю ValueError: Unknown graph. Aborting. Я не знаю, что делать с этим сообщением об ошибке.

Мой код выглядит примерно так:

def load_data(..., model):
    # code to load an image
    files = tf.data.Dataset.from_tensor_slices(file_list)
    images = files.map(load_image_from_file) 

    def pass_image_through_model(img):
        return model.predict(img, steps=1)

    dataset = images.map(pass_image_through_model)
    return dataset

Что не так с этим? Я получаю ошибку:

    /home/.../code/dataloader.py:236 pass_image_through_model  *
        return model.predict(img, steps=1)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:1013 predict
        use_multiprocessing=use_multiprocessing)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:728 predict
        callbacks=callbacks)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:189 model_iteration
        f = _make_execution_function(model, mode)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:571 _make_execution_function
        return model._make_execution_function(mode)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2131 _make_execution_function
        self._make_predict_function()
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2121 _make_predict_function
        **kwargs)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3760 function
        return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3644 __init__
        raise ValueError('Unknown graph. Aborting.')

    ValueError: Unknown graph. Aborting.

Ответы [ 2 ]

0 голосов
/ 09 апреля 2020

Один из самых простых способов решить эту проблему - передать входные данные непосредственно в модель, а не использовать метод model.predit. Причина этого в том, что model.predict возвращает numpy.ndarray. Это вызывает ошибку, потому что tf.data использует выполнение графа, что означает, что лучше иметь любую операцию ввода И выводить тензор в этом графе.

Ниже приведен быстрый рабочий пример этого.

import tensorflow as tf

# Create example model
inputs = tf.keras.Input((1,))
out = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs, out)

def map_fn(row):
    return model(row)


# Create some input data 
a = tf.constant([1, 2])

# Create the dataset
ds = tf.data.Dataset.from_tensor_slices(a).batch(1)
model_mapped_ds = ds.map(lambda x: map_fn(x))

for el in model_mapped_ds:
    print(el)

Наконец, ниже показано, как это будет выглядеть при использовании.


def pass_image_through_model(img):
    return model(img) # this returns a tensor 

@tf.function
def load_data(..., model):
    # code to load an image
    files = tf.data.Dataset.from_tensor_slices(file_list).batch(1) # Don't forget batch size!
    images = files.map(load_image_from_file) 

    dataset = images.map(pass_image_through_model)
    return dataset
0 голосов
/ 09 апреля 2020

Ошибка, которую вы получаете, может быть тихой, если вы впервые имеете дело с tf.data.Dataset() объектом.

Все операции в tf.data.Dataset() фактически выполняются в графическом режиме, и вы не можете использовать какие-либо функции за пределами предопределенных в tf.*.

Единственный способ смешать произвольный код Python с вашим tf.data.Dataset() - это использовать tf.py_function(), в противном случае будет выдано сообщение об ошибке.

Помните, что смешивание кода Python с оптимизированным кодом tf.data.Dataset() приведет к снижению производительности по времени.

Единственный способ проверить это получить набор данных, используйте as_numpy_iterator() для извлечения данных. и прогнозировать с вашей моделью, поэтому вне процесса отображения.

...