Conv2d не принимает тензор в качестве входных данных, говоря, что он не тензор - PullRequest
0 голосов
/ 31 мая 2019

Я хочу пропустить тензор через 2 сверточных слоя.Я не могу выполнить его, так как получаю ошибку типа, даже если я преобразовал свой массив numpy в тензор.

Я пытался использовать tf.convert_to_tensor () для решения этой проблемы.Не работает

import numpy as np
import tensorflow as tf

class Generator():

  def __init__(self):

    self.conv1 = nn.Conv2d(1, 28, kernel_size=3, stride=1, padding=1)
    self.pool1 = nn.MaxPool2d(kernel_size=3, stride=0, padding=1)

    self.fc1 = nn.Linear(100, 10)
    self.fc2 = nn.Linear(10, 5)

  def forward_pass(self, x):                                                                       #Why do we pass the object itself in every method?

    x = self.conv1(x)
    print(x)
    x = self.pool1(x)
    print(x)

    x = self.fc1(x)
    print(x)
    x = self.fc2(x)
    print(x)

    return x

arr = tf.convert_to_tensor(np.random.random((3,28,28)))

gen = Generator()
gen.forward_pass(arr)


Сообщение об ошибке -

TypeError                                 Traceback (most recent call last)

<ipython-input-31-9fa8e764dcdb> in <module>()
      1 gen = Generator()
----> 2 gen.forward_pass(arr)

2 frames

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    336                             _pair(0), self.dilation, self.groups)
    337         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 338                         self.padding, self.dilation, self.groups)
    339 
    340 

TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not Tensor

1 Ответ

0 голосов
/ 01 июня 2019

Вы пытаетесь передать тензор TensorFlow в функцию PyTorch.TensorFlow и PyTorch - это отдельные проекты с различными структурами данных, которые, как правило, не могут использоваться взаимозаменяемо таким образом.

Чтобы преобразовать массив NumPy в тензор PyTorch, вы можете использовать:

import torch
arr = torch.from_numpy(np.random.random((3,28,28)))
...