То, что вы пытаетесь сделать, можно сделать, но не следует, поскольку в большинстве случаев это совершенно не нужно. И это не более читаемое IMO и определенно против пути PyTorch.
В ваших forward
слоях каждый раз происходит повторная инициализация, и они не регистрируются в вашей сети.
Чтобы сделать это правильно, вы можете использовать Module
Функция add_module()
с защитой от переназначения (метод dynamic
ниже):
import torch
class Example(torch.nn.Module):
def __init__(self):
self.num_embeddings = 2020
self.embedding_dim = 51
def dynamic(self, name: str, module_class, *args, **kwargs):
if not hasattr(self, name):
self.add_module(name, module_class(*args, **kwargs))
return getattr(self, name)
def forward(self, x):
embedded = self.dynamic("embedding",
torch.nn.Embedding,
self.num_embeddings,
self.embedding_dim)(x)
return embedded
Вы можете структурировать это по-другому, но за этим стоит идея.
Реально Примером использования может быть случай, когда создание слоя каким-то образом зависит от данных, передаваемых в forward
, но это может указывать на некоторое fl aws в разработке программы.