Я хочу добавить пакет текста одинаковой длины, сгенерировать идентификатор сегмента, вектор маски, а затем передать их в модель bert. В pytorch я могу использовать collate_fn
, как показано ниже.
def collate_fn(self, batch):
rows = self.df.iloc[batch] # take a batch of data
ids, seg_ids = self.get_ids_segs(rows) # process data
attention_mask = (ids > 0)
return ids, seg_ids,attention_mask
Но в тензорном потоке данные передаются кортежем матрицы, поэтому весь текст дополняется до максимальной длины 512.
# ids.shape = seg_ids = attention_mask = (data_number, max_seq_len)
xs = (ids, seg_ids, attention_mask)
model.fit(xs,, ys, batch_size=batch_size)
Я обнаружил, tf.data.dataset
имеет функцию padded_batch . Но он может заполнить только один вход, у меня есть 3 входных данных, ids
, seq_ids
, attn_mask
.