Маленькая модель LSTM в керасе не подходит моему GPU - PullRequest
0 голосов
/ 25 апреля 2019

Я программирую относительно небольшую модель LSTM в Google Collab.

Для справки я использую TensorFlow 1.13 для построения модели, используя tenorflow.keras для API-интерфейса keras.

seq_len = 20000; n_classes = 4
inputs = ll.Input(shape=(seq_len,))
x = ll.Embedding(len(word_index), 1000)(inputs)
x = ll.LSTM(units=100, activation='relu', return_sequences=True)(x)
outputs = ll.Dense(units = n_classes, activation='softmax')(x)
model = Model(inputs, outputs)
model.summary()

Я проверил , что у меня доступно 15 ГБ ОЗУ графического процессора, и согласно моим оценкам модель с размером пакета 32 должна умещаться в 3 ГБ ОЗУ.

Однако, когда я запускаю тренинг, серверу не хватает памяти.

Если честно, я использую очень длинные последовательности данных (20000 - максимальная длина последовательности), но я ожидаю, что модель развернется символически в памяти и просто уместится.

Уменьшение размера пакета до 1 также не помогает.

Что происходит? Как сделать так, чтобы эта модель поместилась в памяти?

РЕДАКТИРОВАТЬ: я пытался уменьшить длину последовательности до 2, и это действительно делает его в памяти. Но мне нужно, чтобы длина последовательности оставалась высокой. Как я могу сказать Tensorflow не развертывать сеть в любой момент? (Я подозреваю, что это происходит за кулисами, как я могу проверить, так ли это на самом деле?)

EDIT: если я удаляю слой Softmax, тогда использование памяти снова падает до нормального диапазона. Я думаю, что слой Softmax заставляет Tensorflow развернуть сеть. Время Распределение Softmax не помогает, хотя.

1 Ответ

1 голос
/ 26 апреля 2019

Изменение слоя LSTM для слоя CuDNNLSTM сделало свое дело!

inputs = ll.Input(shape=(seq_len,))
x = ll.Embedding(len(word_index), 1024)(inputs)
x = ll.CuDNNLSTM(units=100, return_sequences=True)(x)
x = ll.Dense(units = n_classes, activation='softmax')(x)
outputs = x
model = Model(inputs, outputs)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...