Как (py) факел DDQN узнает, какое действие он обновляет? - PullRequest
0 голосов
/ 16 января 2020

Я пытаюсь привести пример, размещенный здесь:

https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html#training

def optimize_model():

if len(memory) < BATCH_SIZE:
    return
transitions = memory.sample(BATCH_SIZE)
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
# detailed explanation). This converts batch-array of Transitions
# to Transition of batch-arrays.
batch = Transition(*zip(*transitions))

# Compute a mask of non-final states and concatenate the batch elements
# (a final state would've been the one after which simulation ended)
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                      batch.next_state)), device=device, dtype=torch.uint8)
non_final_next_states = torch.cat([s for s in batch.next_state
                                            if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)

# Compute V(s_{t+1}) for all next states.
# Expected values of actions for non_final_next_states are computed based
# on the "older" target_net; selecting their best reward with max(1)[0].
# This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
next_state_values = torch.zeros(BATCH_SIZE, device=device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
# Compute the expected Q values
expected_state_action_values = (next_state_values * GAMMA) + reward_batch

# Compute Huber loss
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

# Optimize the model
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters():
    param.grad.data.clamp_(-1, 1)
optimizer.step()

Если я распечатаю значения в state_action_values ​​ и Ожидаемое_статие_значения Я могу видеть, что они оба (8,1) массив / тензор.

BATCH_SIZE = 8 for a short example

------------------state_action_values
torch.Size([8, 1])
tensor([[0.2089],
        [0.9491],
        [0.2925],
        [0.9807],
        [0.4564],
        [0.6373],
        [0.2620],
        [0.2527]], grad_fn=<GatherBackward>)
------------------expected_state_action_values
torch.Size([8, 1])
tensor([[1.0474],
        [1.0480],
        [1.0478],
        [1.0471],
        [1.0469],
        [1.0487],
        [1.0478],
        [1.0476]])
------------------loss
torch.Size([])
tensor(0.1909, grad_fn=<SmoothL1LossBackward>)

Где связь с действиями?

В В случае с карполом для спортзала у нас есть 2 возможных действия [ВЛЕВО | ВПРАВО]. Таким образом, DDQN должен предсказать Q-значения для 2 действий.

Как DDQN узнает, для какого действия он должен выполнить обновление?

...