tf.function input_signature для распределенного набора данных в tenorflow 2.0 - PullRequest
0 голосов
/ 21 октября 2019

Я пытаюсь построить распределенный пользовательский цикл обучения в TensorFlow 2.0, но я не могу понять, как аннотировать подпись tf.function автографа, чтобы избежать повторного прохождения.

Я пытался использоватьDatasetSpec и различные комбинации кортежей TensorSpec, но я получаю всевозможные ошибки.

Мой вопрос

Можно ли указать входную подпись tf.function, которая принимает пакетныераспределенные наборы данных?

Минимальный код воспроизведения

import tensorflow as tf
from tensorflow import keras
import numpy as np


class SimpleModel(keras.layers.Layer):
    def __init__(self, name='simple_model', **kwargs):
        super(SimpleModel, self).__init__(name=name, **kwargs)
        self.w = self.add_weight(shape=(1, 1),
                                 initializer=tf.constant_initializer(5.0),
                                 trainable=True,
                                 dtype=np.float32,
                                 name='w')

    def call(self, x):
        return tf.matmul(x, self.w)


class Trainer:
    def __init__(self):
        self.mirrored_strategy = tf.distribute.MirroredStrategy()

        with self.mirrored_strategy.scope():
            self.simple_model = SimpleModel()
            self.optimizer = tf.optimizers.Adam(learning_rate=0.01)

    def train_batches(self, dataset):
        dataset_dist = self.mirrored_strategy.experimental_distribute_dataset(dataset)

        with self.mirrored_strategy.scope():
            loss = self.train_batches_dist(dataset_dist)

        return loss.numpy()

    @tf.function(input_signature=(tf.data.DatasetSpec(element_spec=tf.TensorSpec(shape=(None, 1), dtype=tf.float32)),))
    def train_batches_dist(self, dataset_dist):
        total_loss = 0.0
        for batch in dataset_dist:
            losses = self.mirrored_strategy.experimental_run_v2(
                Trainer.train_batch, args=(self, batch)
            )
            mean_loss = self.mirrored_strategy.reduce(tf.distribute.ReduceOp.MEAN, losses, axis=0)

            total_loss += mean_loss
        return total_loss

    def train_batch(self, batch):
        with tf.GradientTape() as tape:
            losses = tf.square(2 * batch - self.simple_model(batch))

        gradients = tape.gradient(losses, self.simple_model.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.simple_model.trainable_weights))

        return losses


def main():
    values = np.random.sample((100, 1)).astype(np.float32)

    dataset = tf.data.Dataset.from_tensor_slices(values)
    dataset = dataset.batch(10)

    trainer = Trainer()
    for epoch in range(0, 100):
        loss = trainer.train_batches(dataset)
        print(loss / 10.0)


if __name__ == '__main__':
    main()

Сообщение об ошибке

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.distribute.input_lib.DistributedDataset'>

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...