Env: Python3 .5 Ubuntu16.04 tenorflow-gpu == 1.13.1 Когда я загружаю модель и запускаю код, я получаю следующую ошибку:
Using latest checkpoint at saved_model/ckpt-26
WARNING:tensorflow:From /home/frank/Desktop/mesh-py3/my_venv/lib/python3.5/site-packages/tensorflow/python/ops/resource_variable_ops.py:642: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
Traceback (most recent call last):
File "/home/frank/PycharmProjects/MultiGarmentNetwork/test_network.py", line 189, in <module>
pred = get_results(m, dat)
File "/home/frank/PycharmProjects/MultiGarmentNetwork/test_network.py", line 54, in get_results
out = m([images, vertex_label, J_2d])
File "/home/frank/Desktop/mesh-py3/my_venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 592, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/home/frank/PycharmProjects/MultiGarmentNetwork/network/base_network.py", line 338, in call
garm_model_outputs = [fe(latent_code_offset_ShapeMerged) for fe in self.garmentModels]
File "/home/frank/PycharmProjects/MultiGarmentNetwork/network/base_network.py", line 338, in <listcomp>
garm_model_outputs = [fe(latent_code_offset_ShapeMerged) for fe in self.garmentModels]
File "/home/frank/Desktop/mesh-py3/my_venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 592, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/home/frank/PycharmProjects/MultiGarmentNetwork/network/base_network.py", line 66, in call
x = self.PCA_(pca_comp)
File "/home/frank/Desktop/mesh-py3/my_venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 592, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/home/frank/PycharmProjects/MultiGarmentNetwork/network/custom_layers.py", line 33, in call
return tf.reshape(tf.matmul(x, self.components) + self.mean, (-1, K.int_shape(self.mean)[0] / 3, 3))
File "/home/frank/Desktop/mesh-py3/my_venv/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 7161, in reshape
tensor, shape, name=name, ctx=_ctx)
File "/home/frank/Desktop/mesh-py3/my_venv/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 7206, in reshape_eager_fallback
ctx=_ctx, name=name)
File "/home/frank/Desktop/mesh-py3/my_venv/lib/python3.5/site-packages/tensorflow/python/eager/execute.py", line 66, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Value for attr 'Tshape' of float is not in the list of allowed values: int32, int64
; NodeDef: {{node Reshape}}; Op<name=Reshape; signature=tensor:T, shape:Tshape -> output:T; attr=T:type; attr=Tshape:type,default=DT_INT32,allowed=[DT_INT32, DT_INT64]> [Op:Reshape]
Process finished with exit code 1
Проблема начинается с
model_dir = 'saved_model/'
## Load model
## I need to download it.
m = load_model(model_dir)
## Load test data
dat = pkl.load(open('assets/test_data.pkl','rb'),encoding='latin1')
## Get results before optimization
pred = get_results(m, dat)
и для функции 'get_results':
def get_results(m, inp, with_pose = False):
images = [inp['image_{}'.format(i)].astype('float32') for i in range(NUM)]
J_2d = [inp['J_2d_{}'.format(i)].astype('float32') for i in range(NUM)]
vertex_label = inp['vertexlabel'].astype('int64')
out = m([images, vertex_label, J_2d])
, а затем она переходит в: out = m ([images, vertex_label, J_2d])
File "/home/frank/PycharmProjects/MultiGarmentNetwork/test_network.py", line 54, in get_results
out = m([images, vertex_label, J_2d])
File "/home/frank/Desktop/mesh-py3/my_venv/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 592, in __call__
outputs = self.call(inputs, *args, **kwargs)
Я думаю, это может быть вызвано форматом данных или версией TF. Но я не знаю, как это исправить. Кто-нибудь может мне помочь?
Заранее спасибо!