класс глубокой копии, содержащий модели кераса - PullRequest
0 голосов
/ 06 октября 2018

В моем скрипте на Python я создал класс, который, среди прочего, содержит keras модели, подобные следующим:

from keras.layers import Input, Activation, Dense
from keras.models import Model


class Klass:

    def __init__(self, input_dims, output_dims, hidden_dims, optimizer, a, b):

        self.input_dims = input_dims
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims
        self.optimizer = optimizer
        self.a = a
        self.b = b

        self.__build_nn()

    def __build_nn(self):

        inputs = Input(shape=(self.input_dims,))
        net = inputs
        for h_dim in self.hidden_dims:
            net = Dense(h_dim, kernel_initializer='he_uniform')(net)
            net = Activation("relu")(net)

        outputs = Dense(self.output_dims)(net)
        outputs = Activation("linear")(outputs)
        self.nn1 = Model(inputs=inputs, outputs=outputs)
        self.nn2 = Model(inputs=inputs, outputs=outputs)
        self.nn1.compile(optimizer=self.optimizer, loss='mean_squared_error')
        self.nn2.compile(optimizer=self.optimizer, loss='mean_squared_error')

После создания экземпляра Klass я хотел бы создатьглубокая копия этого:

import copy
obj = Klass(10, 10, (20, 20), Adam(), 1, 2)
obj_dc = copy.deepcopy(obj)

Однако, это бросает TypeError: can't pickle _thread.RLock objects.Я почти уверен, что ошибка связана с keras моделями в объекте класса, поскольку я смог получить глубокую копию аналогичного класса без keras моделей.

К сожалению, мне не удалосьчтобы найти решение этой проблемы в Интернете, так как большинство вопросов, касающихся глубокого копирования модели keras, пытались клонировать модель keras, например здесь .

Итак, как я могуполучить глубокую копию класса, содержащего keras моделей?

РЕДАКТИРОВАТЬ

Эти три вопроса ( 1 , 2 , 3 ) упоминают аналогичную ошибку при разных обстоятельствах.Тем не менее, предлагаемые там решения не применимы в моем случае.

РЕДАКТИРОВАТЬ 2

Как предлагается в комментариях, я добавил метод copy вучебный класс.Будет ли это жизнеспособным решением?

class Klass:

    def __init__(self, input_dims, output_dims, hidden_dims, optimizer, a, b):

        self.input_dims = input_dims
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims
        self.optimizer = optimizer
        self.a = a
        self.b = b

        self.__build_nn()

    # [...]

    def copy(self):

        new = Klass(self.input_dims, self.output_dims, self.hidden_dims,
                    self.optimizer, self.a, self.b)
        new.nn1.set_weights(self.nn1.get_weights())
        new.nn2.set_weights(self.nn2.get_weights())

        return new

1 Ответ

0 голосов
/ 06 октября 2018

Решено в комментариях: добавлен метод copy для Klass, который копирует веса из старого экземпляра Klass во вновь созданный.

...