Как я могу получить веса от нейронной сети и убедиться, что они все еще обучаемы? - PullRequest
0 голосов
/ 06 января 2019

Я пытаюсь обучить нейронную сеть с целевой функцией, состоящей из ошибок и регуляризации.

Чтобы выполнить регуляризацию, я хочу получить все веса в виде 1D-тензора (назовите этот тензор weights), выполнить некоторые операции и добавить это к целевой функции. Как мне получить веса, чтобы я мог продолжать их тренировать?

Пока что я пробовал:

  1. Использование tf.get_default_graph (). Get_tensor_by_name () - когда я вычисляю градиенты относительно weights, градиент условия ошибки всегда равен None.
  2. Использование tf.get_variable () - как указано выше, градиент условия ошибки всегда равен None
  3. Использование атрибута trainable_weights слоев - атрибут trainable_weights возвращает пустой список.

Регуляризация, которую я надеюсь запустить, представляет собой модель гауссовой смеси, в которой также обучаются сами параметры GMM.

Например, для третьей попытки мой код:

# Here I create the layers
layers = []
for L in range(len(units)):
    layer = tf.layers.Dense(units=units[L], activation=tf.nn.relu, name="lay"+str(L))
    layers.append(layer)       
layers.append(tf.layers.Dense(n_y, activation=None))

# Here I try to get the weights
weights = [L.trainable_weights for L in layers] # Returns empty lists
weights = tf.concat(weights,axis=0)

1 Ответ

0 голосов
/ 06 января 2019

В вашем случае я настоятельно рекомендую вам определять веса и слои с помощью самой базовой операции, такой как

weights = {
'h1': tf.Variable(tf.random_normal([num_input, n_hidden_1])),
'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_hidden_2, num_classes]))}

biases = {
'b1': tf.Variable(tf.random_normal([n_hidden_1])),
'b2': tf.Variable(tf.random_normal([n_hidden_2])),
'out': tf.Variable(tf.random_normal([num_classes]))}

вместо продвинутого API, такого как tf.layers.Dense(). Здесь у вас есть очень простой пример .

Тогда значение weights можно получить с помощью

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    weights_value = sess.run(weights['h1'])
    print(weights_value)

Надеюсь, это полезно.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...