обновление подмножества параметров в dynet - PullRequest
0 голосов
/ 05 июня 2018

Есть ли способ обновить поднабор параметров в dynet?Например, в следующем примере с игрушкой сначала обновите h1, затем h2:

 model = ParameterCollection()
 h1 = model.add_parameters((hidden_units, dims))
 h2 = model.add_parameters((hidden_units, dims))
 ...
 for x in trainset:
    ...
    loss.scalar_value()
    loss.backward()
    trainer.update(h1)
    renew_cg()

 for x in trainset:
    ...
    loss.scalar_value()
    loss.backward()
    trainer.update(h2)
    renew_cg()

Я знаю, что update_subset интерфейс существует для этого и работает на основе заданногоиндексы параметров.Но тогда нигде не задокументировано, как мы можем получить индексы параметров в dynet Python.

1 Ответ

0 голосов
/ 06 июня 2018

Решением является использование флага update = False при создании выражений для параметров (включая параметры поиска):

import dynet as dy
import numpy as np

model = dy.Model()
pW = model.add_parameters((2, 4))
pb = model.add_parameters(2)
trainer = dy.SimpleSGDTrainer(model)

def step(update_b):
    dy.renew_cg()
    x = dy.inputTensor(np.ones(4))
    W = pW.expr()
    # update b?
    b = pb.expr(update = update_b)

    loss = dy.pickneglogsoftmax(W * x + b, 0)
    loss.backward()
    trainer.update()
    # dy.renew_cg()

print(pb.as_array())
print(pW.as_array())
step(True)
print(pb.as_array()) # b updated
print(pW.as_array())
step(False)     
print(pb.as_array()) # b not updated
print(pW.as_array())
  • Для update_subset я бы предположил, что индексы являются целыми числамисуффикс в конце имен параметров (.name()). В документе мы должны использовать функцию get_index.
  • Другой вариант: dy.nobackprop(), который предотвращает распространение градиента за пределы определенного узла.на графике.
  • И еще один вариант - обнулить градиент параметра, который не нужно обновлять (.scale_gradient(0)).

Эти методы эквивалентны обнулениюградиент до обновления.Таким образом, параметр все равно будет обновляться, если оптимизатор использует свой импульс от предыдущих шагов обучения (MomentumSGDTrainer, AdamTrainer, ...).

...