Модули PyTorch внутри Numba Jitclass - PullRequest
0 голосов
/ 06 февраля 2020

Я новичок в использовании Numba. Я пытаюсь создать jitclass с модулями pytorch, но не могу найти типы numba для назначения им. Ниже приведен пример кода.

from numba import jitclass, int32
import torch

spec=[('val',int32),('network',),('loss',),('optimizer',)]

@jitclass(spec)
class Test():
    def __init__(self, val, network, loss, optimizer):
        self.val=val
        self.network=network
        self.loss=loss
        self.opimizer=optimizer

val=10        
network= torch.nn.Sequential(torch.nn.Linear(10,50),torch.nn.ReLU(),
                     torch.nn.Linear(50,5),torch.nn.ReLU())
MseLoss= torch.nn.MSELoss()

optimizer= torch.optim.Adam(network.parameters(), lr=0.001)

obj=Test(val,network,MseLoss,optimizer)
...