Как можно выбрать произвольные модели pytorch, которые используют лямбда-функции? - PullRequest
0 голосов
/ 29 апреля 2020

У меня в настоящее время есть модуль нейронной сети:

import torch.nn as nn

class NN(nn.Module):
    def __init__(self,args,lambda_f,nn1, loss, opt):
        super().__init__()
        self.args = args
        self.lambda_f = lambda_f
        self.nn1 = nn1
        self.loss = loss
        self.opt = opt
        # more nn.Params stuff etc...

    def forward(self, x):
        #some code using fields
        return out

Я пытаюсь проверить его, но поскольку pytorch сохраняет с помощью state_dict s, это означает, что я не могу сохранить лямбда-функции, которые я фактически использовал, если Я проверяю контрольно-пропускной пункт с pytorch torch.save et c. Я буквально хочу сохранить все без проблем и перезагрузить, чтобы потом тренироваться на GPU. В настоящее время я использую это:

def save_ckpt(path_to_ckpt):
    from pathlib import Path
    import dill as pickle
    ## Make dir. Throw no exceptions if it already exists
    path_to_ckpt.mkdir(parents=True, exist_ok=True)
    ckpt_path_plus_path = path_to_ckpt / Path('db')

    ## Pickle args
    db['crazy_mdl'] = crazy_mdl
    with open(ckpt_path_plus_path , 'ab') as db_file:
        pickle.dump(db, db_file)

в настоящее время он не выдает ошибок, когда я проверяю его, и он сохранил его.

Я обеспокоен тем, что при обучении может возникнуть незначительная ошибка, даже если не происходит обучения исключений / ошибок или может произойти что-то неожиданное (например, странное сохранение на дисках в кластерах и т. Д. c, кто знает).

Безопасно ли это делать с классами pytorch / nn моделями? Особенно, если мы хотим возобновить обучение с графическими процессорами?

Кросс-пост:

Ответы [ 2 ]

1 голос
/ 30 апреля 2020

Я dill автор. Я использую dillklepto) для сохранения классов, которые содержат обученные ANN внутри лямбда-функций. Я склонен использовать комбинации mystic и sklearn, поэтому я не могу напрямую говорить с pytorch, но могу предположить, что он работает одинаково. Вы должны быть осторожны, если у вас есть лямбда, которая содержит указатель на объект, внешний по отношению к лямбде ... например, y = 4; f = lambda x: x+y. Это может показаться очевидным, но dill будет перебирать лямбду и, в зависимости от остального кода и варианта сериализации, может не сериализовать значение y. Итак, я видел много случаев, когда люди сериализовали обученную оценку внутри некоторой функции (или лямбда-выражения, или класса), и тогда результаты не были «правильными», когда они восстанавливали функцию из сериализации. Основная причина в том, что функция не была инкапсулирована, поэтому все объекты, необходимые для получения правильных результатов, сохраняются в рассоле. Однако даже в этом случае вы можете получить «правильные» результаты обратно, но вам просто нужно создать ту же среду, которая была у вас, когда вы выбирали оценщик (т.е. все те же значения, от которых зависит окружающее пространство имен). Вывод должен быть сделан, постарайтесь убедиться, что все переменные, используемые в функции, определены внутри функции. Вот часть класса, который я недавно начал использовать сам (должен быть в следующей версии mystic):

class Estimator(object):
    "a container for a trained estimator and transform (not a pipeline)"
    def __init__(self, estimator, transform):
        """a container for a trained estimator and transform

    Input:
        estimator: a fitted sklearn estimator
        transform: a fitted sklearn transform
        """
        self.estimator = estimator
        self.transform = transform
        self.function = lambda *x: float(self.estimator.predict(self.transform.transform(np.array(x).reshape(1,-1))).reshape(-1))
    def __call__(self, *x):
        "f(*x) for x of xtest and predict on fitted estimator(transform(xtest))"
        import numpy as np
        return self.function(*x)

Обратите внимание, когда вызывается функция, все, что она использует (включая np) определяется в окружающем пространстве имен. Пока оценки pytorch сериализуются, как и ожидалось (без внешних ссылок), с вами все будет в порядке, если вы будете следовать приведенным выше рекомендациям.

0 голосов
/ 30 апреля 2020

Да, я думаю, что безопасно использовать dill, чтобы засекать лямбда-функции и т.д. c. Я использовал torch.save с укропом, чтобы сохранить состояние, и у меня не было проблем с возобновлением тренировок как на GPU, так и на CPU, если только класс модели не был изменен. Даже если класс модели был изменен (добавление / удаление некоторых параметров), я мог бы загрузить dict состояния, изменить его и загрузить в модель.

Кроме того, обычно люди не сохраняют объекты модели, а только определяют состояние, то есть значения параметров, чтобы возобновить обучение вместе с аргументами гиперпараметров / модели, чтобы позже получить тот же объект модели.

Сохранение иногда объект модели может быть проблематичным c, поскольку изменения в классе модели (коде) могут сделать сохраненный объект бесполезным. Если вы вообще не планируете изменять класс / код модели и, следовательно, объект модели не будет изменен, то, возможно, сохранение объектов может работать хорошо, но в целом не рекомендуется выбирать объект модуля.

...