Набор данных Tensorflow 2.0 и загрузчик данных - PullRequest
2 голосов
/ 22 октября 2019

Я пользователь Pytorch, и я привык к API data.dataset и data.dataloader в pytorch. Я пытаюсь построить ту же модель с tenorflow 2.0, и мне интересно, есть ли API, который работает аналогично с этими API в Pytorch.

Если нет таких API, кто-нибудь из вас может сказать мне, как людиобычно делать, чтобы реализовать часть загрузки данных в тензор потока? Я использовал тензор потока 1, но никогда не имел опыта работы с API набора данных. Я жестко закодировал раньше. Я надеюсь, что есть что-то вроде переопределения getitem только с индексом в качестве входа.

Большое спасибо заранее.

Ответы [ 2 ]

2 голосов
/ 22 октября 2019

При использовании tf.data API вы также обычно будете использовать функцию <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map" rel="nofollow noreferrer">map</a>.

В PyTorch ваш вызов __getItem__ в основном выбирает элемент из вашей структуры данных, заданной в __init__ и преобразует его при необходимости.

В TF2.0 вы делаете то же самое, инициализируя <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset" rel="nofollow noreferrer">Dataset</a>, используя одну из функций Dataset.from_... (см. <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator" rel="nofollow noreferrer">from_generator</a>, <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices" rel="nofollow noreferrer">from_tensor_slices</a>, <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors" rel="nofollow noreferrer">from_tensors</a>);по сути это __init__ часть PyTorch Dataset. Затем вы можете вызвать map для выполнения поэлементных манипуляций, которые были бы у вас в __getItem__.

. Наборы данных Tensorflow в значительной степени причудливые итераторы, так что по своей структуре вы не получаете доступ к их элементам, используя индексы,а точнее, обходя их.

Руководство на tf.data очень полезно и предоставляет широкий спектр примеров.

1 голос
/ 22 октября 2019

Я не знаком с Pytorch, но Tensorflow реализует API Keras, который имеет класс Sequence:

Базовый объект для подгонки к последовательности данных, такой как набор данных

https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence

Этот класс содержит getitem для индекса.

...