M XNET Оценка прогнозов по набору данных MNIST - PullRequest
0 голосов
/ 17 апреля 2020

Я обучил модель MNIST, используя L enet CNN, и теперь я пытаюсь сравнить некоторые входные изображения с обученной сетью.

Обучение и оценка были хорошими, так как они достигли точности 0,963241 после 100 эпох.

[03:20:24] /home/greg/dev/matchbox/src/Lenet.hpp:246: EPOCH [99] Val Accuracy: 0.963241
[03:20:24] /home/greg/dev/matchbox/src/Lenet.hpp:247: EPOCH [99] Val LogLoss: 0.147613

Теперь я сравниваю входное изображение с единичным ди git, но мой прогноз неверен, даже если они имеют высокую точность.

The model predicts the input image to be a [8 ] with Accuracy = 0.99845

Я подозреваю, что проблема заключается в том, когда я загружаю их изображение NDArray image_data = LoadInputImage(image_file), а форма NDARRAY неверна.

NDArray Predictor::LoadInputImage(const std::string &image_file) {
    if (!FileExists(image_file)) {
        LG << "Image file " << image_file << " does not exist";
        throw std::runtime_error("Image file does not exist");
    }
    LG << "Loading the image " << image_file << std::endl;

    cv::Mat mat = cv::imread(image_file, cv::IMREAD_GRAYSCALE);
    mat.convertTo(mat, CV_32F);

    /*resize pictures to (28, 28) according to the pretrained model*/
    int channels = input_shape_[1];
    int height = input_shape_[2];
    int width = input_shape_[3];

    cv::resize(mat, mat, cv::Size(width, height));
    std::vector<float> array((float *) mat.data, (float *) mat.data + mat.rows * mat.cols);

    std::cout << mat;

    NDArray image_data = NDArray(input_shape_, global_ctx_, false);
    image_data.SyncCopyFromCPU(array.data(), input_shape_.Size());
    NDArray::WaitAll();
    return image_data;
}

Вот как выглядит изображение, когда оно выводится на консоль std::cout << mat. В выводе ниже я заменил 255 на ___ для удобства чтения.

___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 249, 226, ___, 142, 100, 113, 198, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 187, 102, 50, 139, ___, 175, 133, 111, 71, 92, 238, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 190, 125, 197, 251, ___, ___, ___, ___, ___, 47, 125, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 201, 235, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 97, 130, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 64, 158, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 106, 68, 252, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 236, 150, 98, 194, 195, 245, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 117, 10, 0, 109, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 238, 244, 200, 48, 105, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 244, 160, 173, 236, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 46, 136, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 167, 115, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 244, 197, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 170, 60, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 211, 235, ___, ___, ___, ___, ___, ___, ___, 250, 31, 118, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 150, 31, 251, ___, ___, ___, ___, ___, ___, 223, 149, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, 119, 52, 227, ___, ___, 235, 177, 44, 186, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 121, 190, 214, 76, 56, 70, 162, 252, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 224, 177, 220, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___

Форма данных выглядит правильно при 28x28

Метод прогнозирования

void Predictor::Score(const std::string &image_file) {
    // Load the input image
    NDArray image_data = LoadInputImage(image_file);
    LG << "Running the forward pass on model to predict the image";

    /*
     * The executor->arg_arrays represent the arguments to the model.
     *
     * Copying the image_data that contains the NDArray of input image
     * to the arg map of the executor. The input is stored with the key "data" in the map.
     */
    double ms = ms_now();

    image_data.CopyTo(&args_map_["data"]);
    NDArray::WaitAll();

    // Run the forward pass.
    executor_->Forward(false);
    NDArray::WaitAll();
    auto array = executor_->outputs[0].Copy(global_ctx_);

    /*
    * Find out the maximum accuracy and the index associated with that accuracy.
    * This is done by using the argmax operator on NDArray.
    */
    auto predicted = array.ArgmaxChannel();

    /*
     * Wait until all the previous write operations on the 'predicted'
     * NDArray to be complete before we read it.
     * This method guarantees that all previous write operations that pushed into the backend engine
     * for execution are actually finished.
     */
    predicted.WaitToRead();
    NDArray::WaitAll();

    auto best_idx = predicted.At(0);
    auto best_accuracy = array.At(0, best_idx);
    LG << "best_idx, best_accuracy = " << best_idx << " : " << best_accuracy;

    if (output_labels.empty()) {
        LG << "The model predicts the highest accuracy of " << best_accuracy << " at index "
           << best_idx;
    } else {
        LG << "The model predicts the input image to be a [" << output_labels[best_idx]
           << " ] with Accuracy = " << best_accuracy << std::endl;
    }

    mx_uint len = output_labels.size();
    std::vector<mx_float> pred_data(len);
    std::vector<mx_float> label_data(len);

    predicted.SyncCopyToCPU(&pred_data, len);

    // Display all candidates
    for (mx_uint i = 0; i < len; ++i) {
        auto val = pred_data[i];  // predicted
        auto label = label_data[i]; // expected

        auto best_idx = predicted.At(i);
        auto best_accuracy = array.At(0, best_idx);
        LG << "best_idx, best_accuracy = " << best_idx << " : " << best_accuracy;
        auto accuracy = array.At(0, i);
        LG << "Found, Expected, Accuracy  :: " << i << " : " << val << " = " << label << " : " << accuracy << " == "
           << best_accuracy;
    }

    ms = ms_now() - ms;

    auto args_name = net_.ListArguments();
    LG << "INFO:" << "label_name = " << args_name[args_name.size() - 1];
    LG << "INFO:" << "rgb_mean: " << "(" << rgb_mean_[0] << ", " << rgb_mean_[1]
       << ", " << rgb_mean_[2] << ")";
    LG << "INFO:" << "rgb_std: " << "(" << rgb_std_[0] << ", " << rgb_std_[1]
       << ", " << rgb_std_[2] << ")";
    LG << "INFO:" << "Image shape: " << "(" << input_shape_[1] << ", "
       << input_shape_[2] << ", " << input_shape_[3] << ")";
    LG << "INFO:" << "Batch size = " << input_shape_[0] << " for inference";
    LG << "INFO:" << "Throughput: " << (1000.0 * input_shape_[0] / ms)
       << " images per second";
}

Входное изображение было цветным изображением ди git '3' 28x28, я подозреваю, что проблема здесь LoadInputImage(const std::string &image_file), но пока не могу ее точно определить.

Любые мысли будут полезны.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...