Как обновить код для использования функциональной модели Keras - PullRequest
0 голосов
/ 01 марта 2019

Мне нужно обновить код, который я нашел на github, чтобы я мог правильно построить модель.Мне нужно объединить некоторые слои (на старых версиях Keras это было сделано через Merge (method = 'concat'), но теперь мне нужно использовать функцию concatenate. Мне нужно использовать функциональную модель, чтобы сделать это.

В качестве примера:

model1 = Sequential()
model1.add(Dense(300, input_dim=40, activation='relu', name='layer_1'))

будет "обновлено до":

model1_in = Input(shape=(27, 27, 1))
model1_out = Dense(300, input_dim=40, activation='relu', name='layer_1')(model1_in)
model1 = Model(model1_in, model1_out)

Код, который мне нужно обновить, следующий:

embed_quarter_hour = Sequential()
embed_quarter_hour.add(Embedding(metadata['n_quarter_hours'], embedding_dim, input_length=1))
embed_quarter_hour.add(Reshape((embedding_dim,)))

Весь код для обновления:

# Arbitrary dimension for all embeddings
embedding_dim = 10

# Quarter hour of the day embedding
embed_quarter_hour = Sequential()
embed_quarter_hour.add(Embedding(metadata['n_quarter_hours'], embedding_dim, input_length=1))
embed_quarter_hour.add(Reshape((embedding_dim,)))

#Quarter hour of the day embedding

# Day of the week embedding
embed_day_of_week = Sequential()
embed_day_of_week.add(Embedding(metadata['n_days_per_week'], embedding_dim, input_length=1))
embed_day_of_week.add(Reshape((embedding_dim,)))


# Week of the year embedding
embed_week_of_year = Sequential()
embed_week_of_year.add(Embedding(metadata['n_weeks_per_year'], embedding_dim, input_length=1))
embed_week_of_year.add(Reshape((embedding_dim,)))


# Client ID embedding
embed_client_ids = Sequential()
embed_client_ids.add(Embedding(metadata['n_client_ids'], embedding_dim, input_length=1))
embed_client_ids.add(Reshape((embedding_dim,)))


# Taxi ID embedding
embed_taxi_ids = Sequential()
embed_taxi_ids.add(Embedding(metadata['n_taxi_ids'], embedding_dim, input_length=1))
embed_taxi_ids.add(Reshape((embedding_dim,)))


# Taxi stand ID embedding
embed_stand_ids = Sequential()
embed_stand_ids.add(Embedding(metadata['n_stand_ids'], embedding_dim, input_length=1))
embed_stand_ids.add(Reshape((embedding_dim,)))



# GPS coordinates (5 first lat/long and 5 latest lat/long, therefore 20 values)
coords = Sequential()
coords.add(Dense(1, input_dim=20, init='normal'))


model = Sequential()
model.add(Merge([
            embed_quarter_hour,
            embed_day_of_week,
            embed_week_of_year,
            embed_client_ids,
            embed_taxi_ids,
            embed_stand_ids,
            coords
        ]),method='concat')

# Simple hidden layer
model.add(Dense(500))
model.add(Activation('relu'))

# Determine cluster probabilities using softmax
model.add(Dense(len(clusters)))
model.add(Activation('softmax'))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...