Задание
Запуск keras.model.fit_generator
с use_multiprocessing=True
и несколькими работниками в генераторе данных, который сам содержит модель тензорного потока или кераса.
Эта проблема очень связана: https://github.com/tensorflow/tensorflow/issues/5448#issuecomment-258934405
def create_minimal_keras_model():
##### Create Model A #####
in1 = keras.layers.Input(shape=(1,))
d = keras.layers.Dense(1)(in1)
a = keras.Model(inputs=in1, outputs=d)
opt = keras.optimizers.Adam(lr=0.01)
loss = keras.losses.mse
a.compile(opt, loss)
#####
return a
class TestGenerator(keras.utils.Sequence):
def __init__(self):
self.len = int(1e2)
self.model = None
# self.init_model()
def init_model(self):
self.graph = tf.Graph()
with self.graph.as_default():
self.model = create_minimal_keras_model()
def __len__(self):
"""
Number of batches for generator.
"""
return self.len
def __getitem__(self, index):
"""
Keras sequence method for generating batches.
"""
if not self.model:
self.init_model()
if self.model:
with self.graph.as_default():
res = self.model.predict(np.array([1]))
return (np.array([index]), np.array([-index/2 + 3]))
Ошибка
Тренировка висит в начале 2-й эпохи.
Что я пробовал
- модель init при инициализации генератора данных (основной процесс)
- модель инициализации при вызове цикла первого поколения (подпроцесс)
- вызов tf.Session () и других функций вызывает тупик в начале 1-й эпохи
Полный пример кода:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import os
def create_minimal_keras_model():
##### Create Model A #####
in1 = keras.layers.Input(shape=(1,))
d = keras.layers.Dense(1)(in1)
a = keras.Model(inputs=in1, outputs=d)
opt = keras.optimizers.Adam(lr=0.01)
loss = keras.losses.mse
a.compile(opt, loss)
#####
return a
class TestGenerator(keras.utils.Sequence):
def __init__(self):
self.len = int(1e2)
self.model = None
# self.init_model()
def init_model(self):
self.graph = tf.Graph()
with self.graph.as_default():
self.model = create_minimal_keras_model()
def __len__(self):
"""
Number of batches for generator.
"""
return self.len
def __getitem__(self, index):
"""
Keras sequence method for generating batches.
"""
if not self.model:
self.init_model()
if self.model:
with self.graph.as_default():
res = self.model.predict(np.array([1]))
return (np.array([index]), np.array([-index/2 + 3]))
os.environ['CUDA_VISIBLE_DEVICES'] = ''
a = create_minimal_keras_model()
a.summary()
##########################################
##### Funcions Halt Before 1st Epoch #####
##########################################
# tf.Session()
# a.save_weights('tmp_model_weights.h5')
# a.load_weights('tmp_model_weights.h5')
# a.save('tmp_model.h5')
# keras.models.load_model('tmp_model.h5')
##########################################
##########################################
##########################################
##########################################
##### Functions Causing NO Deadlocks #####
##########################################
tf.get_default_session()
tf.Graph()
keras.__version__
with tf.device('/cpu:0'):
_ = tf.constant(0)
keras.utils.plot_model(a, to_file=('tmp_plot_model.png'), show_shapes=True)
[a.get_layer(l_name).output for l_name in [a.layers[-1].name]]
_ = keras.backend.variable(4)
_ = keras.backend.image_data_format()
_ = keras.backend.shape(tf.constant(1, shape=(5,5,5)))
_ = a.layers[0].get_config()
tf.random.set_random_seed(0)
##########################################
##########################################
##########################################
# The training will stop at second epoch
a.fit_generator(generator=TestGenerator(), steps_per_epoch=100, epochs=5, workers=4, use_multiprocessing=True)
Вопрос
Какие варианты у меня есть для того, чтобы иметь генератор данных, который запускает модель тензорного потока внутри, во время многопроцессного обучения.
Опция, которую я определил:
- Использование pytorch или другой среды для генератора данных
- Запустить модель в док-контейнере как сервис https://towardsdatascience.com/deploy-tensorflow-models-9813b5a705d5
- Предварительно запустите модель и создайте модифицированную копию набора данных
- Создайте автономный скрипт на python, обслуживающий модель, и обменивайтесь данными через ZeroMQ или аналогичный