Я думаю об этом как о частичной ситуации приложения - полезно иметь возможность "связывать" многие переменные конфигурации с объектом функции потерь.В большинстве случаев ваша функция потерь должна принять prediction
и ground_truth
в качестве аргументов.Это делает для довольно унифицированного базового API функций потери.Однако они отличаются в деталях.Например, не каждая функция потерь имеет параметр reduction
.BCEWithLogitsLoss
имеет параметры weight
и pos_weight
;PoissonNLLLoss
имеет log_input
, eps
.Удобно написать функцию типа
def one_epoch(model, dataset, loss_fn, optimizer):
for x, y in dataset:
model.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
, которая может работать как с экземплярами BCEWithLogitsLoss
, так и с PoissonNLLLoss
.Но он не может работать с их функциональными аналогами из-за необходимости ведения бухгалтерского учета.Вместо этого вам придется сначала создать
loss_fn_packed = functools.partial(F.binary_cross_entropy_with_logits, weight=my_weight, reduction='sum')
, и только потом вы сможете использовать его с one_epoch
, определенным выше.Но эта упаковка уже предоставляется с API объектно-ориентированных потерь, а также с некоторыми прибамбасами (поскольку подкласс потерь nn.Module
, вы можете использовать перемотки вперед и назад, перемещать вещи между процессором и процессором и т. Д. И т. Д.).