Является ли Tensorflow Dataset.from_generator устаревшим в tenorflow 2.0?Выдает ошибку устаревания tf.py_func - PullRequest
0 голосов
/ 08 мая 2019

Когда я создаю набор данных tf из генератора и пытаюсь запустить код tf2.0, он предупреждает меня сообщением об исключении.

Код:

import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model


def my_function():
    import numpy as np
    for i in range(1000):
        yield np.random.random(size=(28, 28, 1)), [1.0]


train_ds = tf.data.Dataset.from_generator(my_function, output_types=(tf.float32, tf.float32)).batch(32)


class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

    # def __call__(self, *args, **kwargs):
    #     return super().__call(*args,**kwargs)


model = MyModel()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')


@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)


EPOCHS = 5

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)
    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1,
                          train_loss.result(),
                          train_accuracy.result() * 100))

Предупреждающее сообщение:

........
Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, there are two
    options available in V2. ........

Я хотел бы передать данные в модель из потокового ввода, используя API набора данных (с предварительной выборкой).Даже если это все еще возможно в текущей альфа-версии, будет ли она удалена позже?

Заменит ли тензорный поток tf.py_func, использованный в наборе данных генератора, на что-то новое или будет удален весь API-генератор генератора данных dataset_from?

1 Ответ

1 голос
/ 08 мая 2019

Нет, tf.data.Dataset.from_generator не рекомендуется использовать в TensorFlow 2.0. То, что вы видите, является предупреждением, оно используется для информирования пользователей о будущих изменениях. В случае, если вам нужно использовать py_func напрямую, самый простой способ - использовать tf.compat.v1.py_func. TF2.0 имеет свою собственную оболочку, которая называется tf.py_function.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...