Как использовать tf.keras с bfloat16 - PullRequest
11 голосов
/ 13 мая 2019

Я пытаюсь заставить модель tf.keras работать на TPU, используя смешанную точность. Мне было интересно, как построить модель keras, используя смешанную точность bfloat16. Это что-то вроде этого?

with tf.contrib.tpu.bfloat16_scope():
    inputs = tf.keras.layers.Input(shape=(2,), dtype=tf.bfloat16)
    logits = tf.keras.layers.Dense(2)(inputs)

logits = tf.cast(logits, tf.float32)
model = tf.keras.models.Model(inputs=inputs, outputs=logits)
model.compile(optimizer=tf.keras.optimizers.Adam(.001),
              loss='mean_absolute_error', metrics=[])

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
        model,
        strategy=tf.contrib.tpu.TPUDistributionStrategy(
            tf.contrib.cluster_resolver.TPUClusterResolver(tpu='my_tpu_name')
        )
    )
...