Как использовать керасы для мультиклассовой классификации mutli label - PullRequest
3 голосов
/ 15 марта 2020

У меня есть набор данных с 6 возможными метками типа:

Class 1: Near-large 
Class 2: Far- Large 
Class 3: Near - Medium 
Class 4: Far - Medium 
Class 5: Near - small 
Class 6: far - small 

И я хотел бы изменить задачу, чтобы разделить метки, чтобы каждый образец классифицировался как дальний / ближний и маленький / средний / маленький независимо, учитывая различные особенности для каждой классификации в качестве входных данных.

Моя первая идея состояла в том, чтобы обучить 2 разные модели для каждой вложенной метки, а затем создать пользовательскую функцию для присоединения к прогнозам, но мне интересно, есть ли более быстрый способ сделать это в рамках Keras.

Я знаю, что могу использовать функциональный API для создания двух моделей с независимыми входами и двумя независимыми выходами. Это дало бы мне два прогноза для двух разных суб-меток. Если я один раз закодирую суб-метки, выходные данные этих моделей будут выглядеть примерно так:

Model1.output = [ 0,1 ] or [1,0]  ( far/near) 
Model2.output = [ 1, 0, 0 ] or [0,1,0] or [0,0,1](small/medium/large)

Но тогда как я могу объединить эти два выхода, чтобы создать 6 тусклых векторов для полных меток?

Model_merged.output = [1, 0, 0, 0, 0 ,0 ] , [010000], ...., [000001] (class1,... ,Class6) 

1 Ответ

0 голосов
/ 15 марта 2020

Можно вывести reshape модель1, чтобы расширить ось, умножить ее на выход модели2 и сгладить их обе.

from keras.models import Model

reshaped = keras.layers.Reshape((2,-1))(model1.output)
combined = keras.layers.Multiply()([reshaped,model2.output])
flattened = keras.layers.Reshape((6,))(combined)


Combined_model = Model([model1.input,model2.input], flattened)

Простым numpy примером вышеупомянутого будет:

model1_output = np.array([0,1])[:,None] #Reshaped

#array([[0],
#       [1]])

model2_output = np.array([1,0,0])

# array([1, 0, 0])

combined = model1_output*model2_output

#array([[0, 0, 0],
#       [1, 0, 0]])

combined.ravel()

#array([0, 0, 0, 1, 0, 0])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...