Есть ли способ использовать tf.keras.Model.predict в tf.data.Dataset.map? - PullRequest
3 голосов
/ 12 июля 2020

У меня есть набор данных, который использует вызов модели keras на карте, как в этом примере игрушки:

ds=tf.data.Dataset.from_tensor_slices([tf.ones([1]) for i in range(10)])
model=tf.keras.models.Sequential([tf.keras.layers.Dense(1),tf.keras.layers.Dense(1)])

ds=ds.batch(4).map(lambda x:model(x))

Мне было интересно, можно ли в любом случае использовать встроенную модель model.predict (x) вместо , поскольку карта, использующая вызов модели, довольно медленная (в моем реальном проекте). Я пробовал

ds=tf.data.Dataset.from_tensor_slices([tf.ones([1]) for i in range(10)])
model=tf.keras.models.Sequential([tf.keras.layers.Dense(1),tf.keras.layers.Dense(1)])

ds=ds.batch(4).map(lambda x:model.predict(x))

и

ds=tf.data.Dataset.from_tensor_slices([tf.ones([1]) for i in range(10)])
model=tf.keras.models.Sequential([tf.keras.layers.Dense(1),tf.keras.layers.Dense(1)])

def predict(x):
  return model.predict(x)

def predict_wrapper(x):
  y=tf.py_function(predict,[x],tf.float32)
  y.set_shape([None,None])
  return y
ds=ds.batch(4).map(predict_wrapper)
for x in ds:
  print(x)

Возможно ли это? Будет ли разница в скорости? Я предполагаю, что, вероятно, нет, поскольку набор данных уже оптимизирован для распределенной стратегии, и это похоже на распределение операций внутри распределенных операций. Но поскольку я понятия не имею об этом, я подумал, что спрошу.

Также я работаю в Google Colab, если это имеет значение.

...