Параллельные экземпляры модели TensorFlow 2.0 и проблема Deep-copy - PullRequest
0 голосов
/ 01 ноября 2019

У меня есть модель с тензорным потоком, и я хочу, чтобы несколько ее экземпляров работали параллельно, например, с помощью Joblib, и я замечаю, что у меня возникает следующая ошибка, связанная с травлением. Та же ошибка также, если я использую глубокое копирование:

class neural(Model):
    def __init__(self,
                 output_dim=1, 
                 dtype=tfdtype):
        super(neural, self).__init__()

        self.length_scales  = tf.Variable(
                                      [1.0 for i in range(6)], name='length_scales',
                                      dtype=dtype,
                                      constraint=tf.keras.constraints.NonNeg()
                                     )
    def call(self, x):
        P = tf.linalg.tensor_diag(self.length_scales**(-1))
        x_embedd      = tf.matmul(x, P)

        return x_embedd

m = neural()
m2 = deepcopy(m)

, для которого возвращается следующий трекбек:

----------------------------------------------------------------------
TypeError                            Traceback (most recent call last)
<ipython-input-12-8b3383bcd0cc> in <module>
      1 server = tf.distribute.Server.create_local_server()
----> 2 a = deepcopy(m)

~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in deepcopy(x, memo, _nil)
    178                     y = x
    179                 else:
--> 180                     y = _reconstruct(x, memo, *rv)
    181 
    182     # If is its own copy, don't memoize.

~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    278     if state is not None:
    279         if deep:
--> 280             state = deepcopy(state, memo)
    281         if hasattr(y, '__setstate__'):
    282             y.__setstate__(state)

~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in deepcopy(x, memo, _nil)
    148     copier = _deepcopy_dispatch.get(cls)
    149     if copier:
--> 150         y = copier(x, memo)
    151     else:
    152         try:

~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in _deepcopy_dict(x, memo, deepcopy)
    238     memo[id(x)] = y
    239     for key, value in x.items():
--> 240         y[deepcopy(key, memo)] = deepcopy(value, memo)
    241     return y
    242 d[dict] = _deepcopy_dict

~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in deepcopy(x, memo, _nil)
    167                     reductor = getattr(x, "__reduce_ex__", None)
    168                     if reductor:
--> 169                         rv = reductor(4)
    170                     else:
    171                         reductor = getattr(x, "__reduce__", None)

TypeError: can't pickle _thread._local objects

Есть ли способ обойти эту ошибку?

...