Аргумент «Состояния» отсутствует в пользовательской модели с использованием пользовательского слоя RNN - PullRequest
1 голос
/ 25 февраля 2020

Я создаю свой собственный слой в Tensorflow 2.1 и использую его в пользовательской модели. В приведенном ниже примере я скопировал код MinimalRNNCell с веб-сайта tenorflow (https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN), и я пытаюсь использовать этот слой в своей модели.

Однако при попытке подгонять модель я Я получаю сообщение о том, что метод вызова ячейки требует аргумента "состояния", а я его не предоставляю.

Как мне исправить мою модель, указав этот аргумент?

Мой код:

import tensorflow as tf 
from tensorflow.keras.layers import Layer
from tensorflow.keras import Model
import numpy as np

class MinimalRNNCell(Layer):

    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = units
        super(MinimalRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer='uniform',
                                      name='kernel')
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer='uniform',
            name='recurrent_kernel')
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = K.dot(inputs, self.kernel)
        output = h + K.dot(prev_output, self.recurrent_kernel)
        return output, [output]


class RNNXModel(Model):
    def __init__(self, size):
        super(RNNXModel, self).__init__()
        self.minimalrnn=MinimalRNNCell(size)

    def call(self, inputs):
        out=self.minimalrnn(inputs)
        return out


x=np.array([[[1,2,3],[4,5,6],[7,8,9]],[[10,11,12],[13,14,15],[16,17,18]]])
y=np.array([[1,2,3],[10,11,12]])

model=RNNXModel(3)
model.compile(optimizer='sgd', loss='mse')
model.fit(x,y,epochs=10, batch_size=1)

Ошибка, которую я получаю:

Traceback (most recent call last):
  File "/home/.../test.py", line 64, in <module>
    model.fit(x,y,epochs=10, batch_size=1)
  File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 819, in fit
    use_multiprocessing=use_multiprocessing)
  File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 235, in fit
    use_multiprocessing=use_multiprocessing)
  File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 593, in _process_training_inputs
    use_multiprocessing=use_multiprocessing)
  File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 646, in _process_inputs
    x, y, sample_weight=sample_weights)
  File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 2346, in _standardize_user_data
    all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
  File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 2572, in _build_model_with_inputs
    self._set_inputs(cast_inputs)
  File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 2659, in _set_inputs
    outputs = self(inputs, **kwargs)
  File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 773, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/autograph/impl/api.py", line 237, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in converted code:

    /home/.../test.py:36 call  *
        out=self.minimalrnn(inputs)
    /home/.../.venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py:773 __call__
        outputs = call_fn(cast_inputs, *args, **kwargs)

    TypeError: tf__call() missing 1 required positional argument: 'states'

1 Ответ

0 голосов
/ 28 февраля 2020

Благодаря Susmit Agrawal я пришел с этим, и он работает:

class MinimalRNNCell(AbstractRNNCell):

    def __init__(self, units, **kwargs):
      self.units = units
      super(MinimalRNNCell, self).__init__(**kwargs)

    @property
    def state_size(self):
      return self.units

    def build(self, input_shape):
      self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                    initializer='uniform',
                                    name='kernel')
      self.recurrent_kernel = self.add_weight(
          shape=(self.units, self.units),
          initializer='uniform',
          name='recurrent_kernel')
      self.built = True

    def call(self, inputs, states):
      prev_output = states[0]
      h = K.dot(inputs, self.kernel)
      output = h + K.dot(prev_output, self.recurrent_kernel)
      return output, output


class RNNXModel(Model):
    def __init__(self, size):
        super(RNNXModel, self).__init__()
        self.minimalrnn=RNN(MinimalRNNCell(size))

    def call(self, inputs):
        out=self.minimalrnn(inputs)
        return out


x=np.array([[[1,2,3],[4,5,6],[7,8,9]],[[10,11,12],[13,14,15],[16,17,18]]])
y=np.array([[1,2,3],[10,11,12]])

model=RNNXModel(3)
model.compile(optimizer='sgd', loss='mse')
model.fit(x,y,epochs=10, batch_size=1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...