Я не уверен, sid_t = t[h][w][0]
одинаково для каждого пикселя или нет. Если это так, вы можете избавиться от всех for loop
, которые повышают скорость вычислительных потерь.
Не используйте .item()
, потому что он вернет значение Python, которое теряет дорожку grad_fn
. Тогда вы не можете использовать loss.backward()
для вычисления градиентов.
Если sid_t = t[h][w][0]
не то же самое, вот некоторая модификация, которая поможет вам избавиться хотя бы от 1 for-loop
:
batch, height, width, channel = output.size()
total_loss = 0.
for b in range(batch): # for each batch
o = output[b]
t = target[b]
loss = 0.
for w in range(width):
for h in range(height): # for every pixel([h,w]) in the image
sid_t = t[h][w][0]
sid_o_candi = o[h][w]
part1 = 0. # to store the first sigma
part2 = 0. # to store the second sigma
sid1_cumsum = sid_o_candi[:sid_t].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,))
part1 = torch.sum(torch.log(sid1_cumsum + 1e-12))
sid2_cumsum = sid_o_candi[sid_t:intervals].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,))
part2 = torch.sum(torch.log(1 - sid2_cumsum + 1e-12))
loss += part1 + part2
loss /= width * height * (-1)
total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)
Как это работает:
x = torch.arange(10);
print(x)
x_flip = x.flip(dims=(0,));
print(x_flip)
x_inverse_cumsum = x_flip.cumsum(dim=0).flip(dims=(0,))
print(x_inverse_cumsum)
# output
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
tensor([45, 45, 44, 42, 39, 35, 30, 24, 17, 9])
Надеюсь, это поможет.