У меня есть код Python со следующими двумя классами.
import torch
import torch.nn as nn
import torch.nn.functional as F
class QNet_baseline(nn.Module):
"""
A MLP with 2 hidden layer
observation_dim (int): number of observation features
action_dim (int): Dimension of each action
seed (int): Random seed
"""
def __init__(self, observation_dim, action_dim, seed):
super(QNet_baseline, self).__init__()
self.seed = torch.manual_seed(seed)
self.fc1 = nn.Linear(observation_dim, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, action_dim)
def forward(self, observations):
"""
Forward propagation of neural network
"""
x = F.relu(self.fc1(observations))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class QNet_3hidden(nn.Module):
"""
A MLP with 3 hidden layer
observation_dim (int): number of observation features
action_dim (int): Dimension of each action
seed (int): Random seed
"""
def __init__(self, observation_dim, action_dim, seed):
super(QNet_3hidden, self).__init__()
self.seed = torch.manual_seed(seed)
self.fc1 = nn.Linear(observation_dim, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 64)
self.fc4 = nn.Linear(64, action_dim)
def forward(self, observations):
"""
Forward propagation of neural network
"""
x = F.relu(self.fc1(observations))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
Я использовал один и тот же код для создания экземпляров обоих классов. QNet_baseline
работает нормально, но я получил следующую ошибку для QNet_3hidden
. Почему QNet_baseline
работает, но QNet_3hidden
имеет ошибку? Что я здесь пропустил? Спасибо!
/home/workspace/QNetworks.py in __init__(self, observation_dim, action_dim, seed)
44
45 def __init__(self, observation_dim, action_dim, seed):
---> 46 super(QNet_3hidden, self).__init__()
47 self.seed = torch.manual_seed(seed)
48 self.fc1 = nn.Linear(observation_dim, 128)
TypeError: super(type, obj): obj must be an instance or subtype of type
Кроме того, ниже показано, как создаются два класса:
class DDQN_Agent():
"""Interacts with and learns from the environment.
Attributes:
state_size (int): dimension of each state
action_size (int): dimension of each action
seed (int): random seed
"""
def __init__(self, state_size, action_size, seed, qnet="baseline", filename=None):
"""Initialize an Agent object.
Args:
filename: path of .pth file with trained weights
"""
self.state_size = state_size
self.action_size = action_size
self.seed = random.seed(seed)
# Q-Network
if qnet=="3hidden":
self.qnetwork_local = QNet_3hidden(state_size, action_size, seed).to(device)
self.qnetwork_target = QNet_3hidden(state_size, action_size, seed).to(device)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
else:
self.qnetwork_local = QNet_baseline(state_size, action_size, seed).to(device)
self.qnetwork_target = QNet_baseline(state_size, action_size, seed).to(device)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
if filename:
weights = torch.load(filename)
self.qnetwork_local.load_state_dict(weights)
self.qnetwork_target.load_state_dict(weights)
# Replay memory
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
# Initialize time step (for updating every UPDATE_EVERY steps)
self.t_step = 0