Как предотвратить вычисление градиента для определенных весов - PullRequest
0 голосов
/ 09 января 2019

Я хочу, чтобы градиент не рассчитывался для определенных весов или устанавливался на ноль. Как я не хочу их обновлять во время тренировок. Вот пример кода:

 import tensorflow as tf
 import tensorflow.contrib.eager as tfe
 import numpy as np

 tf.enable_eager_execution()


 model = tf.keras.Sequential([
   tf.keras.layers.Dense(2, activation=tf.sigmoid, input_shape=(2,)),
   tf.keras.layers.Dense(2, activation=tf.sigmoid)
 ])


 #set the weights
 weights=[np.array([[0, 0.25],     [0.2,0.3]]),np.array([0.35,0.35]),np.array([[0.4,0.5],[0.45, 0.55]]),np.array([0.6,0.6])]

 model.set_weights(weights)

 model.get_weights()

 features = tf.convert_to_tensor([[0.05,0.10 ]])
 labels =  tf.convert_to_tensor([[0.01,0.99 ]])

 #define the loss function
 def loss(model, x, y):
   y_ = model(x)
   return tf.losses.mean_squared_error(labels=y, predictions=y_)

 #define the gradient calculation
 def grad(model, inputs, targets):
   with tf.GradientTape() as tape:
     loss_value = loss(model, inputs, targets)
   return loss_value, tape.gradient(loss_value, model.trainable_variables) 

 #create optimizer an global Step
 optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
     global_step = tf.train.get_or_create_global_step()

Я хочу, чтобы первый вес, который равен 0, не был включен в расчет градиента. Я нашел tf.stop_gradient и tf.keras.backend.stop_gradient. Но не знаю, как и если вы можете применить их к моей проблеме

1 Ответ

0 голосов
/ 09 января 2019

Вы можете использовать layer.trainable = False, что остановит изменение веса в этом слое во время тренировки.

Чтобы получить слои в вашей модели, вы можете вызвать model.layers и выбрать первый или нулевой слой в объекте, который вы возвращаете, а затем установить обучаемое значение в false, что-то вроде:

      layers =  model.layers
      layers[0].trainable = False 
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...