Как get_updates () optimizers.SGD используется в Keras во время обучения? - PullRequest
0 голосов
/ 08 марта 2019

Я не знаком с внутренней работой Keras и не могу понять, как Keras использует функцию get_updates() optimizer.SGD во время обучения.

Я довольно долго искал в интернете, но получил только несколько деталей. В частности, я понимаю, что правило обновления параметров / весов SGD определено в функции get_updates(). Но похоже, что get_updates() не буквально вызывается на каждой итерации во время обучения; в противном случае «моменты» не будут переноситься из одной итерации в другую, чтобы правильно реализовать импульс, так как он сбрасывается при каждом вызове, c.f. optimizers.py:

shapes = [K.get_variable_shape(p) for p in params]
moments = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + moments
for p, g, m in zip(params, grads, moments):
    v = self.momentum * m - lr * g  # velocity
    self.updates.append(K.update(m, v))

Как указано в https://github.com/keras-team/keras/issues/7502, get_updates () определяет только «символьный граф вычислений». Я не уверен, что это значит. Может кто-нибудь дать более подробное объяснение того, как это работает?

Например, как 'v', вычисленное в одной итерации, передается 'моментам' в следующей итерации для реализации импульса? Буду также признателен, если кто-нибудь подскажет мне, как это работает.

Большое спасибо! (Кстати, я использую тензор потока, если это имеет значение.)

1 Ответ

1 голос
/ 08 марта 2019

get_updates () определяет графовые операции, которые обновляют градиенты. Когда график оценивается для обучения, он будет выглядеть примерно так:

  • проходы вперед вычисляют значение прогнозирования
  • потеря вычисляет стоимость
  • обратные проходы вычисляют градиенты
  • градиенты обновлены

Обновление градиентов - это само вычисление графа; то есть фрагмент кода, который вы цитируете, определяет, как выполнить операцию, указав, какие тензоры используются и какие математические операции выполняются. Сами математические операции в этот момент не выполняются.

моменты - векторы тензоров, определенные в коде выше. Код создает графовую операцию, которая обновляет каждый элемент моментов.

Каждая итерация графика будет запускать эту операцию обновления.

Следующая ссылка пытается объяснить концепцию вычислительного графа в TensorFlow: https://www.tensorflow.org/guide/graphs

Keras использует те же основные идеи, но отвлекает пользователя от необходимости иметь дело с деталями низкого уровня. Определение модели в традиционном API TensorFlow 1.0 требует гораздо более высокого уровня детализации.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...