Извлечение функций из 2 авто-кодировщиков и передача их в MLP - PullRequest
0 голосов
/ 02 июня 2018

Я понимаю, что функции, извлеченные из автокодера, могут быть введены в mlp для классификации или регрессии.Это то, что я делал ранее.
Но что, если у меня есть 2 авто-кодера?Могу ли я извлечь функции из слоев узких мест 2 автокодировщиков и передать их в mlp, который выполняет классификацию на основе этих функций?Если да, то как?Я не уверен, как объединить эти два набора функций.Я попытался с numpy.hstack (), который выдает ошибку 'unhashable slice', тогда как использование tf.concat () дает мне ошибку 'Входные тензоры в модель должны быть тензорами Keras.'слои узких мест двух автокодеров имеют размерность (Нет, 100) каждый.Итак, по сути, если я сложу их горизонтально, я должен получить (Нет, 200).Скрытый слой mlp может содержать несколько (num_hidden = 100) нейронов.Может ли кто-нибудь помочь, пожалуйста?

x1 = autoencoder1.get_layer('encoder2').output
x2 = autoencoder2.get_layer('encoder2').output

#inp = np.hstack((x1, x2))
inp = tf.concat([x1, x2], 1)
x = tf.concat([x1, x2], 1)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
y = Dense(1, activation='sigmoid', name='prediction')(h)
mymlp = Model(inputs=inp, outputs=y)

# Compile model
mymlp.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train model
mymlp.fit(x_train, y_train, epochs=20, batch_size=8)

обновлено в соответствии с предложением @ twolffpiggott:

from keras.layers import Input, Dense, Dropout
from keras import layers
from keras.models import Model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import numpy as np

x1 = Data1
x2 = Data2
y = Data3

num_neurons1 = x1.shape[1]
num_neurons2 = x2.shape[1]

# Train-test split
x1_train, x1_test, x2_train, x2_test, y_train, y_test = train_test_split(x1, x2, y, test_size=0.2)

# scale data within [0-1] range
scalar = MinMaxScaler()
x1_train = scalar.fit_transform(x1_train)
x1_test = scalar.transform(x1_test)

x2_train = scalar.fit_transform(x2_train)
x2_test = scalar.transform(x2_test)

x_train = np.concatenate([x1_train, x2_train], axis =-1)
x_test = np.concatenate([x1_test, x2_test], axis =-1)

# Auto-encoder1

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons1,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded1 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded1)
decoded = Dense(num_neurons1, activation='sigmoid', name='decoder2')(decoded)

# this model maps an input to its reconstruction
autoencoder1 = Model(inputs=input_data, outputs=decoded)
autoencoder1.compile(optimizer='sgd', loss='mse')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    

# training
autoencoder1.fit(x1_train, x1_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x1_test, x1_test))

# Auto-encoder2

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons2,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded2 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded2)
decoded = Dense(num_neurons2, activation='sigmoid', name='decoder2')(decoded)


# this model maps an input to its reconstruction
autoencoder2 = Model(inputs=input_data, outputs=decoded)
autoencoder2.compile(optimizer='sgd', loss='mse')

# training
autoencoder2.fit(x2_train, x2_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x2_test, x2_test))

# MLP

num_hidden = 100

encoded1.trainable = False
encoded2.trainable = False

encoded1 = autoencoder1(autoencoder1.inputs)
encoded2 = autoencoder2(autoencoder2.inputs)

concatenated = layers.concatenate([encoded1, encoded2], axis=-1)
x = Dropout(0.2)(concatenated)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
h = Dropout(0.5)(h)
y = Dense(1, activation='sigmoid', name='prediction')(h)
myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

# Compile model
myMLP.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Training
myMLP.fit(x_train, y_train, epochs=200, batch_size=8)

# Testing
myMLP.predict(x_test)

, сообщая мне ошибку: не подлежащий описанию тип: 'list' из строки: myMLP = Model (входные данные= [autoencoder1.inputs, autoencoder2.inputs], выходы = y)

Ответы [ 2 ]

0 голосов
/ 02 июня 2018

Я бы также согласился с первым подходом Даниэля (для простоты и эффективности), но если вам интересен второй;например, если вы заинтересованы в работе сквозной сети, вы должны подойти к ней следующим образом:

# make autoencoders not trainable
autoencoder1.trainable = False
autoencoder2.trainable = False

encoded1 = autoencoder1(kerasInputs1)
encoded2 = autoencoder2(kerasInputs2)

concatenated = layers.concatenate([encoded1, encoded2], axis=-1)
h = Dense(num_hidden, activation='relu', name='hidden')(concatenated)
y = Dense(1, activation='sigmoid', name='prediction')(h)

myMLP = Model([input_data1, input_data2], y)

myMLP.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Training
myMLP.fit([x1_train, x2_train], y_train, epochs=200, batch_size=8)

# Testing
myMLP.predict([x1_test, x2_test])

Ключевые правки

  1. Веса обоих автоэнкодеров должны быть заморожены до конца (иначе обновления градиента на ранней стадии от случайно инициализированной MLP, вероятно, приведут к потере большей части их обучения).
  2. Входные слои автоэнкодера должныназначаться отдельным переменным input_data1 и input_data2 для каждого автоэнкодера (вместо обеих input_data).Даже если autoencoder1.inputs возвращает tf-тензор, это является источником исключения unhashable type: list, и замена на [input_data1, input_data2] решает проблему.
  3. При установке MLP для сквозной модели,входными данными должен быть список x1_train и x2_train, а не составные входные данные.То же самое при прогнозировании.
0 голосов
/ 02 июня 2018

Проблема в том, что вы смешиваете массивы numy с тензорами keras.Это не может идти.

Есть два подхода.

  • Прогнозирование массивов-пустышек из каждого автоэнкодера, объединение массивов, отправка их в третью модель
  • Соединение всех моделей, возможно, сделать автоэнкодеры неуправляемыми, подходить с одним входом для каждого автоэнкодера.

Лично я бы пошел первым.(Предполагая, что автоэнкодеры уже обучены и не нуждаются в смене.)

Первый подход

numpyOutputFromAuto1 = autoencoder1.predict(numpyInputs1)    
numpyOutputFromAuto2 = autoencoder2.predict(numpyInputs2)

inputDataForThird = np.concatenate([numpyOutputFromAuto1,numpyOutputFromAuto2],axis=-1)

inputTensorForMlp = Input(inputsForThird.shape[1:])
h = Dense(num_hidden, activation='relu', name='hidden')(inputTensorForMlp)
y = Dense(1, activation='sigmoid', name='prediction')(h)

mymlp = Model(inputs=inputTensorForMlp, outputs=y)

....
mymlp.fit(inputDataForThird ,someY)

Второй подход

Это немного сложнее, и поначалуЯ не вижу особой причины для этого.(Но, конечно, могут быть случаи, когда это хороший выбор)

Теперь мы полностью забываем numpy и работаем с тензорами keras.

Создание mlp самостоятельно (хорошо, если вы будете использовать его позже без автоэнкодеров):

inputTensorForMlp = Input(input_shape_compatible_with_concatenated_encoder_outputs)
x = Dropout(0.2)(inputTensorForMlp)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
h = Dropout(0.5)(h)
y = Dense(1, activation='sigmoid', name='prediction')(h)
myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

Возможно, нам нужны узкие места автоэнкодеров, верно?Если вам удалось правильно создать автоэнкодеры с помощью: модели кодера, модели декодера, соедините оба, то проще использовать только модель кодера.Иное:

encodedOutput1 = autoencoder1.layers[bottleneckLayer].outputs #or encoder1.outputs
encodedOutput2 = autoencoder1.layers[bottleneckLayer].outputs #or encoder2.outputs

Создание объединенной модели.В конкатенации должен использоваться слой keras (мы работаем с тензорами keras):

concatenated = Concatenate()([encodedOutput1,encodedOutput2])
output = myMLP(concatenated)

joinedModel = Model([autoencoder1.input,autoencoder2.input],output)
...