Как получить доступ к весам слоев внутри модели Keras? - PullRequest
1 голос
/ 08 апреля 2020

Я пытаюсь получить доступ к весам слоя Keras и использовать сами значения весов в качестве входных данных для другого слоя.

Вот примерный план того, чего я надеюсь достичь:

def generate_myModel(SEQUENCE_LENGTH, FILT_NUM, FILT_SIZE):

  ip = keras.layers.Input(shape = (SEQUENCE_LENGTH,1))

  conv_layer = keras.layers.Conv1D(filters = FILT_NUM, kernel_size = FILT_SIZE)
  y = conv_layer(ip)

  y = keras.layers.GlobalMaxPooling1D()(y)

  out_y = keras.layers.Dense(units = 1, activation = 'linear')(y)

  # Acquire the actual weights from the previous convolution layers
  w1 = <WEIGHTS FROM conv_layer - THIS IS THE PART IN QUESTION>
  out_w1 = keras.layers.Lambda( lambda x: K.std(x)/K.abs(K.mean(x)) )(w1)

  myModel = keras.models.Model(inputs = ip, outputs = [out_y, out_w1])

  return myModel

Я знаю, что, когда у вас есть экземплярная модель, вы можете использовать model.layers[i].get_weights(), но я бы хотел сделать это в реальной архитектуре модели.

Возможно ли это?

РЕДАКТИРОВАТЬ ------------------------------ ------

Пытаясь найти решение из комментариев, я добавляю layer.get_weights() в архитектуру модели следующим образом:

def generate_myModel(SEQUENCE_LENGTH, FILT_NUM, FILT_SIZE):

  ip = keras.layers.Input(shape = (SEQUENCE_LENGTH,1))

  conv_layer = keras.layers.Conv1D(filters = FILT_NUM, kernel_size = FILT_SIZE)
  y = conv_layer(ip)

  # Acquire the actual weights from the previous convolution layers
  w1 = K.constant(conv_layer.get_weights()) # layer.get_weights returns a Numpy 
                                            # array, I need a Keras Tensor - 
                                            # so I use K.constant()

  y = keras.layers.GlobalMaxPooling1D()(y)

  out_y = keras.layers.Dense(units = 1, activation = 'linear')(y)
  out_w1 = keras.layers.Lambda( lambda x: K.std(x)/K.abs(K.mean(x)) )(w1)

  myModel = keras.models.Model(inputs = ip, outputs = [out_y, out_w1])

  return myModel

, но это оставляет меня со следующим ошибка:

AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

Любое руководство будет с благодарностью!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...