Один из самых простых способов решить эту проблему - передать входные данные непосредственно в модель, а не использовать метод model.predit
. Причина этого в том, что model.predict
возвращает numpy.ndarray
. Это вызывает ошибку, потому что tf.data
использует выполнение графа, что означает, что лучше иметь любую операцию ввода И выводить тензор в этом графе.
Ниже приведен быстрый рабочий пример этого.
import tensorflow as tf
# Create example model
inputs = tf.keras.Input((1,))
out = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs, out)
def map_fn(row):
return model(row)
# Create some input data
a = tf.constant([1, 2])
# Create the dataset
ds = tf.data.Dataset.from_tensor_slices(a).batch(1)
model_mapped_ds = ds.map(lambda x: map_fn(x))
for el in model_mapped_ds:
print(el)
Наконец, ниже показано, как это будет выглядеть при использовании.
def pass_image_through_model(img):
return model(img) # this returns a tensor
@tf.function
def load_data(..., model):
# code to load an image
files = tf.data.Dataset.from_tensor_slices(file_list).batch(1) # Don't forget batch size!
images = files.map(load_image_from_file)
dataset = images.map(pass_image_through_model)
return dataset