У меня есть модель с тензорным потоком, и я хочу, чтобы несколько ее экземпляров работали параллельно, например, с помощью Joblib, и я замечаю, что у меня возникает следующая ошибка, связанная с травлением. Та же ошибка также, если я использую глубокое копирование:
class neural(Model):
def __init__(self,
output_dim=1,
dtype=tfdtype):
super(neural, self).__init__()
self.length_scales = tf.Variable(
[1.0 for i in range(6)], name='length_scales',
dtype=dtype,
constraint=tf.keras.constraints.NonNeg()
)
def call(self, x):
P = tf.linalg.tensor_diag(self.length_scales**(-1))
x_embedd = tf.matmul(x, P)
return x_embedd
m = neural()
m2 = deepcopy(m)
, для которого возвращается следующий трекбек:
----------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-12-8b3383bcd0cc> in <module>
1 server = tf.distribute.Server.create_local_server()
----> 2 a = deepcopy(m)
~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in deepcopy(x, memo, _nil)
178 y = x
179 else:
--> 180 y = _reconstruct(x, memo, *rv)
181
182 # If is its own copy, don't memoize.
~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
278 if state is not None:
279 if deep:
--> 280 state = deepcopy(state, memo)
281 if hasattr(y, '__setstate__'):
282 y.__setstate__(state)
~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in deepcopy(x, memo, _nil)
148 copier = _deepcopy_dispatch.get(cls)
149 if copier:
--> 150 y = copier(x, memo)
151 else:
152 try:
~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in _deepcopy_dict(x, memo, deepcopy)
238 memo[id(x)] = y
239 for key, value in x.items():
--> 240 y[deepcopy(key, memo)] = deepcopy(value, memo)
241 return y
242 d[dict] = _deepcopy_dict
~/.local/share/virtualenvs/tf-tRAPLeXL/lib/python3.6/copy.py in deepcopy(x, memo, _nil)
167 reductor = getattr(x, "__reduce_ex__", None)
168 if reductor:
--> 169 rv = reductor(4)
170 else:
171 reductor = getattr(x, "__reduce__", None)
TypeError: can't pickle _thread._local objects
Есть ли способ обойти эту ошибку?