Мы с другом пытались запустить модель в браузере, построенном с Keras, используя Tensorflow JS. Мы следовали официальному учебнику , но прогнозы в браузере отличаются от Python.
В частности, модель представляет собой 1D CNN, и ее целью является классификация одномерных массивов (сегментов) записей напряжения по классу захвата или отсутствия захвата. В браузере прогнозные оценки модели постоянно центрированы на уровне 0,93 для класса отсутствия захвата во всех сегментах напряжения, что неверно.
Мы использовали следующие методы для экспорта модели из Keras в совместимый с Tensorflow JS формат:
С помощью инструмента CLI tenorflowjs_converter:
tensorflowjs_converter --input_format keras \
path/to/my_model.h5 \
path/to/tfjs_target_dir
Непосредственно из Python:
tfjs.converters.save_keras_model(model, tfjs_target_dir)
Модель 1D CNN требует трехмерных данных, поэтому мы должны изменить их на лету.
В Python мы просто используем:
data = data.reshape(data.shape[0], data.shape[1], 1) # data = tables instance
prediction = model.predict(data)
Это дает правильные прогнозы, как и ожидалось.
Вот как мы загружаем модель, изменяем форму данных и запускаем модель в браузере:
let toPredict = tf.tensor2d(inputData, [inputData.length, inputData[0].length])
toPredict = toPredict.reshape([inputData.length, inputData[0].length, 1])
const model = await tf.loadLayersModel('/model/model.json');
predictions = await model.predict(toPredict).array();
Это не возвращает никакой ошибки, но дает неверные прогнозы, как описано выше.
Подходы, которые мы пытались устранить, это:
1) Проверено, что данные были правильными загружается в соответствующем формате. Экспорт этих данных из браузера и передача их в модель Python keras дали правильные прогнозы.
2) Проверено, что вес модели и архитектура совпадают в Python и в браузере.
3) Отключено WebGL.
4) Преобразовал модель с использованием двух различных методов (см. Выше).
Однако у нас ничего не получалось. У кого-нибудь есть какие-либо советы или рекомендации, как решить эту проблему?
Стек интерфейса:
- Последний Google Chrome / последний Firefox в Ма c ОС
- Tensorflow JS 1.7.2
Заранее спасибо,