Если бы мы могли собрать элементы в стиль и форму слоев Conv, мы могли бы использовать обычные свертки.
Сбор может быть выполнен с использованием этого слоя Keras, который использует сборку tenorflow.
class GatherFromIndices(Layer):
"""
To have a graph convolution (over a fixed/fixed degree kernel) from a given sequence of nodes, we need to gather
the data of each node's neighbours before running a simple Conv1D/conv2D,
that would be effectively a defined convolution (or even TimeDistributed(Dense()) can be used - only
based on data format we would output).
This layer should do exactly that.
Does not support non integer values, values lesser than 0 zre automatically masked.
"""
def __init__(self, mask_value=0, include_self=True, flatten_indices_features=False, **kwargs):
Layer.__init__(self, **kwargs)
self.mask_value = mask_value
self.include_self = include_self
self.flatten_indices_features = flatten_indices_features
def get_config(self):
config = {'mask_value': self.mask_value,
'include_self': self.include_self,
'flatten_indices_features': self.flatten_indices_features,
}
base_config = super(GatherFromIndices, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
#def build(self, input_shape):
#self.built = True
def compute_output_shape(self, input_shape):
inp_shape, inds_shape = input_shape
indices = inds_shape[-1]
if self.include_self:
indices += 1
features = inp_shape[-1]
if self.flatten_indices_features:
return tuple(list(inds_shape[:-1]) + [indices * features])
else:
return tuple(list(inds_shape[:-1]) + [indices, features])
def call(self, inputs, training=None):
inp, inds = inputs
# assumes input in the shape of (inp=[...,batches, sequence_len, features],
# inds = [...,batches,sequence_ind_len, neighbours]... indexing into inp)
# for output we want to get [...,batches,sequence_ind_len, indices,features]
assert_shapes = tf.Assert(tf.reduce_all(tf.equal(tf.shape(inp)[:-2], tf.shape(inds)[:-2])), [inp])
assert_positive_ins_shape = tf.Assert(tf.reduce_all(tf.greater(tf.shape(inds), 0)), [inds])
# the shapes need to be the same (with the exception of the last dimension)
with tf.control_dependencies([assert_shapes, assert_positive_ins_shape]):
inp_shape = tf.shape(inp)
inds_shape = tf.shape(inds)
features_dim = -1
# ^^ todo for future variablility of the last dimension, because maybe can be made to take not the last
# dimension as features, but something else.
inp_p = tf.reshape(inp, [-1, inp_shape[features_dim]])
ins_p = tf.reshape(inds, [-1, inds_shape[features_dim]])
# we have lost the batchdimension by reshaping, so we save it by adding the size to the respective indexes
# we do it because we use the gather_nd as nonbatched (so we do not need to provide batch indices)
resized_range = tf.range(tf.shape(ins_p)[0])
different_seqs_ids_float = tf.scalar_mul(1.0 / tf.to_float(inds_shape[-2]), tf.to_float(resized_range))
different_seqs_ids = tf.to_int32(tf.floor(different_seqs_ids_float))
different_seqs_ids_packed = tf.scalar_mul(inp_shape[-2], different_seqs_ids)
thseq = tf.expand_dims(different_seqs_ids_packed, -1)
# in case there are negative indices, make them all be equal to -1
# and add masking value to the ending of inp_p - that way, everything that should be masked
# will get the masking value as features.
mask = tf.greater_equal(ins_p, 0) # extract where minuses are, because the will all default to default value
# .. before the mod operation, if provided greater id numbers, to wrap correctly small sequences
offset_ins_p = tf.mod(ins_p, inp_shape[-2]) + thseq # broadcast to ins_p
minus_1 = tf.scalar_mul(tf.shape(inp_p)[0], tf.ones_like(mask, dtype=tf.int32))
'''
On GPU, if we use index = -1 anywhere it would throw a warning:
OP_REQUIRES failed at gather_nd_op.cc:50 : Invalid argument:
flat indices = [-1] does not index into param.
Which is a warning, that there are -1s. We are using that as feature and know about that.
'''
offset_ins_p = tf.where(mask, offset_ins_p, minus_1)
# also possible to do something like tf.multiply(offset_ins_p, mask) + tf.scalar_mul(-1, mask)
mask_value_last = tf.zeros((inp_shape[-1],))
if self.mask_value != 0:
mask_value_last += tf.constant(self.mask_value) # broadcasting if needed
inp_p = tf.concat([inp_p, tf.expand_dims(mask_value_last, 0)], axis=0)
# expand dims so that it would slice n times instead having slice of length n indices
neighb_p = tf.gather_nd(inp_p, tf.expand_dims(offset_ins_p, -1)) # [-1,indices, features]
out_shape = tf.concat([inds_shape, inp_shape[features_dim:]], axis=-1)
neighb = tf.reshape(neighb_p, out_shape)
# ^^ [...,batches,sequence_len, indices,features]
if self.include_self: # if is set, add self at the 0th position
self_originals = tf.expand_dims(inp, axis=features_dim-1)
# ^^ [...,batches,sequence_len, 1, features]
neighb = tf.concat([neighb, self_originals], axis=features_dim-1)
if self.flatten_indices_features:
neighb = tf.reshape(neighb, tf.concat([inds_shape[:-1], [-1]], axis=-1))
return neighb
С отлаживаемым интерактивным тестом:
def allow_tf_debug(func):
"""
Decorator for tests that use tensorflow, to make them more breakpoint-friendly, i.e. to be able to call .eval()
on tensors immediately.
"""
def interactive_wrapper():
sess = tf.InteractiveSession()
ret = func()
sess.close()
return ret
return interactive_wrapper
@allow_tf_debug
def test_gather_from_indices():
gat = GatherFromIndices(include_self=False, flatten_indices_features=False)
# test for include_self=True is not included
# test for flatten_indices_features not included
seq = [ # batch of sequences
# sequences of 2d features
[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8]],
[[10, 1], [11, 2], [12, 3], [13, 4], [14, 5], [15, 6], [16, 7], [17, 8]]
]
ids = [ # batch of sequences
# sequences of 3 ids of each item in sequence
[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [5, 5, 5], [6, 6, 6], [7, 7, 7]],
[[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [5, 6, 7], [6, 7, 0], [7, 0, -1]]
# minus one should mean masking
]
def compute_assert_2ways_gathers(seq, ids):
seq = np.array(seq, dtype=np.float32)
ids = np.array(ids, dtype=np.int32)
# intended_look
result_np = None
if len(ids.shape) == 3: # classical batches
result_np = np.empty(list(ids.shape) + [seq.shape[-1]])
for b, seq_in_batch in enumerate(ids):
for i, sid in enumerate(seq_in_batch):
for c, copyid in enumerate(sid):
assert ids[b,i,c] == copyid
if ids[b,i,c] < 0:
result_np[b, i, c, :] = 0
else:
result_np[b, i, c, :] = seq[b, ids[b,i,c], :]
elif len(ids.shape) == 4: # some other batching format...
result_np = np.empty(list(ids.shape) + [seq.shape[-1]])
for mb, mseq_in_batch in enumerate(ids):
for b, seq_in_batch in enumerate(mseq_in_batch):
for i, sid in enumerate(seq_in_batch):
for c, copyid in enumerate(sid):
assert ids[mb, b, i, c] == copyid
if ids[mb, b, i, c] < 0:
result_np[mb, b, i, c, :] = 0
else:
result_np[mb, b, i, c, :] = seq[mb, b, ids[mb, b, i, c], :]
output_shape_kerascomputed = gat.compute_output_shape([seq.shape, ids.shape])
assert isinstance(output_shape_kerascomputed, tuple)
assert list(output_shape_kerascomputed) == list(result_np.shape)
#with tf.get_default_session() as sess:
sess = tf.get_default_session()
gat.build(seq.shape)
result = gat.call([tf.constant(seq), tf.constant(ids)])
tf_result = sess.run(result)
assert list(tf_result.shape) == list(output_shape_kerascomputed)
assert np.all(np.equal(tf_result, result_np))
compute_assert_2ways_gathers(seq, ids)
compute_assert_2ways_gathers(seq * 5, ids * 5)
compute_assert_2ways_gathers([seq] * 3, [ids] * 3)
И пример использования для 5 соседей на узел:
fields_input = Input(shape=(None, 10, name='nodedata')
neighbours_ids_input = Input(shape=(None, 5), name='nodes_neighbours_ids', dtype='int32')
fields_input_with_neighbours = GatherFromIndices(mask_value=0,
include_self=True, flatten_indices_features=True)\
([fields_input, neighbours_ids_input])
fields = Conv1D(128, kernel_size=5, padding='same',
activation='relu')(fields_input_with_neighbours) # data_format="channels_last"