Вывести вес из модели Pytorch - PullRequest
0 голосов
/ 29 марта 2020

Вот очень базовая c модель:

class LinearDeepQNetwork(nn.Module):
    def __init__(self, lr, n_actions, input_dims):
        super(LinearDeepQNetwork, self).__init__()

        self.fc1 = nn.Linear(*input_dims, 128)
        self.fc2 = nn.Linear(128, n_actions)

        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        layer1 = F.relu(self.fc1(state))
        actions = self.fc2(layer1)

        return actions

Имейте в виду, что я использую Pytorch, а не Keras или Tensorflow. В моем классе Agent() я создаю экземпляр self.Q_eval = LinearDeepQNetwork(self.lr, self.n_actions, self.input_dims). После того, как я подготовил своего агента к нескольким эпизодам, мне нужно вывести веса self.Q_eval. Как я могу это сделать?

1 Ответ

0 голосов
/ 29 марта 2020

Мне нужно было ввести веса из сети Q_eval в сеть Q_next. Я сделал следующую функцию:

def replace_target_network(self):
        self.Q_next.load_state_dict(self.Q_eval.state_dict())
        self.Q_next.eval()

Ответ Я могу получить веса с Q_eval.state_dict().

...