Я обучил модель MNIST, используя L enet CNN, и теперь я пытаюсь сравнить некоторые входные изображения с обученной сетью.
Обучение и оценка были хорошими, так как они достигли точности 0,963241 после 100 эпох.
[03:20:24] /home/greg/dev/matchbox/src/Lenet.hpp:246: EPOCH [99] Val Accuracy: 0.963241
[03:20:24] /home/greg/dev/matchbox/src/Lenet.hpp:247: EPOCH [99] Val LogLoss: 0.147613
Теперь я сравниваю входное изображение с единичным ди git, но мой прогноз неверен, даже если они имеют высокую точность.
The model predicts the input image to be a [8 ] with Accuracy = 0.99845
Я подозреваю, что проблема заключается в том, когда я загружаю их изображение NDArray image_data = LoadInputImage(image_file)
, а форма NDARRAY неверна.
NDArray Predictor::LoadInputImage(const std::string &image_file) {
if (!FileExists(image_file)) {
LG << "Image file " << image_file << " does not exist";
throw std::runtime_error("Image file does not exist");
}
LG << "Loading the image " << image_file << std::endl;
cv::Mat mat = cv::imread(image_file, cv::IMREAD_GRAYSCALE);
mat.convertTo(mat, CV_32F);
/*resize pictures to (28, 28) according to the pretrained model*/
int channels = input_shape_[1];
int height = input_shape_[2];
int width = input_shape_[3];
cv::resize(mat, mat, cv::Size(width, height));
std::vector<float> array((float *) mat.data, (float *) mat.data + mat.rows * mat.cols);
std::cout << mat;
NDArray image_data = NDArray(input_shape_, global_ctx_, false);
image_data.SyncCopyFromCPU(array.data(), input_shape_.Size());
NDArray::WaitAll();
return image_data;
}
Вот как выглядит изображение, когда оно выводится на консоль std::cout << mat
. В выводе ниже я заменил 255
на ___
для удобства чтения.
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 249, 226, ___, 142, 100, 113, 198, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 187, 102, 50, 139, ___, 175, 133, 111, 71, 92, 238, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 190, 125, 197, 251, ___, ___, ___, ___, ___, 47, 125, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 201, 235, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 97, 130, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 64, 158, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 106, 68, 252, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 236, 150, 98, 194, 195, 245, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 117, 10, 0, 109, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 238, 244, 200, 48, 105, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 244, 160, 173, 236, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 46, 136, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 167, 115, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 244, 197, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 170, 60, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 211, 235, ___, ___, ___, ___, ___, ___, ___, 250, 31, 118, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 150, 31, 251, ___, ___, ___, ___, ___, ___, 223, 149, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, 119, 52, 227, ___, ___, 235, 177, 44, 186, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 121, 190, 214, 76, 56, 70, 162, 252, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 224, 177, 220, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___
Форма данных выглядит правильно при 28x28
Метод прогнозирования
void Predictor::Score(const std::string &image_file) {
// Load the input image
NDArray image_data = LoadInputImage(image_file);
LG << "Running the forward pass on model to predict the image";
/*
* The executor->arg_arrays represent the arguments to the model.
*
* Copying the image_data that contains the NDArray of input image
* to the arg map of the executor. The input is stored with the key "data" in the map.
*/
double ms = ms_now();
image_data.CopyTo(&args_map_["data"]);
NDArray::WaitAll();
// Run the forward pass.
executor_->Forward(false);
NDArray::WaitAll();
auto array = executor_->outputs[0].Copy(global_ctx_);
/*
* Find out the maximum accuracy and the index associated with that accuracy.
* This is done by using the argmax operator on NDArray.
*/
auto predicted = array.ArgmaxChannel();
/*
* Wait until all the previous write operations on the 'predicted'
* NDArray to be complete before we read it.
* This method guarantees that all previous write operations that pushed into the backend engine
* for execution are actually finished.
*/
predicted.WaitToRead();
NDArray::WaitAll();
auto best_idx = predicted.At(0);
auto best_accuracy = array.At(0, best_idx);
LG << "best_idx, best_accuracy = " << best_idx << " : " << best_accuracy;
if (output_labels.empty()) {
LG << "The model predicts the highest accuracy of " << best_accuracy << " at index "
<< best_idx;
} else {
LG << "The model predicts the input image to be a [" << output_labels[best_idx]
<< " ] with Accuracy = " << best_accuracy << std::endl;
}
mx_uint len = output_labels.size();
std::vector<mx_float> pred_data(len);
std::vector<mx_float> label_data(len);
predicted.SyncCopyToCPU(&pred_data, len);
// Display all candidates
for (mx_uint i = 0; i < len; ++i) {
auto val = pred_data[i]; // predicted
auto label = label_data[i]; // expected
auto best_idx = predicted.At(i);
auto best_accuracy = array.At(0, best_idx);
LG << "best_idx, best_accuracy = " << best_idx << " : " << best_accuracy;
auto accuracy = array.At(0, i);
LG << "Found, Expected, Accuracy :: " << i << " : " << val << " = " << label << " : " << accuracy << " == "
<< best_accuracy;
}
ms = ms_now() - ms;
auto args_name = net_.ListArguments();
LG << "INFO:" << "label_name = " << args_name[args_name.size() - 1];
LG << "INFO:" << "rgb_mean: " << "(" << rgb_mean_[0] << ", " << rgb_mean_[1]
<< ", " << rgb_mean_[2] << ")";
LG << "INFO:" << "rgb_std: " << "(" << rgb_std_[0] << ", " << rgb_std_[1]
<< ", " << rgb_std_[2] << ")";
LG << "INFO:" << "Image shape: " << "(" << input_shape_[1] << ", "
<< input_shape_[2] << ", " << input_shape_[3] << ")";
LG << "INFO:" << "Batch size = " << input_shape_[0] << " for inference";
LG << "INFO:" << "Throughput: " << (1000.0 * input_shape_[0] / ms)
<< " images per second";
}
Входное изображение было цветным изображением ди git '3' 28x28, я подозреваю, что проблема здесь LoadInputImage(const std::string &image_file)
, но пока не могу ее точно определить.
Любые мысли будут полезны.