Обновление только последнего слоя классификации модели TFLITE - PullRequest
0 голосов
/ 28 января 2020

В настоящее время я выполняю вывод через tflite на периферийном устройстве, которое требует периодического обновления, так как я накапливаю дополнительные секретные данные для обучения. У пограничного устройства есть только спутниковая связь, поэтому любые данные, отправляемые в / из него, могут быть очень дорогими. Используемая в настоящее время модель представляет собой мобильную архитектуру enet V2 100 объемом 3,5 МБ, что является наименьшим размером, который мне удалось получить при сохранении достаточно высокой точности. Моя текущая идея по обновлению модели состоит в том, чтобы переобучить только последний слой классификации, извлечь веса, отправить новые веса на периферийное устройство через двоичный файл, скопировать содержимое старой модели в новый плоский буфер, перезаписывая веса последнего слоя, затем сохранение.

Мне интересно, это правильный подход? Я что-то пропускаю / пропускаю?

Вот подтверждение концепции, которую я написал на c ++:

    //load flatbbuffer binary
    std::ifstream infile;
    infile.open("model.tflite", std::ios::binary | std::ios::in);
    infile.seekg(0,std::ios::end);
    int length = infile.tellg();
    std::cout << length << std::endl;
    infile.seekg(0,std::ios::beg);
    char *data = new char[length];
    infile.read(data, length);
    infile.close();

    //Unpack flatbuffer
    tflite::ModelT model;
    tflite::GetModel(data)->UnPackTo(&model);

    //load weights. For test model, buffer index of classification layer is 181  
    std::vector<uint8_t> buffer_data = (model.buffers[181]->data);

    //load new weights from binary file sent to device
    std::vector<uint8_t> new_buffer_data = ReadBinary("new_weights_binary");

    //assign new weights to existing ones.
    model.buffers[181]->data = new_buffer_data;

    //Pack and save updated flatbbuffer
    flatbuffers::FlatBufferBuilder fbb;
    auto model2 = tflite::CreateModel(fbb, &model);
    fbb.Finish(model2,reinterpret_cast<const char*>(tflite::ModelIdentifier()));

    std::string filename = "new_model.tflite";
    bool result = SaveFB(filename.c_str(),
                        reinterpret_cast<const char*>(fbb.GetBufferPointer()),
                        (size_t) fbb.GetSize(),
                        true);
...