Думаю, это довольно просто, или я неправильно понял ваш запрос.
x
, t
- ваши входные переменные.
Теперь давайте определим сеть M
, которая будет принимать вход t и выход theta
.
M = nn.Sequential(....) # declare network here
Далее определяем сеть Y
. Это может быть сложно, поскольку вы хотите использовать тэту в качестве параметров. Возможно, будет проще и интуитивно понятнее работать с функциональными аналогами модулей, заявленных в nn
(см. https://pytorch.org/docs/stable/nn.functional.html). Я попытаюсь привести пример этого, предполагая, что тета - это параметры линейного модуля.
class Y(nn.Module):
def __init__(self):
# declare any modules here
def forward(self, theta, x):
return nn.functional.linear(input=x, weight=theta, bias=None)
Общий проход вперед будет
def forward(t, x, M, Y):
theta = M(t)
output = Y(theta, x)
return output