Мое текущее решение - добавить дополнительный канал, который представляет битовый флаг.После извлечения фрагментов изображения битовый флаг равен 0
для канала с заполнением и 1
для канала без заполнения.
Полное решение:
input_tensor = tf.random.normal([10, 28, 28, 1])
window_shape, strides, padding = (4, 4), (2, 2), 'SAME'
# ----------------------------
bits = tf.ones([tf.shape(input_tensor)[0], input_tensor.shape[1], input_tensor.shape[2], 1])
input_for_patching = tf.concat([input_tensor, bits], axis=-1)
patches = tf.extract_image_patches(input_for_patching, ksizes=(1, *window_shape, 1), strides=(1, *strides, 1), rates=(1, 1, 1, 1), padding=padding)
patches_shape = patches.shape
patches = tf.reshape(patches, [-1, *window_shape, input_tensor.shape[3] + 1])
padding_mask = tf.to_float(tf.reduce_all(tf.equal(patches[:, :, :, -1:], 1.0), [1, 2, 3]))
patches = tf.reshape(patches[:, :, :, :-1], [-1, patches_shape[1], patches_shape[2], window_shape[0] * window_shape[1] * input_tensor.shape[3]])
padding_mask
из приведенного выше кода это то, что мне нужно.
Если у кого-то есть более короткая, более элегантная и / или более интегрированная версия, пожалуйста, не стесняйтесь поделиться.