Как проверить, есть ли NamedTuple в списке? - PullRequest
0 голосов
/ 10 января 2020

Я пытался проверить, равен ли inctance NamedTuple «Перехода» какому-либо объекту в списке «self.memory».

Вот код, который я пытался запустить:

from typing import NamedTuple
import random
import torch as t

Transition = NamedTuple('Transition', state=t.Tensor, action=int, reward=int, next_state=t.Tensor, done=int, hidden=t.Tensor)


class ReplayMemory:

    def __init__(self, capacity):
        self.memory = []
        self.capacity = capacity
        self.position = 0

    def store(self, *args):
        print(self.memory == Transition(*args))
        if Transition(*args) in self.memory:
            return
    if len(self.memory) < self.capacity:
        self.memory.append(None)
    self.memory[self.position] = Transition(*args)
    ...

А вот вывод:

False
False

И ошибка, которую я получил:

   ...
        if Transition(*args) in self.memory:
    RuntimeError: bool value of Tensor with more than one value is ambiguous

Это кажется мне странным, потому что печать говорит мне, что "== "операция возвращает логическое значение.

Как это можно сделать правильно?

Спасибо

Редактировать:

* args - это кортеж, состоящий из

torch.Size([16, 12])
int
int
torch.Size([16, 12])
int
torch.Size([4])

1 Ответ

1 голос
/ 10 января 2020

Я считаю, что вы должны четко определить равенство.

from typing import NamedTuple
import random
import torch as t


class Sample(NamedTuple):
    state: t.Tensor
    action: int

    def __eq__(self, other):
        return bool(t.all(self.state == other.state)) and self.action == other.action
...