Возможно, это не полный ответ на ваш вопрос. Но если вы посмотрите на исходный код flatten_parameters
, вы заметите, что он вызывает _cudnn_rnn_flatten_weight
в
...
NoGradGuard no_grad;
torch::_cudnn_rnn_flatten_weight(...)
...
- это функция, которая выполняет эту работу. Вы обнаружите, что фактически она копирует веса модели в vector<Tensor>
(проверьте объявление params_arr
) в:
// Slice off views into weight_buf
std::vector<Tensor> params_arr;
size_t params_stride0;
std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf);
MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
params{params_arr, params_stride0};
А весы копируются в
// Copy weights
_copyParams(weight, params);
Также обратите внимание, что они обновляют (или Reset
, как они явно говорят в документах) оригинальные указатели weights
с новыми указателями params
, выполняя операцию на месте .set_
(_
is их обозначения для операций на месте) в orig_param.set_(new_param.view_as(orig_param));
// Update the storage
for (size_t i = 0; i < weight.size(0); i++) {
for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin();
orig_param_it != weight[i].end() && new_param_it != params[i].end();
orig_param_it++, new_param_it++) {
auto orig_param = *orig_param_it, new_param = *new_param_it;
orig_param.set_(new_param.view_as(orig_param));
}
}
А согласно n2798 (черновик C ++ 0x)
© ISO / IECN3092
23.3.6 Шаблон класса вектор
Вектор - это контейнер последовательности, который поддерживает итераторы произвольного доступа. Кроме того, он поддерживает (амортизируется) операции вставки и удаления с постоянным временем в конце; вставить и стереть в середине взять линейное время. Управление хранилищем осуществляется автоматически, хотя могут быть даны советы для повышения эффективности. Элементы вектора хранятся непрерывно , что означает, что если v
- это вектор <T, Allocator>
, где T
- это какой-то тип, отличный от bool, то он подчиняется identity&v[n] == &v[0] + n
для всех 0 <= n < v.size()
.
В некоторых ситуациях
UserWarning: Вес модуля RNN не является частью одного непрерывного фрагмента памяти. Это означает, что они должны быть сжаты при каждом вызове, что может значительно увеличить использование памяти. Для уменьшения веса снова звоните flatten_parameters()
.
Они явно советуют людям в предупреждениях кода иметь непрерывный кусок памяти.