Почему я не могу добавить тензор PyTorch с torch.cat? - PullRequest
0 голосов
/ 25 марта 2020

У меня есть:

import torch

input_sliced = torch.rand(180, 161)
output_sliced = torch.rand(180,)

batched_inputs = torch.Tensor()
batched_outputs = torch.Tensor()

print('input_sliced.size', input_sliced.size())
print('output_sliced.size', output_sliced.size())

batched_inputs = torch.cat((batched_inputs, input_sliced))
batched_outputs = torch.cat((batched_outputs, output_sliced))

print('batched_inputs.size', batched_inputs.size())
print('batched_outputs.size', batched_outputs.size())

Это выводит:

input_sliced.size torch.Size([180, 161])
output_sliced.size torch.Size([180])

batched_inputs.size torch.Size([180, 161])
batched_outputs.size torch.Size([180])

Мне нужно добавить batched, но torch.cat не работает. Что я делаю не так?

1 Ответ

1 голос
/ 25 марта 2020

Предполагая, что вы делаете это на всех oop, я бы сказал, что лучше сделать это так:

import torch

batch_input, batch_output = [], []
for i in range(10):  # assuming batch_size=10
    batch_input.append(torch.rand(180, 161))
    batch_output.append(torch.rand(180,))

batch_input = torch.stack(batch_input)
batch_output = torch.stack(batch_output)

print(batch_input.shape)   # output: torch.Size([10, 180, 161])
print(batch_output.shape)  # output: torch.Size([10, 180])

Если вы знаете результирующую batch_* форму априори , вы можете предварительно выделить финальный Tensor и просто назначить каждый образец в соответствующие позиции в пакете. Было бы более эффективно использовать память.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...