Что делает flatten_parameters ()? - PullRequest
       3

Что делает flatten_parameters ()?

0 голосов
/ 09 ноября 2018

Я видел много примеров Pytorch, использующих flatten_parameters в функции forward RNN

self.rnn.flatten_parameters()

Я видел это RNNBase , и написано, что оно

Сбрасывает указатель на данные параметров, чтобы они могли использовать более быстрые пути кода

Что это значит?

1 Ответ

0 голосов
/ 08 мая 2019

Возможно, это не полный ответ на ваш вопрос. Но если вы посмотрите на исходный код flatten_parameters, вы заметите, что он вызывает _cudnn_rnn_flatten_weight в

...
NoGradGuard no_grad;
torch::_cudnn_rnn_flatten_weight(...)
...

- это функция, которая выполняет эту работу. Вы обнаружите, что фактически она копирует веса модели в vector<Tensor> (проверьте объявление params_arr) в:

  // Slice off views into weight_buf
  std::vector<Tensor> params_arr;
  size_t params_stride0;
  std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf);

  MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
                    params{params_arr, params_stride0};

А весы копируются в

  // Copy weights
  _copyParams(weight, params);

Также обратите внимание, что они обновляют (или Reset, как они явно говорят в документах) оригинальные указатели weights с новыми указателями params, выполняя операцию на месте .set_ (_ is их обозначения для операций на месте) в orig_param.set_(new_param.view_as(orig_param));

  // Update the storage
  for (size_t i = 0; i < weight.size(0); i++) {
    for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin();
         orig_param_it != weight[i].end() && new_param_it != params[i].end();
         orig_param_it++, new_param_it++) {
      auto orig_param = *orig_param_it, new_param = *new_param_it;
      orig_param.set_(new_param.view_as(orig_param));
    }
  }

А согласно n2798 (черновик C ++ 0x)

© ISO / IECN3092

23.3.6 Шаблон класса вектор

Вектор - это контейнер последовательности, который поддерживает итераторы произвольного доступа. Кроме того, он поддерживает (амортизируется) операции вставки и удаления с постоянным временем в конце; вставить и стереть в середине взять линейное время. Управление хранилищем осуществляется автоматически, хотя могут быть даны советы для повышения эффективности. Элементы вектора хранятся непрерывно , что означает, что если v - это вектор <T, Allocator>, где T - это какой-то тип, отличный от bool, то он подчиняется identity&v[n] == &v[0] + n для всех 0 <= n < v.size() .


В некоторых ситуациях

UserWarning: Вес модуля RNN не является частью одного непрерывного фрагмента памяти. Это означает, что они должны быть сжаты при каждом вызове, что может значительно увеличить использование памяти. Для уменьшения веса снова звоните flatten_parameters().

Они явно советуют людям в предупреждениях кода иметь непрерывный кусок памяти.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...