тензор потока, как дополнить пакетный текст, как 'collate_fn' Pytorch? - PullRequest
0 голосов
/ 16 января 2020

Я хочу добавить пакет текста одинаковой длины, сгенерировать идентификатор сегмента, вектор маски, а затем передать их в модель bert. В pytorch я могу использовать collate_fn, как показано ниже.

def collate_fn(self, batch):
    rows = self.df.iloc[batch] # take a batch of data
    ids, seg_ids = self.get_ids_segs(rows) # process data
    attention_mask = (ids > 0)
    return ids, seg_ids,attention_mask

Но в тензорном потоке данные передаются кортежем матрицы, поэтому весь текст дополняется до максимальной длины 512.

# ids.shape = seg_ids = attention_mask = (data_number, max_seq_len) 
xs = (ids, seg_ids, attention_mask)

model.fit(xs,, ys, batch_size=batch_size)

Я обнаружил, tf.data.dataset имеет функцию padded_batch . Но он может заполнить только один вход, у меня есть 3 входных данных, ids, seq_ids, attn_mask.

...