Обновление tf.train.string_input_producer до tf.data.Dataset.from_tensor_slices создает ошибку - PullRequest
0 голосов
/ 05 июля 2019

Как говорится на официальном сайте TensorFlow:

tf.train.string_input_producer is deprecated and will be removed in a future version * * 1004

Я пытался заменить его предложенным методом tf.data.Dataset.from_tensor_slices().

Но после обновления я получаю следующую ошибку:

TypeError: Tensors in list passed to 'input' of 'PyFunc' Op have types [<NOT CONVERTIBLE TO TENSOR>] that are invalid. Tensors: [<DatasetV1Adapter shapes: (), types: tf.string>]

Код выглядит следующим образом:

with tf.device("/cpu:0"), tf.name_scope(scope):
    '''This is the correct but deprecated version'''
    input_ops['id'] = tf.train.string_input_producer(
       tf.convert_to_tensor(data_id), capacity=128
    ).dequeue(name='input_ids_dequeue')

    ''' The following replaced code creates an error

    input_ops['id'] = tf.data.Dataset.from_tensor_slices(
        tf.convert_to_tensor(data_id)
    ).shuffle(128)

    '''

    img, q, a = dataset.get_data(data_id[0])

    def load_fn(id):
        # image [n, n], q: [m], a: [l]
        img, q, a = dataset.get_data(id)
        return (id, img.astype(np.float32), q.astype(np.float32),
                a.astype(np.float32))

    input_ops['id'], input_ops['img'], input_ops['q'], input_ops['a'] = \
        tf.py_func(
            load_fn,
            inp=[input_ops['id']],
            Tout=[tf.string, tf.float32, tf.float32, tf.float32],
            name='func'
    )

Здесь data_id - это список некоторых n чисел. Может ли кто-нибудь помочь мне с этим?

Спасибо и всего наилучшего.

...