BERT от tfhub МЕДЛЕННО и не использует GPU - PullRequest
1 голос
/ 08 января 2020

Я пробую эту модель BERT от TFHub . К сожалению, он работает очень медленно и использует только 1-2% GPU в соответствии с Windows Task Manager.

Что я могу сделать, чтобы ускорить это?

import tensorflow as tf
import tensorflow_hub as hub

tf.test.is_gpu_available(True) # returns True

max_seq_length = 128

input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name="segment_ids")

bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1", trainable=False)
pooled_out, seq_out = bert_layer([input_word_ids, input_mask, segment_ids])
model = tf.keras.Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=[pooled_out, seq_out])

t = tf.random.uniform((1000, max_seq_length), maxval=1, dtype=tf.int32)
outputs = model.predict([t, t, t]) # super slow...

1 Ответ

0 голосов
/ 08 января 2020

Это случилось со мной. Как уже упоминалось Ashwin Geet D'SA, убедитесь, что у вас установлено tensorflow-gpu (не tensorflow). Запустите этот тест, чтобы убедиться, что графический процессор доступен:

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))
...