Pytorch печатает Nan, хотя у него есть значения - PullRequest
0 голосов
/ 20 февраля 2020

У меня проблема с кодом ниже. Когда я запускаю код, я получаю значение для MSE Nan. Я проверил все промежуточные функции, они прекрасно работают. Я не знаю откуда Наны входят в сеть.

def train_model(self, X, Y, lr = 0.01):
    print("No of Layers: ", self.no_of_layers)
    self.intermediate = []

    for epoch in range(100):
        overall_loss = 0
        print("Epoch:",epoch+1)
        for i in range(len(X)):
            self.layers = []
            inp = X.iloc[i,:]
            output = Y.iloc[i]
            ## PreProcessing for one input
            a = torch.tensor(inp)
            # a = torch.from_numpy(trimf(a,[0,3,10]))
            a = a.to(device = torch.device('cuda'))
            a = a.view(-1,1)
            output = torch.tensor(output, device = self.device)
            # print("Input Number:",i)
            ## Feed Forward
            for layer in range(self.no_of_layers):
                if layer == 0:
                    inpt = a
                else:
                    inpt = self.layers[layer-1]
                ## Problem is with this inp. Layer 0 tensor of 1st exaple is passed as input to all the other examples for succesive layers.
                # print("CL:", inpt)
                Model.compute_layer(self,self.weights[layer],inpt,layer)
            # print("Overall Layers",self.layers)


            pred = self.layers[-1]
            pred != pred
            # print(pred)
            pred = torch.sigmoid(pred)

            ## Loss Function
            loss = (pred - output)**2
            overall_loss +=torch.sum(loss)

            ## Back Propgation
            loss.backward(torch.empty(loss.size(), device = self.device), retain_graph = True)
            j = 0
            for w in self.weights:
                    # print("BU", w.data)
                    w.data -= (lr*((w.grad-torch.min(w.grad))/(torch.max(w.grad)-torch.min(w.grad))))
                    w.grad != w.grad

                    # w.data -= (lr*w.grad)
                    # w.data != w.data
                    # print("AU",w.data)
        overall_loss = overall_loss/len(X)
        print("MSE: ", overall_loss.item())

Вывод:

Code Output

...