Объект tf.data в модель Keras с несколькими входами - PullRequest
0 голосов
/ 12 октября 2019

У меня возникли проблемы при использовании tf.Data с Keras для множественного ввода.

Я читаю данные из таблицы PostgreSQL с помощью генератора Python, который возвращает три массива:

class PairGenerator(object):

    return {'numerical_inputs': features, 'cat_input': idx}, response

Использование .from_generator, я создаю объект набора данных:

training_generator = PairGenerator(sql_query = sql_query, config_file = 'config.json', column_dtypes = ColsDtypes, n_steps = n_steps, num_obs = 1000, batch_size = batch_size)

train_dataset = tf.data.Dataset.from_generator(lambda: training_generator, output_types=({'numerical_inputs': tf.float32, 'cat_input': tf.string}, tf.int32), output_shapes=({'numerical_inputs': tf.TensorShape([None, 10, 36]), 'cat_input': tf.TensorShape([None,10])}, tf.TensorShape([None,10, 1]))).prefetch(1)
#<DatasetV1Adapter shapes: ({numerical_inputs: (None, 10, 36), cat_input: (None, 10)}, (None, 10, 1)), types: ({numerical_inputs: tf.float32, cat_input: tf.string}, tf.int32)>

Это хорошо работает, когда я печатаю несколько примеров

for epoch in range(3):
    for example_batch, label_batch in train_dataset:
        print(len(example_batch))
        print(label_batch.shape)
    print("End of epoch: ", epoch)

Поэтому я определяю модель, используя Keras из Tensorflow 2.0

batch_size = 32
num_obs = 1000
num_cats = 1 # number of categorical features
n_steps = 10 # number of timesteps in each sample
n_numerical_feats = 36 # number of numerical features in each sample
cat_size = 32465 # number of unique categories in each categorical feature
embedding_size = 1 # embedding dimension for each categorical feature

numerical_inputs = keras.layers.Input(shape=(n_steps, n_numerical_feats), name='numerical_inputs')
#<tf.Tensor 'numerical_inputs:0' shape=(?, 10, 36) dtype=float32>

cat_input = keras.layers.Input(shape=(n_steps,), name='cat_input')
#<tf.Tensor 'cat_input:0' shape=(None, 10) dtype=float32>

cat_embedded = keras.layers.Embedding(cat_size, embedding_size, embeddings_initializer='uniform')(cat_input)
#<tf.Tensor 'embedding_1/Identity:0' shape=(None, 10, 1) dtype=float32>

merged = keras.layers.concatenate([numerical_inputs, cat_embedded])
#<tf.Tensor 'concatenate_1/Identity:0' shape=(None, 10, 37) dtype=float32>

lstm_out = keras.layers.LSTM(64, return_sequences=True)(merged)
#<tf.Tensor 'lstm_2/Identity:0' shape=(None, 10, 64) dtype=float32>

Dense_layer1 = keras.layers.Dense(32, activation='relu', use_bias=True)(lstm_out)
#<tf.Tensor 'dense_4/Identity:0' shape=(None, 10, 32) dtype=float32>
Dense_layer2 = keras.layers.Dense(1, activation='linear', use_bias=True)(Dense_layer1 )
#<tf.Tensor 'dense_5/Identity:0' shape=(None, 10, 1) dtype=float32>

model = keras.models.Model(inputs=[numerical_inputs, cat_input], outputs=Dense_layer2)

enter image description here

Затем я компилирую модель

#compile model
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss='mse',
          optimizer=optimizer,
          metrics=['mae', 'mse'])
EPOCHS =5

Теперь она подходит для модели. Начиная с Tensorflow 1.9, объект tf.data.Dataset можно передать непосредственно в keras.Model.fit().

#fit the model
history = model.fit(train_dataset,
                epochs=EPOCHS,
                verbose=1,
                initial_epoch=0)

Однако на этом этапе ничего не происходит. Ядро Jupyter включено, кажется, оно работает, но никаких результатов не появляется!

Если я не использую объект tf.data.Dataset и напрямую передаю данные из пустых массивов, он работает как шарм!

#fit the model
#you can use input layer names instead
history = model.fit({'numerical_inputs': X_numeric, 
       'cat_input': X_cat1.reshape(-1, n_steps)}, 
                y = target,
                batch_size = batch_size
                epochs=EPOCHS,  
                verbose=1,
                initial_epoch=0)

Эта проблема существует на github без решений! https://github.com/tensorflow/tensorflow/issues/20698

Я действительно не знаю, что еще делать. Может ли кто-нибудь помочь мне об этом?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...