Ну, это зависит от того, как реализован ваш Dataset
.Например, в случае torchvision.datasets.MNIST(...)
вы не можете получить имя файла просто потому, что не существует такой вещи, как имя файла отдельной выборки (выборки MNIST загружаются другим способом ).
Поскольку вы не показали свою реализацию Dataset
, я расскажу вам, как это можно сделать с помощью torchvision.datasets.ImageFolder(...)
(или любого torchvision.datasets.DatasetFolder(...)
):
f = open("test_y", "w")
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader, 0):
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
sample_fname, _ = test_loader.dataset.samples[i]
f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()
Вы можете видеть, что путь к файлу извлекается во время __getitem__(self, index)
, в частности здесь .
Если вы реализовали свой собственный Dataset
(и, возможно,хотел бы поддерживать shuffle
и batch_size > 1
), тогда я бы вернул sample_fname
при вызове __getitem__(...)
и сделал бы что-то вроде этого:
for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
# [...]
Таким образом, вам не нужно было бызаботиться о shuffle
.И если batch_size
больше 1, вам нужно изменить содержимое цикла на что-то более общее, например:
f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
outputs = model(images)
pred = torch.max(outputs, 1)[1]
f.write("\n".join([
", ".join(x)
for x in zip(map(str, pred.cpu().tolist()), samples_fname)
]) + "\n")
f.close()