Есть ли удобный способ сбросить данные running_stats для модели pytorch? - PullRequest
1 голос
/ 22 сентября 2019

Я пишу C-версию модели pytorch для запуска на моем специальном оборудовании.Пока все выглядит хорошо, за исключением running_mean и running_var в каждом слое batchnorm.

У нас есть код Python для вывода всех named_parameters, но ничего не нужно делать для running_stats, хотя нам нужно использовать его в вычислениях пересылки.

Так есть ли способ выгрузить его с помощью встроенных функций?Я искал документацию, без помощи по моей задаче.В противном случае мне может понадобиться написать код регулярного выражения, чтобы распознать и вывести их.

Большое спасибо./ Патрик

for name, param in model.named_parameters():
    # here can dump weight and bias, but not running_stats
    names.append(name)
    shapes.append(list(param.data.numpy().shape))
    values.append(param.data.numpy().flatten().tolist())

1 Ответ

0 голосов
/ 22 сентября 2019

running_mean и другие registered_buffers в PyTorch.Вы можете сохранить (как вы говорите, дамп) их с помощью torch.nn.Module state_dict:

torch.save(model.state_dict(), PATH) 

Вы можете перебирать именованные буферы и сохранять каждый из них так, как вам нравится, аналогичнок параметрам:

for name, buffer in model.named_buffers():
    # do your thing with them
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...