Вот что я придумала до сих пор ...
class GridDataset:
def __init__(self):
self.chunk_size = 5
self.x = torch.arange(100)
self.y = torch.arange(100)
self.z = torch.tensor([0.22,0.22,0.45,0.788,0.013])
def chunks_x(self):
return self.x.size(0) // self.chunk_size
def chunks_y(self):
return self.y.size(0) // self.chunk_size
def __len__(self):
return self.chunks_x() * self.chunks_y()
def __getitem__(self, idx):
if idx >= len(self): raise IndexError()
# integer division to get the id along the first axis
x_idx = idx // self.chunks_x()
# modulo division to get the id along the other axis
y_idx = idx % self.chunks_x()
cs = self.chunk_size # to make lines shorter
# grab the actual slices using the computed values of x_idx and y_idx
x_chunk = self.x[cs * x_idx:cs * (1+x_idx)]
y_chunk = self.y[cs * y_idx:cs * (1+y_idx)]
print(x_chunk.shape)
x_chunk = x_chunk.unsqueeze(dim=1).double()
y_chunk = y_chunk.unsqueeze(dim=1).double()
xytotal = torch.cat((x_chunk,y_chunk),dim=1)
r = torch.sqrt(x_chunk**2 + y_chunk**2).float()
new = torch.zeros((len(xytotal),len(xytotal[0]) + len(self.z)))
for i in range(len(xytotal)):
new[i] = torch.cat((xytotal[i].double(),self.z.double()))
new = torch.cat((new,r),dim=1)
return new
Если я отображаю одно значение, оно выдает:
torch.Size([5])
tensor([[0.0000, 5.0000, 0.2200, 0.2200, 0.4500, 0.7880, 0.0130, 5.0000],
[1.0000, 6.0000, 0.2200, 0.2200, 0.4500, 0.7880, 0.0130, 6.0828],
[2.0000, 7.0000, 0.2200, 0.2200, 0.4500, 0.7880, 0.0130, 7.2801],
[3.0000, 8.0000, 0.2200, 0.2200, 0.4500, 0.7880, 0.0130, 8.5440],
[4.0000, 9.0000, 0.2200, 0.2200, 0.4500, 0.7880, 0.0130, 9.8489]])