Я хотел бы использовать пакетную версию tf.math.unsorted_segment_sum, но эта операция имеет только 1D версию. Чтобы выполнить unsorted_segment_sum для пакетных данных (форма данных, например [batch_size, 200, 32], форма segment_ids, такая как [batch_size, 200]), мне нужно сделать что-то вроде ниже
def unsorted_segment_sum(data, segment_ids, num_segments):
num_rows = tf.shape(segment_ids)[0]
rows_idx = tf.range(num_rows)
rows_idx = tf.cast(rows_idx, segment_ids.dtype)
segment_ids_per_row = segment_ids + num_segments * tf.expand_dims(rows_idx, axis=1)
num_segments_ = tf.cast(num_segments * num_rows, segment_ids.dtype)
seg_sums = tf.math.unsorted_segment_sum(data, segment_ids_per_row, num_segments_)
result = tf.reshape(seg_sums, [-1, num_segments, get_shape(data, -1)])
return result
Но это выполняется медленно поскольку его сложность пропорциональна размеру партии. Для torch мы можем использовать scatter_add, который намного быстрее для больших пакетов.
def unsorted_segment_sum(data, segment_ids, num_segments):
segment_ids = torch.repeat_interleave(segment_ids.unsqueeze(-1), repeats=data.shape[-1], dim=-1)
shape = [data.shape[0], num_segments] + list(data.shape[2:])
tensor = torch.zeros(*shape).scatter_add(1, segment_ids, data)
return tensor
Итак, можем ли мы использовать аналогичную операцию, например torch.scatter_add в tf?