Понимание того, когда использовать список Python в Pytorch - PullRequest
0 голосов
/ 15 марта 2019

В основном, поскольку этот поток обсуждает здесь , вы не можете использовать список python для обертывания ваших подмодулей (например, ваших слоев); в противном случае Pytorch не собирается обновлять параметры субмодулей внутри списка. Вместо этого вы должны использовать nn.ModuleList, чтобы обернуть ваши подмодули, чтобы убедиться, что их параметры будут обновлены. Теперь я также видел такие коды, как следующие, где автор использует список python для расчета потерь, а затем делает loss.backward() для обновления (в алгоритме подкрепления RL). Вот код:

 policy_loss = []
    for log_prob in self.controller.log_probability_slected_action_list:
        policy_loss.append(- log_prob * (average_reward - b))
    self.optimizer.zero_grad()
    final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
    final_policy_loss.backward()
    self.optimizer.step()

Почему использование списка в этом формате работает для обновления параметров модулей, но первый случай не работает? Я очень смущен сейчас. Если я изменяю в предыдущем коде policy_loss = nn.ModuleList([]), выдается исключение, говорящее, что тензорное плавание не является субмодулем.

1 Ответ

2 голосов
/ 15 марта 2019

Вы неправильно понимаете, что такое Module s. A Module хранит параметры и определяет реализацию прямого прохода.

Вам разрешено выполнять произвольные вычисления с тензорами и параметрами, что приводит к появлению других новых тензоров. Modules не нужно знать об этих тензорах. Вы также можете хранить списки тензоров в списках Python. При вызове backward он должен быть в скалярном тензоре, таким образом, сумма конкатенации. Эти тензоры являются потерями, а не параметрами, поэтому они не должны быть атрибутами Module или заключены в ModuleList.

...