Оптимизация кода: вычисления в Torch.Tensor - PullRequest
0 голосов
/ 30 апреля 2019

В настоящее время я реализую функцию для вычисления пользовательских кросс-энтропийных потерь. Определение функции следующее изображение.

cite from

мои коды следующие,

output = output.permute(0, 2, 3, 1)
target = target.permute(0, 2, 3, 1)

batch, height, width, channel = output.size()

total_loss = 0.
for b in range(batch): # for each batch
    o = output[b]
    t = target[b]
    loss = 0.
    for w in range(width):
        for h in range(height): # for every pixel([h,w]) in the image
            sid_t = t[h][w][0]
            sid_o_candi = o[h][w]
            part1 = 0. # to store the first sigma 
            part2 = 0. # to store the second sigma

            for k in range(0, sid_t):
                p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
                part1 += torch.log(p + 1e-12).item()

            for k in range(sid_t, intervals):
                p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
                part2 += torch.log(1-p + 1e-12).item()

            loss += part1 + part2

    loss /= width * height * (-1)
    total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)

Мне интересно, можно ли с этим кодом провести какую-либо оптимизацию?

1 Ответ

2 голосов
/ 30 апреля 2019

Я не уверен, sid_t = t[h][w][0] одинаково для каждого пикселя или нет. Если это так, вы можете избавиться от всех for loop, которые повышают скорость вычислительных потерь.

Не используйте .item(), потому что он вернет значение Python, которое теряет дорожку grad_fn. Тогда вы не можете использовать loss.backward() для вычисления градиентов.

Если sid_t = t[h][w][0] не то же самое, вот некоторая модификация, которая поможет вам избавиться хотя бы от 1 for-loop:



batch, height, width, channel = output.size()

total_loss = 0.
for b in range(batch): # for each batch
    o = output[b]
    t = target[b]
    loss = 0.
    for w in range(width):
        for h in range(height): # for every pixel([h,w]) in the image
            sid_t = t[h][w][0]
            sid_o_candi = o[h][w]
            part1 = 0. # to store the first sigma 
            part2 = 0. # to store the second sigma

            sid1_cumsum = sid_o_candi[:sid_t].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,)) 
            part1 = torch.sum(torch.log(sid1_cumsum + 1e-12))

            sid2_cumsum = sid_o_candi[sid_t:intervals].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,)) 
            part2 = torch.sum(torch.log(1 - sid2_cumsum + 1e-12))

            loss += part1 + part2

    loss /= width * height * (-1)
    total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)

Как это работает:

x = torch.arange(10); 
print(x)

x_flip = x.flip(dims=(0,)); 
print(x_flip)

x_inverse_cumsum = x_flip.cumsum(dim=0).flip(dims=(0,))
print(x_inverse_cumsum)

# output
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
tensor([45, 45, 44, 42, 39, 35, 30, 24, 17,  9])

Надеюсь, это поможет.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...