После намеков, которые Сакет Кумар Сингх дал в своем ответе, я написал следующее, что, похоже, решает вопрос:
Я создаю два пользовательских слоя.Возможно, Keras уже предлагает некоторые классы, которые им эквивалентны.
Первый из них - обучаемый вход:
class MyInputLayer(keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyInputLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(name='kernel',
shape=self.output_dim,
initializer='uniform',
trainable=True)
super(MyInputLayer, self).build(input_shape)
def call(self, x):
return self.kernel
def compute_output_shape(self, input_shape):
return self.output_dim
Второй получает вероятность интересующей метки:
class MySelectionLayer(keras.layers.Layer):
def __init__(self, position, **kwargs):
self.position = position
self.output_dim = 1
super(MySelectionLayer, self).__init__(**kwargs)
def build(self, input_shape):
super(MySelectionLayer, self).build(input_shape)
def call(self, x):
mask = np.array([False]*x.shape[-1])
mask[self.position] = True
return tf.boolean_mask(x, mask,axis=1)
def compute_output_shape(self, input_shape):
return self.output_dim
Я использовал их таким образом:
# Build the model
layer_flatten = keras.layers.Flatten(input_shape=(28, 28))
layerDense1 = keras.layers.Dense(128, activation=tf.nn.relu)
layerDense2 = keras.layers.Dense(10, activation=tf.nn.softmax)
model = keras.Sequential([
layer_flatten,
layerDense1,
layerDense2
])
# Compile the model
model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Train the model
# ...
# Freeze the model
layerDense1.trainable = False
layerDense2.trainable = False
# Build another model
class_index = 7
layerInput = MyInputLayer((1,784))
layerSelection = MySelectionLayer(class_index)
model_extended = keras.Sequential([
layerInput,
layerDense1,
layerDense2,
layerSelection
])
# Compile it
model_extended.compile(optimizer=tf.train.AdamOptimizer(),
loss='mean_absolute_error')
# Train it
dummyInput = np.ones((1,1))
target = np.ones((1,1))
model_extended.fit(dummyInput, target,epochs=300)
# Retrieve the weights of layerInput
layerInput.get_weights()[0]