Обновите значения веса с заданным std :: vector - PullRequest
0 голосов
/ 06 мая 2019

Мне нужно асинхронно обновлять веса нескольких копий одной сети в большинстве алгоритмов RL. Я пытался написать функцию класса, в которой существует экземпляр toch::nn::seqential. Используя named_parameters(), я могу получить доступ к параметрам в сети. Теперь вопрос в том, могу ли я назначить другой тензор той же формы для p.value()? Например, предположим, что у меня есть тензор w, который имеет те же характеристики, что и p.value(). p.value() = w назначает значения в w для p.value()? Я проверил эту процедуру, как показано ниже, и она не работает для меня:

torch::autograd::GradMode::set_enabled(false);
int m=0;
for (auto &p : net->named_parameters()) {
    auto z = p.value(); // note that z is a Tensor, same as &p : net->parameters
    auto w = torch::zeros_like(p.value());
    if (z.dim()==1){
        int first =  m;
        int last = m + z.size(0);
        m += z.size(0);
        auto v = slice(weights, first, last);
        w+= torch::tensor(v);//.to(cpu_device);
        p.value() = w;
    }
    else if (z.dim()==2){
        int first = m;
        int last = m + z.size(0)*z.size(1);
        m += z.size(0)*z.size(1);
        auto v = slice(weights, first, last);
        w += torch::reshape(torch::tensor(v), {z.size(0),z.size(1)});//.to(cpu_device);
        p.value() = w;
    }
}

, в котором weights - это std::vector<float>, а функция slice возвращает правильный срез вектора weights.

Спасибо, Afshin

...