восстановить тензор факела из его пакетной версии - PullRequest
0 голосов
/ 26 сентября 2019

Это очень хороший пример того, как построить 3D-тензор:

import torch
y = torch.rand(100, 1)
batch_size = 10
batched_data = y.contiguous().view(batch_size, -1, y.size(-1)).transpose(0,1)
batched_data.shape

вывод:

torch.Size([10, 10, 1])

хорошо, теперь я хочу начать сbatched_data Я хочу построить y.Наоборот.Любое хорошее предложение с мощным оптимизированным кодом Pytorch?

==== Дополнительный ввод =====

Я использую это для RNN, и теперь у меня есть некоторые сомнения, потому что, если вы считаете,следующий код:

import torch
y = torch.arange(100).view(100,1)
batch_size = 10
batched_data = y.contiguous().view(batch_size, -1, y.size(-1)).transpose(0,1)
batched_data.shape

Вывод:

tensor([[[ 0],
         [10],
         [20],
         [30],
         [40],
         [50],
         [60],
         [70],
         [80],
         [90]],

        [[ 1],
         [11],
         [21],
         [31],
         [41],
         [51],
         [61],
         [71],
         [81],
         [91]],

Чего я не ожидал.Я бы ожидал что-то вроде: [[1,2,3,4,5,6,7,8,9,10],[11,12,13,14,15,16,17,18,19,20],....

Ответы [ 2 ]

1 голос
/ 26 сентября 2019

Если вы хотите подготовить входные данные для RNN, вам нужно знать, что RNN принимает трехмерные тензоры формы, seq_len, batch, input_size.Здесь input_size относится к числу объектов, а в вашем сценарии оно равно 1. Таким образом, входной тензор формы 10, 10, 1 все еще может быть допустимым входным значением для RNN.

Пример

rnn = nn.RNN(input_size=1, hidden_size=20, num_layers=1)
input = torch.randn(10, 10, 1)
output, hn = rnn(input)
print(output.size()) # 10, 10, 20

Выход РНН имеет форму, seq_len, batch, num_directions * hidden_size.

0 голосов
/ 26 сентября 2019

Предположим, вы хотите сделать что-то вроде этого, чтобы восстановить y:

rebuilded_y = batched_data.transpose(0,1).view(*y.shape)

Чтобы ввод был таким, как вы ожидали, вам нужно удалить транспонирование и дополнительное измерение в batched_data:

batched_data = y.contiguous().view(batch_size, -1)
...