Keras Двунаправленный пользовательский слой - PullRequest
0 голосов
/ 18 января 2020

Я реализую пользовательский рекуррентный класс, который наследует tf.layers.Layer, при использовании двунаправленной оболочки я получаю ошибку:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-3-7bd5b5269810> in <module>
----> 1 a = TimeDistributed(Bidirectional(char_recurrent_cell))

~/opt/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/wrappers.py in __init__(self, layer, merge_mode, weights, backward_layer, **kwargs)
    434     if backward_layer is None:
    435       self.backward_layer = self._recreate_layer_from_config(
--> 436           layer, go_backwards=True)
    437     else:
    438       self.backward_layer = backward_layer

~/opt/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/wrappers.py in _recreate_layer_from_config(self, layer, go_backwards)
    493     config = layer.get_config()
    494     if go_backwards:
--> 495       config['go_backwards'] = not config['go_backwards']
    496     if 'custom_objects' in tf_inspect.getfullargspec(
    497         layer.__class__.from_config).args:

KeyError: 'go_backwards'

Это код самого слоя:

class RecurrentConfig(BaseLayer):
    '''Basic configurable recurrent layer'''
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.layers: List[layers.Layer] = stack_layers(self.params,
                                                       self.num_layers,
                                                       self.layer_name)

    def call(self, inputs: np.ndarray) -> layers.Layer:
        '''This function is a sequential/functional call to this layers logic
        Args:
            inputs: Array to be processed within this layer
        Returns:
            inputs processed through this layer'''
        processed = inputs
        for layer in self.layers:
            processed = layer(processed)
        return processed

    @staticmethod
    def default_params() -> Dict[Any, Any]:
        return{
            'units': 32,
            'recurrent_initializer': 'glorot_uniform',
            'dropout': 0,
            'recurrent_dropout': 0,
            'activation': None,
            'return_sequences': True
        }

Я попытался добавить go_backwards в конфигурацию, которая получается при вызове get_config (), но это приводит к другой ошибке:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-7bd5b5269810> in <module>
----> 1 a = TimeDistributed(Bidirectional(char_recurrent_cell))

~/opt/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/wrappers.py in __init__(self, layer, merge_mode, weights, backward_layer, **kwargs)
    430     # Recreate the forward layer from the original layer config, so that it will
    431     # not carry over any state from the layer.
--> 432     self.forward_layer = self._recreate_layer_from_config(layer)
    433 
    434     if backward_layer is None:

~/opt/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/wrappers.py in _recreate_layer_from_config(self, layer, go_backwards)
    506       return layer.__class__.from_config(config, custom_objects=custom_objects)
    507     else:
--> 508       return layer.__class__.from_config(config)
    509 
    510   @tf_utils.shape_type_conversion

~/opt/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in from_config(cls, config)
    517         A layer instance.
    518     """
--> 519     return cls(**config)
    520 
    521   def compute_output_shape(self, input_shape):

~/nlpv3-general/nlp-lib/src/main/python/mosaix_py/mosaix_learn/layers/recurrent_layers.py in __init__(self, *args, **kwargs)
     12     '''Basic configurable recurrent layer'''
     13     def __init__(self, *args, **kwargs):
---> 14         super().__init__(*args, **kwargs)
     15         self.layers: List[layers.Layer] = stack_layers(self.params,
     16                                                        self.num_layers,

~/nlpv3-general/nlp-lib/src/main/python/mosaix_py/mosaix_learn/layers/base_layer.py in __init__(self, params, mode, layer_name, num_layers, cust_name, **kwargs)
     17                  cust_name: str = '',
     18                  **kwargs):
---> 19         super().__init__(params, mode, **kwargs)
     20         self.layer_name = layer_name
     21         self.cust_name = cust_name

~/nlpv3-general/nlp-lib/src/main/python/mosaix_py/mosaix_learn/configurable.py in __init__(self, params, mode, **kwargs)
     61 
     62     def __init__(self, params: Dict[AnyStr, Any], mode: ModeKeys, **kwargs):
---> 63         super().__init__(**kwargs) #type: ignore
     64         self._params = _parse_params(params, self.default_params())
     65         self._mode = mode

~/opt/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    455     self._self_setattr_tracking = False  # pylint: disable=protected-access
    456     try:
--> 457       result = method(self, *args, **kwargs)
    458     finally:
    459       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

~/opt/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __init__(self, trainable, name, dtype, dynamic, **kwargs)
    184     }
    185     # Validate optional keyword arguments.
--> 186     generic_utils.validate_kwargs(kwargs, allowed_kwargs)
    187 
    188     # Mutable properties

~/opt/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in validate_kwargs(kwargs, allowed_kwargs, error_message)
    716   for kwarg in kwargs:
    717     if kwarg not in allowed_kwargs:
--> 718       raise TypeError(error_message, kwarg)

TypeError: ('Keyword argument not understood:', 'go_backwards')

Вот небольшой пример, который будет копировать мой Проблема:

import tensorflow as tf
class DummyLayer(tf.keras.layers.Layer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.a = tf.keras.layers.LSTM(2)

    def call(inputs):
        return self.a(inputs)

tf.keras.layers.Bidirectional(DummyLayer())

Информация о версии: tf_version: '2.1.0-dev20191125' git_version: 'v1.12.1-19144-gf39f4ea3fa'

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...