Есть ли у тензорного потока аналогичный оператор, такой как torch.scatter_add, чтобы мы могли выполнять пакетную обработку unsorted_segment_sum? - PullRequest
0 голосов
/ 12 июля 2020

Я хотел бы использовать пакетную версию tf.math.unsorted_segment_sum, но эта операция имеет только 1D версию. Чтобы выполнить unsorted_segment_sum для пакетных данных (форма данных, например [batch_size, 200, 32], форма segment_ids, такая как [batch_size, 200]), мне нужно сделать что-то вроде ниже

def unsorted_segment_sum(data, segment_ids, num_segments):  
  num_rows = tf.shape(segment_ids)[0]  
  rows_idx = tf.range(num_rows)  
  rows_idx = tf.cast(rows_idx, segment_ids.dtype)  
  segment_ids_per_row = segment_ids + num_segments * tf.expand_dims(rows_idx, axis=1)  
  num_segments_ = tf.cast(num_segments * num_rows, segment_ids.dtype)  
  seg_sums = tf.math.unsorted_segment_sum(data, segment_ids_per_row, num_segments_)  
  result = tf.reshape(seg_sums, [-1, num_segments, get_shape(data, -1)])  
  return result  

Но это выполняется медленно поскольку его сложность пропорциональна размеру партии. Для torch мы можем использовать scatter_add, который намного быстрее для больших пакетов.

def unsorted_segment_sum(data, segment_ids, num_segments):  
  segment_ids = torch.repeat_interleave(segment_ids.unsqueeze(-1), repeats=data.shape[-1], dim=-1)   
  shape = [data.shape[0], num_segments] + list(data.shape[2:])  
  tensor = torch.zeros(*shape).scatter_add(1, segment_ids, data)  
  return tensor  

Итак, можем ли мы использовать аналогичную операцию, например torch.scatter_add в tf?

...