Я использую API tf.keras в TensorFlow2. У меня есть 100 000 изображений или около того, которые сохраняются как TFRecords (128 изображений на запись). Каждая запись имеет входное изображение, целевое изображение и индекс кадра. Я не могу найти чистый способ сохранить индекс кадра с предсказанием.
Вот пример, за исключением того, что я строю набор данных с массивами NumPy вместо чтения из TFRecords:
import tensorflow as tf
from tensorflow import keras
import numpy as np
# build dummy tf.data.Dataset
x = np.random.random(10000).astype(np.float32)
y = x + np.random.random(10000).astype(np.float32) * 0.1
idx = np.arange(10000, dtype=np.uint16)
np.random.shuffle(idx) # frames are random in my TFRecord files
ds = tf.data.Dataset.from_tensor_slices((x, y, idx))
# pretend ds returned from TFRecord
ds = ds.map(lambda f0, f1, f2: (f0, f1)) # strip off idx
ds = ds.batch(32)
# build and train model
x = keras.Input(shape=(1,))
y_hat = keras.layers.Dense(1)(x) # i.e. linear regression
model = keras.Model(x, y_hat)
model.compile('sgd', 'mse')
history = model.fit(ds, epochs=5)
# predict 1 batch
model.predict(ds, steps=1)
Если не считать повторного чтения набора данных для извлечения индексов (что подвержено ошибкам), существует ли чистый способ сохранения соответствия предсказания с индексом изображения? В TF1.x все было просто. Но я бы хотел воспользоваться преимуществами чистого Keras API compile (), fit (), pregnet () в TF2.