Вы можете использовать bfloat16
с TPU.Необходимо сделать две основные вещи:
- Привести входные данные к bfloat16 в вашем входном конвейере
- Окружить вашу сеть внутри области bfloat16 и привести выходные данные как F32 для дальнейших вычислений.
Вот фрагмент кода, который иллюстрирует необходимые изменения:
def input_fn():
def dataset_parser(self, value):
"""Parse an ImageNet record from a serialized string Tensor."""
image = self.image_preprocessing_fn(
image_bytes=image_bytes,
is_training=self.is_training,
)
if self.use_bfloat16:
image = tf.cast(image, tf.bfloat16)
return image, label
def resnet_model_fn(features, labels, mode, params):
"""The model_fn for ResNet to be used with TPUEstimator."""
# This nested function allows us to avoid duplicating the logic which
# builds the network, for different values of --precision.
def build_network():
network = resnet_model.resnet_v1(
resnet_depth=FLAGS.resnet_depth,
num_classes=LABEL_CLASSES,
data_format=FLAGS.data_format)
return network(
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
if FLAGS.precision == 'bfloat16':
with bfloat16.bfloat16_scope():
logits = build_network()
logits = tf.cast(logits, tf.float32)
elif FLAGS.precision == 'float32':
logits = build_network()
Вы также можете увидеть второе условие, показанное в этой модели TPU .