Я пытаюсь реализовать maxpooling1D с маскированием для данных последовательности (текста). Например, если у нас максимальная длина 5, маска может быть [1 1 1 0 0] с 1/0 = с / без токена.
Я смотрю на исходный код среднего пула, который является включено с TF 2.0 (https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/layers/pooling.py#L604 -L651 ). К сожалению, для MaxPool1D нет маскировки, поэтому я пытаюсь реализовать маску на основе AvgPool1D. Я застрял с проблемой во время обучения, и я не знаю, что происходит. Если я изменю свой globalmaxpooling1DM, задаваемый регулярным globalmaxpooling, у меня не будет ошибки. Так что проблема исходит от моего текущего класса. Я изменяю значение 0 в маске на -inf, потому что у меня может быть отрицательное значение в моем тензоре. Ввод слоя может иметь форму: (Размер партии, номер последовательности, скрытый размер), а маска: (Размер партии, номер последовательности)
Здесь код:
import tensorflow as tf
from sklearn.metrics import roc_auc_score
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.python.ops import math_ops, array_ops
from tensorflow.python.keras import backend
import numpy as np
class GlobalMaxPooling1DMasked(GlobalMaxPool1D):
"""Global average pooling operation for temporal data.
Arguments:
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, steps, features)` while `channels_first`
corresponds to inputs with shape
`(batch, features, steps)`.
Call arguments:
inputs: A 3D tensor.
mask: Binary tensor of shape `(batch_size, steps)` indicating whether
a given step should be masked (excluded from the average).
Input shape:
- If `data_format='channels_last'`:
3D tensor with shape:
`(batch_size, steps, features)`
- If `data_format='channels_first'`:
3D tensor with shape:
`(batch_size, features, steps)`
Output shape:
2D tensor with shape `(batch_size, features)`.
"""
def __init__(self, data_format='channels_last', **kwargs):
super(GlobalMaxPooling1DMasked, self).__init__(data_format=data_format,
**kwargs)
self.supports_masking = True
def call(self, inputs, mask=None):
steps_axis = 1 if self.data_format == 'channels_last' else 2
if mask is not None:
mask = math_ops.cast(mask, backend.floatx())
mask = array_ops.expand_dims(
mask, 2 if self.data_format == 'channels_last' else 1)
#mask[cond] = -inf
mask = K.log(mask)
inputs = inputs + mask
return backend.max(inputs, axis=steps_axis)
class Custom(tf.keras.Model):
def __init__(self, num_classes=1):
super(Custom, self).__init__()
self.block_1 = Conv1D(32, kernel_size=5, padding="same")
self.dropout = Dropout(0.5)
self.classifier = Dense(num_classes)
def call(self, inputs, training=False):
print(inputs[0].shape, inputs[1].shape)
sequence_output = self.block_1(inputs[0], training=training)
print(sequence_output.shape)
avgpool = GlobalAvgPool1D()(sequence_output, mask=inputs[1])
maxpool = GlobalMaxPooling1DMasked()(sequence_output, mask=inputs[1])
#avg_pool = K.sum(sequence_output*K.expand_dims(K.cast(inputs[1], dtype='float32'), axis=-1), axis=1)/K.sum(K.cast(inputs[1], dtype='float32'), axis=1)
#print(sequence_output.shape, pooled_output.shape)
#x = Concatenate(axis=-1)([pooled_output, avg_pool])
#x = concatenate([avg_pool, pooled_output])
#print("max", maxpool)
#print("avg", avgpool)
#print("out", pooled_output)
print(maxpool.shape, avgpool.shape)
print(maxpool, avgpool)
x = Concatenate()([maxpool, avgpool])
x = self.dropout(x)
print(x.shape)
return self.classifier(x)
X_train = np.random.uniform(0,1, size=(2, 10, 16)).astype(np.float32)
X_mask = np.array([[1,1,1,1,1,1,1,0,0,0],[1,1,1,1,1,1,1,1,1,0]])
print(X_train.shape, X_mask.shape)
model = Custom(1)
preds = model((X_train, X_mask))
Сообщение об ошибке:
InvalidArgumentError: cannot compute ConcatV2 as input #1(zero-based) was expected to be a int32 tensor but is a bool tensor [Op:ConcatV2] name: concat
Как решить эту проблему?