Я новичок в использовании Numba. Я пытаюсь создать jitclass с модулями pytorch, но не могу найти типы numba для назначения им. Ниже приведен пример кода.
from numba import jitclass, int32
import torch
spec=[('val',int32),('network',),('loss',),('optimizer',)]
@jitclass(spec)
class Test():
def __init__(self, val, network, loss, optimizer):
self.val=val
self.network=network
self.loss=loss
self.opimizer=optimizer
val=10
network= torch.nn.Sequential(torch.nn.Linear(10,50),torch.nn.ReLU(),
torch.nn.Linear(50,5),torch.nn.ReLU())
MseLoss= torch.nn.MSELoss()
optimizer= torch.optim.Adam(network.parameters(), lr=0.001)
obj=Test(val,network,MseLoss,optimizer)