Невозможно создать CustomDataset LibTorch / PyTorch - PullRequest
2 голосов
/ 24 марта 2020

У меня есть папка с набором изображений, которые я хочу загрузить как набор пользовательских данных в LibTorch.

int main()
{
    ...
    std::string file_location{"dataset/img_align_celeba/*.jpg"};
    auto train_set = CustomDataset(file_location).map(data::transforms::Stack<>());
}

CustomDataset.h

using namespace torch;


torch::Tensor read_data(const std::string& loc);

std::vector<std::string> obtain_data(std::string &address);

class CustomDataset : public data::Dataset<CustomDataset> {
private:
    std::vector<std::string> m_filenames;
public:
    // Constructor
    explicit CustomDataset(std::string &files): m_filenames(obtain_data(files)) {
    };

    // Override get() function to return tensor at location index
    torch::data::Example<> get(size_t index) override {
        torch::Tensor sample_img = read_data(m_filenames[index]);
        torch::Tensor sample_label = torch::full({1}, 1);
        return {sample_img, sample_label};
    };

    // Return the length of data
    torch::optional<size_t> size() const override {
        return m_filenames.size();
    };
};

CustomDataset. cpp

#include "CustomDataset.h"

torch::Tensor read_data(const std::string& loc)
{
    // Read Data here
    // Return tensor form of the image
    cv::Mat img = cv::imread(loc);
    std::vector<cv::Mat> channels(3);
    cv::split(img, channels);

    auto R = torch::from_blob(
            channels[2].ptr(),
            {64, 64},
            torch::kUInt8);
    auto G = torch::from_blob(
            channels[1].ptr(),
            {64, 64},
            torch::kUInt8);
    auto B = torch::from_blob(
            channels[0].ptr(),
            {64, 64},
            torch::kUInt8);

    auto tdata = torch::cat({R, G, B})
            .view({3, 64, 64})
            .to(torch::kFloat);

    return tdata;
}

std::vector<std::string> obtain_data(std::string &address)
{
    std::vector<std::string> filenames;
    cv::glob(address, filenames);
    return filenames;
}

Даже когда я нахожусь внутри класса CustomDataset, я вижу, что переменная-член m_filenames не пуста и что get() предоставляет тензор с изображением и меткой, как только я вернусь к main.cpp Переменная CustomDataset с именем train_set имеет размер 0.

...