Я хотел бы использовать предварительно обученную модель Keras как часть моей обработки данных для генерации обучающих данных для второй модели Keras. В идеале я хотел бы сделать это, вызвав первую модель в генераторе данных для второй модели.
Я использую tenorflow 1.15.
Простой пример того, что я пытаюсь сделать, таков:
import numpy as np
import tensorflow as tf
from tensorflow_core.python.keras import Sequential
from tensorflow_core.python.keras.layers import Dense
batch_size = 4
data_size = 16
model_generator = Sequential([Dense(data_size)])
model_generator.compile(optimizer='adam', loss='mae')
model_generator.build((batch_size, data_size))
sess = tf.keras.backend.get_session()
def generator():
while True:
data = np.random.random((batch_size, data_size))
targets = tf.random.uniform((batch_size, 1))
data = model_generator(data, training=False)
data = data.eval(session=sess)
yield (data, targets)
model_train = Sequential([Dense(data_size)])
model_train.compile(optimizer='adam', loss='mae')
model_train.build((batch_size, data_size))
output_types = (tf.float64, tf.float64)
output_shapes = (tf.TensorShape((batch_size, data_size)), tf.TensorShape((batch_size, 1)))
dataset = tf.data.Dataset.from_generator(
generator,
output_types=output_types,
output_shapes=output_shapes,
)
if next(generator()) is not None:
print("Generator works outside of model.fit()!")
model_train.fit(
dataset,
epochs=2,
steps_per_epoch=2
)
Приведенный выше фрагмент кода выдает следующее сообщение об ошибке когда вызывается .fit () .:
2020-01-28 17:35:56.705549: W tensorflow/core/framework/op_kernel.cc:1639] Invalid argument: ValueError: Tensor("dense/kernel/Read/ReadVariableOp:0", shape=(16, 16), dtype=float32) must be from the same graph as Tensor("sequential/dense/Cast:0", shape=(4, 16), dtype=float32).
Traceback (most recent call last):
Код будет работать нормально, если генератор не вызывает модель model_generator
. Например:
def generator():
while True:
data = np.random.random((batch_size, data_size))
targets = np.random.random((batch_size, 1))
yield (data, targets)
Я считаю, что вызов fit создает свой собственный тензорный граф, который не включает узлы, необходимые для model_generator
. Есть ли способ использовать одну модель в генераторе для обучения другой модели, как эта? Если да, как я могу изменить приведенный выше пример для достижения этой цели?