Это базовая c модель многозадачного обучения, которая имеет 2 задачи. Поскольку будет только 2 задачи, я могу продублировать код для каждой задачи как self.tower1 и self.tower2, а затем выполнить их в прямом проходе.
class Multitask_Network(nn.Module):
def __init__(self):
super(Multitask_Network, self).__init__()
in_ch = 20
shared_layer_size = 10
tower_hidden_size = 10
output_size = 2
self.sharedlayer = nn.Sequential(
nn.Linear(in_ch, shared_layer_size),
nn.ReLU(),
nn.Dropout()
)
self.tower1 = nn.Sequential(
nn.Linear(shared_layer_size, tower_hidden_size),
nn.ReLU(),
nn.Dropout(),
nn.Linear(tower_hidden_size, output_size)
)
self.tower2 = nn.Sequential(
nn.Linear(shared_layer_size, tower_hidden_size),
nn.ReLU(),
nn.Dropout(),
nn.Linear(tower_hidden_size, output_size)
)
def forward(self, x):
x = self.sharedlayer(x)
out1 = self.tower1(x)
out2 = self.tower2(x)
return out1, out2
model = Multitask_Network()
x = torch.ones((4, 20))
y = model(x)
y
Но для моего случая использования я необходимо определить количество задач модели во время выполнения. Я могу сделать это с for-l oop, но я думаю, что неэффективно работать на GPU. Я думаю, что это делает прямой проход для задачи № 1, затем следует задача № 2, к задаче № п, один за другим.
Как я могу добиться этого без для для * oop в прямом проходе ?
class Multitask_Network(nn.Module):
def __init__(self, num_tasks):
super(Multitask_Network, self).__init__()
in_ch = 20
shared_layer_size = 10
tower_hidden_size = 5
output_size = 2
self.num_tasks = num_tasks
self.sharedlayer = nn.Sequential(
nn.Linear(in_ch, shared_layer_size),
nn.ReLU(),
nn.Dropout()
)
self.conv_towers = []
for i in range(num_tasks):
tower = nn.Sequential(
nn.Linear(shared_layer_size, tower_hidden_size),
nn.ReLU(),
nn.Dropout(),
nn.Linear(tower_hidden_size, output_size)
)
self.conv_towers.append(
tower
)
self.conv_towers = nn.ModuleList(self.conv_towers)
def forward(self, x):
x = self.sharedlayer(x)
output = []
for i in range(self.num_tasks):
output.append(self.conv_towers[i](x))
return output
model = Multitask_Network(num_tasks=3)
x = torch.ones((4, 20))
y = model(x)
y