Я пытаюсь добавить пользовательский слой в модель Keras, поэтому я написал свой код в соответствии с руководством TensorFlow https://www.tensorflow.org/guide/keras/custom_layers_and_models#layers_are_recursively_composable.
Но я встретил эту ошибку: AttributeError: 'tuple 'object не имеет атрибута' layer '.
Вот мои пользовательские слои, и я хотел бы, чтобы мои слои могли быть рекурсивно компонованными.
class Hw_Flatten(keras.layers.Layer):
def __init__(self):
super(Hw_Flatten, self).__init__()
def call(self, inputs, **kwargs):
return tf.reshape(inputs, shape=[inputs.shape[0], -1, inputs.shape[-1]])
class Max_Pooling(keras.layers.Layer):
def __init__(self):
super(Max_Pooling, self).__init__()
def call(self, inputs, **kwargs):
return tf.layers.max_pooling2d(inputs, pool_size=2, strides=2, padding='SAME')
class Convolution(keras.layers.Layer):
def __init__(self, use_bias=True):
super(Convolution, self).__init__()
self.use_bias = use_bias
self.weight_init = tf_contrib.layers.xavier_initializer()
self.weight_regularizer = None
self.weight_regularizer_fully = None
def call(self, inputs, channels, kernel=4, stride=2, **kwargs):
return tf.layers.conv2d(inputs=inputs, filters=channels,
kernel_size=kernel, kernel_initializer=self.weight_init,
kernel_regularizer=self.weight_regularizer,
strides=stride, use_bias=self.use_bias)
class google_attention(keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
super(google_attention, self).__init__(**kwargs)
self.shape = (64,64,128)
self.channels = 1024
self.name = 'attention'
self.output_dim = output_dim
self.conv = Convolution()
self.max_pooling = Max_Pooling()
self.hw_flatten = Hw_Flatten()
def build(self, input_shape): #add weight
super(google_attention, self).build(input_shape)
self.gamma = K.variable([0.0]) # tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
self.trainable_weights = [self.gamma]
def call(self, inputs, **kwargs):
f = self.conv(inputs, channels=self.channels // 8, kernel=1, stride=1) # [bs, h, w, c']
f = self.max_pooling(f)
g = self.conv(inputs, channels=self.channels // 8, kernel=1, stride=1) # [bs, h, w, c']
h = self.conv(inputs, channels=self.channels // 2, kernel=1, stride=1) # [bs, h, w, c]
h = self.max_pooling(h)
# N = h * w
s = tf.matmul(self.hw_flatten(g), self.hw_flatten(f), transpose_b=True) # # [bs, N, N]
beta = tf.nn.softmax(s) # attention map
o = tf.matmul(beta, self.hw_flatten(h)) # [bs, N, C]
o = tf.reshape(o, shape=[-1, self.shape[0], self.shape[1], self.shape[2] // 2]) # [bs, h, w, C]
o = self.conv(o, channels=self.output_dim, kernel=1, stride=1)
x = self.gamma * o + inputs
return x
inputs = Input((img_cols, img_rows, IN_CH))
e1 = BatchNormalization()(inputs)
e1 = Convolution2D(64, 4, 4, subsample=(2, 2), activation='relu', init='uniform', border_mode='same')(e1)
e1 = BatchNormalization()(e1)
e2 = Convolution2D(128, 4, 4, subsample=(2, 2), activation='relu', init='uniform', border_mode='same')(e1)
e2 = BatchNormalization()(e2)
atten = google_attention(128)
e2 = atten(e2)
model = Model(input=inputs, output=e2)
Полная ошибкаis:
tracking <tf.Variable 'attention/Variable:0' shape=(1,) dtype=float32> gamma
WARNING:tensorflow:From E:\elts\cgan\custom_layer.py:110: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.
Traceback (most recent call last):
File "E:/elts/cgan/custome_class.py", line 376, in <module>
train(BATCH_SIZE)
File "E:/elts/cgan/custome_class.py", line 277, in train
generator = generator_model()
File "E:/elts/cgan/custome_class.py", line 144, in generator_model
e2 = atten(e2)
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\keras\engine\base_layer.py", line 489, in __call__
output = self.call(inputs, **kwargs)
File "E:\elts\cgan\custom_layer.py", line 130, in call
f = self.conv(inputs, channels=self.channels // 8, kernel=1, stride=1) # [bs, h, w, c']
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\keras\engine\base_layer.py", line 489, in __call__
output = self.call(inputs, **kwargs)
File "E:\elts\cgan\custom_layer.py", line 110, in call
strides=stride, use_bias=self.use_bias)
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\util\deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\layers\convolutional.py", line 424, in conv2d
return layer.apply(inputs)
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1479, in apply
return self.__call__(inputs, *args, **kwargs)
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\layers\base.py", line 537, in __call__
outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 663, in __call__
inputs, outputs, args, kwargs)
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1708, in _set_connectivity_metadata_
input_tensors=inputs, output_tensors=outputs, arguments=kwargs)
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1795, in _add_inbound_node
input_tensors)
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\util\nest.py", line 515, in map_structure
structure[0], [func(*x) for x in entries],
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\util\nest.py", line 515, in <listcomp>
structure[0], [func(*x) for x in entries],
File "E:\elts\Self-Attention-GAN-Tensorflow-master\venv\Include\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1794, in <lambda>
inbound_layers = nest.map_structure(lambda t: t._keras_history.layer,
AttributeError: 'tuple' object has no attribute 'layer'
Любой комментарий или предложение высоко ценится. Спасибо !!!