Уменьшение памяти Tensorflow TPU v2 / v3 bfloat16 - PullRequest
0 голосов
/ 24 ноября 2018

Моя модель слишком большая, чтобы получить пакет> 64 с обычными устройствами TPU v2.На сайте поиск и устранение неисправностей упоминается, что в следующих версиях tenorflow будет поддерживаться bfloat16.Могут ли недавно поддерживаемые версии 1.9-1.12 tf использовать bfloat16 и, если да, есть ли ограниченный набор оптимизаторов, которые я могу использовать?Я не нашел никакой дополнительной документации по этому вопросу, но видел использование bfloat16 в модели тензорной линии, поэтому я думаю, что должен быть способ.

Кроме того, я читал, что TPU v3 также поддерживает большие модели но модель нуждается в минимальных изменениях, но я не нашел никакой документации, что нужно изменить.

Я уже использую Adafactor и пытался уменьшить свои слои,если у вас есть какие-либо дополнительные советы по сокращению, это тоже было бы здорово.Я использую матрицы изображений и векторы слов (float32 на данный момент) в качестве входных данных.

1 Ответ

0 голосов
/ 18 декабря 2018

Вы можете использовать bfloat16 с TPU.Необходимо сделать две основные вещи:

  1. Привести входные данные к bfloat16 в вашем входном конвейере
  2. Окружить вашу сеть внутри области bfloat16 и привести выходные данные как F32 для дальнейших вычислений.

Вот фрагмент кода, который иллюстрирует необходимые изменения:

def input_fn():

  def dataset_parser(self, value):
    """Parse an ImageNet record from a serialized string Tensor."""
    image = self.image_preprocessing_fn(
        image_bytes=image_bytes,
        is_training=self.is_training,
    )

    if self.use_bfloat16:
      image = tf.cast(image, tf.bfloat16)

    return image, label


def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator."""

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    network = resnet_model.resnet_v1(
        resnet_depth=FLAGS.resnet_depth,
        num_classes=LABEL_CLASSES,
        data_format=FLAGS.data_format)
    return network(
        inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  if FLAGS.precision == 'bfloat16':
    with bfloat16.bfloat16_scope():
      logits = build_network()
    logits = tf.cast(logits, tf.float32)
  elif FLAGS.precision == 'float32':
    logits = build_network()

Вы также можете увидеть второе условие, показанное в этой модели TPU .

...