Я использую Estimator API tenorflow и хотел бы создать пользовательские пакеты для обучения.
У меня есть примеры, которые выглядят следующим образом
example1 = {
"num_sentences": 3,
"sentences": [[1, 2], [3, 4], [5, 6]]
}
example2 = {
"num_sentences": 2,
"sentences": [[1, 2], [3, 4]]
}
Так что пример может иметь любойколичество предложений фиксированного размера.Теперь я хотел бы создать пакеты, размер которых зависит от количества предложений в пакете.В противном случае мне придется использовать размер пакета 1, так как в некоторых примерах могут быть предложения «размера пакета», а большой размер пакета не помещается в память GPU.
Например: у меня размер пакета 6 и примерыс количеством предложений [5, 3, 3, 2, 2, 1].Затем я группирую примеры в партии [5], [3, 3] и [2, 2, 1].Обратите внимание, что пример "1" в последнем пакете будет дополнен.
Я написал алгоритм, который группирует примеры в такие партии.Теперь я не могу передать пакеты в tf.data.Dataset.
Я пытался использовать tf.data.Dataset.from_generator
, но метод, похоже, ожидает отдельных примеров, и я получаю сообщение об ошибке, если генератор выдает пакеты, подобные [example1, пример2].
Как я могу кормить набор данных пользовательскими партиями?Есть ли более элегантный способ решения моей проблемы?
Обновление: я предполагаю, что не могу правильно указать параметр выходных фигур.Следующий код работает нормально.
import tensorflow as tf
def gen():
for b in range(3):
#yield [{"num_sentences": 3, "sentences": [[1, 2], [3, 4], [5, 6]]}]
yield {"num_sentences": 3, "sentences": [[1, 2], [3, 4], [5, 6]]}
dataset = tf.data.Dataset.from_generator(generator=gen,
output_types={'num_sentences': tf.int32, 'sentences': tf.int32},
#output_shapes=tf.TensorShape([None, {'num_sentences': tf.TensorShape(None), 'sentences': tf.TensorShape(None)}])
output_shapes={'num_sentences': tf.TensorShape(None), 'sentences': tf.TensorShape(None)}
)
def print_dataset(dataset):
it = dataset.make_one_shot_iterator()
with tf.Session() as sess:
print(dataset.output_shapes)
print(dataset.output_types)
while True:
try:
data = it.get_next()
print("data" + str(sess.run(data)))
except tf.errors.OutOfRangeError:
break
print_dataset(dataset)
Если я вместо этого получу массив и раскомментирую output_shapes, я получу сообщение об ошибке: аргумент int () должен быть строкой, байтовым объектом или числом, а не 'dict '"