Модель Tensorflow 2.0, использующая функцию tf.function, очень медленная и перекомпилируется каждый раз, когда изменяется число поездов.Eager работает примерно в 4 раза быстрее - PullRequest
8 голосов
/ 16 апреля 2019

У меня есть модели, построенные из не скомпилированного кода keras, и я пытаюсь запустить их через пользовательский цикл обучения.

Стремительный код TF 2.0 (по умолчанию) выполняется около 30 секунд на процессоре (ноутбуке).Когда я создаю модель keras с обернутыми методами вызова tf.function, она работает намного, намного медленнее и, похоже, запускается очень долго, особенно в «первый» раз.

Например, вФункциональный код: начальный набор из 10 выборок занимает 40 с, а следующий за 10 выборками - 2 с.

На 20 выборках начальный этап занимает 50 с, а за последующие 4 с.

Первый поезд на 1 выборке занимает 2 с, а наблюдение занимает 200 мс.

Похоже, что каждый вызов поезда создает новый график , где сложность зависит от количества поездов!?

Я просто делаю что-то вроде этого:

@tf.function
def train(n=10):
    step = 0
    loss = 0.0
    accuracy = 0.0
    for i in range(n):
        step += 1
        d, dd, l = train_one_step(model, opt, data)
        tf.print(dd)
        with tf.name_scope('train'):
            for k in dd:
                tf.summary.scalar(k, dd[k], step=step)
        if tf.equal(step % 10, 0):
            tf.print(dd)
    d.update(dd)
    return d

Где модель keras.model.Model с методом @tf.function декорирование call в соответствии с примерами.

1 Ответ

9 голосов
/ 16 апреля 2019

Я проанализировал это поведение @tf.function здесь Использование нативного типа Python .

Вкратце: дизайн tf.function не позволяет автоматически связывать нативные типы Python с tf.Tensor объектами с четко определенным dtype.

Если ваша функция принимает объект tf.Tensor, при первом вызове функция анализируется, график строится и связывается с этой функцией. В каждом не-первом вызове, если dtype объекта tf.Tensor совпадает, график используется повторно.

Но в случае использования нативного типа Python график создается каждый раз, когда функция вызывается с другим значением .

Вкратце: спроектируйте ваш код так, чтобы он везде использовал tf.Tensor вместо переменных Python, если вы планируете использовать @tf.function.

tf.function - это не оболочка, которая магически ускоряет функцию, которая хорошо работает в активном режиме; является оберткой, которая требует разработки энергичной функции (тело, входные параметры, типы), понимающей, что произойдет после создания графика, чтобы получить реальные ускорения.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...