Keras: пользовательский слой с умножением матрицы дает ошибку измерения, когда она должна быть правильной - PullRequest
0 голосов
/ 11 февраля 2019

У меня есть пользовательские слои keras, которые принимают несколько векторов одинакового размера (например, список из 3 входных векторов, каждый длиной 10. В keras форма каждого входного вектора будет (?, 10).)

В пользовательском слое под секцией вызова я сначала складываю 3 вектора для формирования фигуры (?, 3, 10), где каждый вектор становится вектором строки, а 3 вектора объединяются в матрицу, x (исключая размер партии).

Затем x умножается на весовую матрицу w, которая имеет размер (3,3) без размера партии.Весовая матрица определена в части построения пользовательского слоя.

Результат y переставлен, чтобы снова сделать размерность пакета первым измерением.

Наконец, слой должен вывести 3 векторатой же длины, что и исходный ввод.Поэтому я нарезаю вдоль оси = 1, чтобы получить 3 тензора, каждый из которых имеет одинаковый размер (?, 10).

Я опробовал тестовый пример, и, похоже, он работает.Но когда я вызываю модель и получаю строку для model.summary (), она выдает следующую ошибку: ValueError: Размеры должны быть равны, но равны 3 и 0 для «add» (оп: «Add») с входными фигурами:[3,3], [0].

Я пробовал различные решения, включая K.batch_dot (), но для batch_dot () я не смог заставить его работать из-за ошибок размеров ...

Спасибо за вашу помощь!

Решено

Заменить

self.trainable_weights = self._w

на

self.trainable_weights.append(self._w)

Фу

# Test Case
import keras.backend as K
a = K.variable(np.array([[1,2,3],[4,5,6],[7,8,9]]))
b = K.variable(np.repeat(np.array([[1,1,10,1,1],[2,2,20,2,2],[3,3,30,3,3]])[np.newaxis,:],repeats=10,axis=0))
c = K.dot(a,b)
c = K.permute_dimensions(c,pattern=(1,0,2))
y = K.eval(c)
print(y)
print(c.shape)  # (10, 3, 5)

# Custom layer build part
def build(self, input_shape):
    # input_shape should be a list, since cross stitch must take in inputs from all the individual tasks.
    self._input_count = len(input_shape)
    w = np.identity(self._input_count) * 0.9
    inverse_diag_mask = np.invert(np.identity(self._input_count, dtype=np.bool))
    off_value = 0.1 / (self._input_count - 1)
    w[inverse_diag_mask] = off_value
    self._w = K.variable(np.array(w))
    self.trainable_weights = self._w

    super(CrossStitchLayer, self).build(input_shape)

# Custom layer call part
def call(self, x, **kwargs):
    temp = x  # to show shape
    x = K.stack(x, axis=1)
    y1 = K.dot(self._w, x)
    y = K.permute_dimensions(y1, pattern=(1, 0, 2))
    results = []
    for idx in range(self._input_count):
        results.append(y[:, idx, :])
    return results

Полное сообщение об ошибке:

Трассировка (последний последний вызов): файл "C: \ Users \ limka \ Anaconda3 \ envs \ my-rdkit-env \ lib \ site-packages \ tenorflow \ python \ framework\ ops.py ", строка 1628, в _create_c_op c_op = c_api.TF_FinishOperation (op_desc) tenorflow.python.framework.errors_impl.InvalidArgumentError: Размеры должны быть равными, но равны 3 и 0 для« добавления »(op:« Добавить »)с формами ввода: [3,3], [0].

Во время обработки вышеупомянутого исключения произошло другое исключение:

Traceback (most recent call last):
  File "C:/Users/limka/Desktop/Python/strain_sensor/run_cross_validation.py", line 10, in <module>
    k_folds=10, k_shuffle=True, save_model=False, save_model_name=None, save_model_dir='./save/models/')
  File "C:\Users\limka\Desktop\Python\strain_sensor\own_package\cross_validation.py", line 57, in run_skf
    model = MTmodel(fl=ss_fl, mode=model_mode, hparams=hparams, labels_norm=True)
  File "C:\Users\limka\Desktop\Python\strain_sensor\own_package\models.py", line 217, in __init__
    cs_model = cross_stitch(self.features_dim, self.labels_dim, self.hparams)
  File "C:\Users\limka\Desktop\Python\strain_sensor\own_package\models.py", line 193, in cross_stitch
    model.summary()
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\keras\engine\network.py", line 1260, in summary
    print_fn=print_fn)
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\keras\utils\layer_utils.py", line 166, in print_summary
    print_layer_summary_with_connections(layers[i])
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\keras\utils\layer_utils.py", line 153, in print_layer_summary_with_connections
    layer.count_params(),
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\keras\engine\base_layer.py", line 1129, in count_params
    return count_params(self.weights)
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\keras\engine\base_layer.py", line 1022, in weights
    return self.trainable_weights + self.non_trainable_weights
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\tensorflow\python\ops\variables.py", line 856, in _run_op
    return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\tensorflow\python\ops\math_ops.py", line 878, in binary_op_wrapper
    return func(x, y, name=name)
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 300, in add
    "Add", x=x, y=y, name=name)
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\tensorflow\python\util\deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\tensorflow\python\framework\ops.py", line 3274, in create_op
    op_def=op_def)
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\tensorflow\python\framework\ops.py", line 1792, in __init__
    control_input_ops)
  File "C:\Users\limka\Anaconda3\envs\my-rdkit-env\lib\site-packages\tensorflow\python\framework\ops.py", line 1631, in _create_c_op
    raise ValueError(str(e))
ValueError: Dimensions must be equal, but are 3 and 0 for 'add' (op: 'Add') with input shapes: [3,3], [0].
...