viz = visdom.Visdom(env=DATASET+ ' ' + MODEL)
if not viz.check_connection:
print("Visdom is not connected. Did you run 'python -m visdom.server' ?")
Мой учитель дал мне несколько кодов для запуска, но я не смог. В нем сказано: «TypeError: неподдерживаемые типы операндов для +: 'NoneType' и 'str'". Как это решить? Я не знаком с Python. Ниже приведен мой код.
# Python 2/3 compatiblity
from __future__ import print_function
from __future__ import division
# Torch
import torch
import torch.utils.data as data
from torchsummary import summary
# Numpy, scipy, scikit-image, spectral
import numpy as np
import sklearn.svm
import sklearn.model_selection
from skimage import io
# Visualization
import seaborn as sns
import visdom
import os
from utils import metrics, convert_to_color_, convert_from_color_,\
display_dataset, display_predictions, explore_spectrums, plot_spectrums,\
sample_gt, build_dataset, show_results, compute_imf_weights, get_device
from datasets import get_dataset, HyperX, open_file, DATASETS_CONFIG
from models import get_model, train, test, save_model
import argparse
dataset_names = [v['name'] if 'name' in v.keys() else k for k, v in DATASETS_CONFIG.items()]
# Argument parser for CLI interaction
parser = argparse.ArgumentParser(description="Run deep learning experiments on"
" various hyperspectral datasets")
parser.add_argument('--dataset', type=str, default=None, choices=dataset_names,
help="Dataset to use.")
parser.add_argument('--model', type=str, default=None,
help="Model to train. Available:\n"
"SVM (linear), "
"SVM_grid (grid search on linear, poly and RBF kernels), "
"baseline (fully connected NN), "
"hu (1D CNN), "
"hamida (3D CNN + 1D classifier), "
"lee (3D FCN), "
"chen (3D CNN), "
"li (3D CNN), "
"he (3D CNN), "
"luo (3D CNN), "
"sharma (2D CNN), "
"boulch (1D semi-supervised CNN), "
"liu (3D semi-supervised CNN), "
"mou (1D RNN)")
parser.add_argument('--folder', type=str, help="Folder where to store the "
"datasets (defaults to the current working directory).",
default="./Datasets/")
parser.add_argument('--cuda', type=int, default=-1,
help="Specify CUDA device (defaults to -1, which learns on CPU)")
parser.add_argument('--runs', type=int, default=1, help="Number of runs (default: 1)")
parser.add_argument('--restore', type=str, default=None,
help="Weights to use for initialization, e.g. a checkpoint")
# Dataset options
group_dataset = parser.add_argument_group('Dataset')
group_dataset.add_argument('--training_sample', type=float, default=10,
help="Percentage of samples to use for training (default: 10%%)")
group_dataset.add_argument('--sampling_mode', type=str, help="Sampling mode"
" (random sampling or disjoint, default: random)",
default='random')
group_dataset.add_argument('--train_set', type=str, default=None,
help="Path to the train ground truth (optional, this "
"supersedes the --sampling_mode option)")
group_dataset.add_argument('--test_set', type=str, default=None,
help="Path to the test set (optional, by default "
"the test_set is the entire ground truth minus the training)")
# Training options
group_train = parser.add_argument_group('Training')
group_train.add_argument('--epoch', type=int, help="Training epochs (optional, if"
" absent will be set by the model)")
group_train.add_argument('--patch_size', type=int,
help="Size of the spatial neighbourhood (optional, if "
"absent will be set by the model)")
group_train.add_argument('--lr', type=float,
help="Learning rate, set by the model if not specified.")
group_train.add_argument('--class_balancing', action='store_true',
help="Inverse median frequency class balancing (default = False)")
group_train.add_argument('--batch_size', type=int,
help="Batch size (optional, if absent will be set by the model")
group_train.add_argument('--test_stride', type=int, default=1,
help="Sliding window step stride during inference (default = 1)")
# Data augmentation parameters
group_da = parser.add_argument_group('Data augmentation')
group_da.add_argument('--flip_augmentation', action='store_true',
help="Random flips (if patch_size > 1)")
group_da.add_argument('--radiation_augmentation', action='store_true',
help="Random radiation noise (illumination)")
group_da.add_argument('--mixture_augmentation', action='store_true',
help="Random mixes between spectra")
parser.add_argument('--with_exploration', action='store_true',
help="See data exploration visualization")
parser.add_argument('--download', type=str, default=None, nargs='+',
choices=dataset_names,
help="Download the specified datasets and quits.")
args = parser.parse_args()
CUDA_DEVICE = get_device(args.cuda)
# % of training samples
SAMPLE_PERCENTAGE = args.training_sample
# Data augmentation ?
FLIP_AUGMENTATION = args.flip_augmentation
RADIATION_AUGMENTATION = args.radiation_augmentation
MIXTURE_AUGMENTATION = args.mixture_augmentation
# Dataset name
DATASET = args.dataset
# Model name
MODEL = args.model
# Number of runs (for cross-validation)
N_RUNS = args.runs
# Spatial context size (number of neighbours in each spatial direction)
PATCH_SIZE = args.patch_size
# Add some visualization of the spectra ?
DATAVIZ = args.with_exploration
# Target folder to store/download/load the datasets
FOLDER = args.folder
# Number of epochs to run
EPOCH = args.epoch
# Sampling mode, e.g random sampling
SAMPLING_MODE = args.sampling_mode
# Pre-computed weights to restore
CHECKPOINT = args.restore
# Learning rate for the SGD
LEARNING_RATE = args.lr
# Automated class balancing
CLASS_BALANCING = args.class_balancing
# Training ground truth file
TRAIN_GT = args.train_set
# Testing ground truth file
TEST_GT = args.test_set
TEST_STRIDE = args.test_stride
if args.download is not None and len(args.download) > 0:
for dataset in args.download:
get_dataset(dataset, target_folder=FOLDER)
quit()
viz = visdom.Visdom(env=DATASET+ ' ' + MODEL)
if not viz.check_connection:
print("Visdom is not connected. Did you run 'python -m visdom.server' ?")
hyperparams = vars(args)
# Load the dataset
img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET,
FOLDER)
# Number of classes
N_CLASSES = len(LABEL_VALUES)
# Number of bands (last dimension of the image tensor)
N_BANDS = img.shape[-1]
# Parameters for the SVM grid search
SVM_GRID_PARAMS = [{'kernel': ['rbf'], 'gamma': [1e-1, 1e-2, 1e-3],
'C': [1, 10, 100, 1000]},
{'kernel': ['linear'], 'C': [0.1, 1, 10, 100, 1000]},
{'kernel': ['poly'], 'degree': [3], 'gamma': [1e-1, 1e-2, 1e-3]}]
if palette is None:
# Generate color palette
palette = {0: (0, 0, 0)}
for k, color in enumerate(sns.color_palette("hls", len(LABEL_VALUES) - 1)):
palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype='uint8'))
invert_palette = {v: k for k, v in palette.items()}
def convert_to_color(x):
return convert_to_color_(x, palette=palette)
def convert_from_color(x):
return convert_from_color_(x, palette=invert_palette)
# Instantiate the experiment based on predefined networks
hyperparams.update({'n_classes': N_CLASSES, 'n_bands': N_BANDS, 'ignored_labels': IGNORED_LABELS, 'device': CUDA_DEVICE})
hyperparams = dict((k, v) for k, v in hyperparams.items() if v is not None)
# Show the image and the ground truth
display_dataset(img, gt, RGB_BANDS, LABEL_VALUES, palette, viz)
color_gt = convert_to_color(gt)
if DATAVIZ:
# Data exploration : compute and show the mean spectrums
mean_spectrums = explore_spectrums(img, gt, LABEL_VALUES, viz,
ignored_labels=IGNORED_LABELS)
plot_spectrums(mean_spectrums, viz, title='Mean spectrum/class')
results = []
# run the experiment several times
for run in range(N_RUNS):
if TRAIN_GT is not None and TEST_GT is not None:
train_gt = open_file(TRAIN_GT)
test_gt = open_file(TEST_GT)
elif TRAIN_GT is not None:
train_gt = open_file(TRAIN_GT)
test_gt = np.copy(gt)
w, h = test_gt.shape
test_gt[(train_gt > 0)[:w,:h]] = 0
elif TEST_GT is not None:
test_gt = open_file(TEST_GT)
else:
# Sample random training spectra
train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE)
print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
np.count_nonzero(gt)))
print("Running an experiment with the {} model".format(MODEL),
"run {}/{}".format(run + 1, N_RUNS))
display_predictions(convert_to_color(train_gt), viz, caption="Train ground truth")
display_predictions(convert_to_color(test_gt), viz, caption="Test ground truth")
if MODEL == 'SVM_grid':
print("Running a grid search SVM")
# Grid search SVM (linear and RBF)
X_train, y_train = build_dataset(img, train_gt,
ignored_labels=IGNORED_LABELS)
class_weight = 'balanced' if CLASS_BALANCING else None
clf = sklearn.svm.SVC(class_weight=class_weight)
clf = sklearn.model_selection.GridSearchCV(clf, SVM_GRID_PARAMS, verbose=5, n_jobs=4)
clf.fit(X_train, y_train)
print("SVM best parameters : {}".format(clf.best_params_))
prediction = clf.predict(img.reshape(-1, N_BANDS))
save_model(clf, MODEL, DATASET)
prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'SVM':
X_train, y_train = build_dataset(img, train_gt,
ignored_labels=IGNORED_LABELS)
class_weight = 'balanced' if CLASS_BALANCING else None
clf = sklearn.svm.SVC(class_weight=class_weight)
clf.fit(X_train, y_train)
save_model(clf, MODEL, DATASET)
prediction = clf.predict(img.reshape(-1, N_BANDS))
prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'SGD':
X_train, y_train = build_dataset(img, train_gt,
ignored_labels=IGNORED_LABELS)
X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
scaler = sklearn.preprocessing.StandardScaler()
X_train = scaler.fit_transform(X_train)
class_weight = 'balanced' if CLASS_BALANCING else None
clf = sklearn.linear_model.SGDClassifier(class_weight=class_weight, learning_rate='optimal', tol=1e-3, average=10)
clf.fit(X_train, y_train)
save_model(clf, MODEL, DATASET)
prediction = clf.predict(scaler.transform(img.reshape(-1, N_BANDS)))
prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'nearest':
X_train, y_train = build_dataset(img, train_gt,
ignored_labels=IGNORED_LABELS)
X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
class_weight = 'balanced' if CLASS_BALANCING else None
clf = sklearn.neighbors.KNeighborsClassifier(weights='distance')
clf = sklearn.model_selection.GridSearchCV(clf, {'n_neighbors': [1, 3, 5, 10, 20]}, verbose=5, n_jobs=4)
clf.fit(X_train, y_train)
clf.fit(X_train, y_train)
save_model(clf, MODEL, DATASET)
prediction = clf.predict(img.reshape(-1, N_BANDS))
prediction = prediction.reshape(img.shape[:2])
else:
# Neural network
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
if CLASS_BALANCING:
weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
hyperparams['weights'] = torch.from_numpy(weights)
# Split train set in train/val
train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
# Generate the dataset
train_dataset = HyperX(img, train_gt, **hyperparams)
train_loader = data.DataLoader(train_dataset,
batch_size=hyperparams['batch_size'],
#pin_memory=hyperparams['device'],
shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
#pin_memory=hyperparams['device'],
batch_size=hyperparams['batch_size'])
print(hyperparams)
print("Network :")
with torch.no_grad():
for input, _ in train_loader:
break
summary(model.to(hyperparams['device']), input.size()[1:])
# We would like to use device=hyperparams['device'] altough we have
# to wait for torchsummary to be fixed first.
if CHECKPOINT is not None:
model.load_state_dict(torch.load(CHECKPOINT))
try:
train(model, optimizer, loss, train_loader, hyperparams['epoch'],
scheduler=hyperparams['scheduler'], device=hyperparams['device'],
supervision=hyperparams['supervision'], val_loader=val_loader,
display=viz)
except KeyboardInterrupt:
# Allow the user to stop the training
pass
probabilities = test(model, img, hyperparams)
prediction = np.argmax(probabilities, axis=-1)
run_results = metrics(prediction, test_gt, ignored_labels=hyperparams['ignored_labels'], n_classes=N_CLASSES)
mask = np.zeros(gt.shape, dtype='bool')
for l in IGNORED_LABELS:
mask[gt == l] = True
prediction[mask] = 0
color_prediction = convert_to_color(prediction)
display_predictions(color_prediction, viz, gt=convert_to_color(test_gt), caption="Prediction vs. test ground truth")
results.append(run_results)
show_results(run_results, viz, label_values=LABEL_VALUES)
if N_RUNS > 1:
show_results(results, viz, label_values=LABEL_VALUES, agregated=True)
Ниже приведен dataset.py
# -*- coding: utf-8 -*-
"""
This file contains the PyTorch dataset for hyperspectral images and
related helpers.
"""
import spectral
import numpy as np
import torch
import torch.utils
import torch.utils.data
import os
from tqdm import tqdm
try:
# Python 3
from urllib.request import urlretrieve
except ImportError:
# Python 2
from urllib import urlretrieve
from utils import open_file
DATASETS_CONFIG = {
'PaviaC': {
'urls': ['http://www.ehu.eus/ccwintco/uploads/e/e3/Pavia.mat',
'http://www.ehu.eus/ccwintco/uploads/5/53/Pavia_gt.mat'],
'img': 'Pavia.mat',
'gt': 'Pavia_gt.mat'
},
'Salinas': {
'urls': ['http://www.ehu.eus/ccwintco/uploads/a/a3/Salinas_corrected.mat',
'http://www.ehu.eus/ccwintco/uploads/f/fa/Salinas_gt.mat'],
'img': 'Salinas_corrected.mat',
'gt': 'Salinas_gt.mat'
},
'PaviaU': {
'urls': ['http://www.ehu.eus/ccwintco/uploads/e/ee/PaviaU.mat',
'http://www.ehu.eus/ccwintco/uploads/5/50/PaviaU_gt.mat'],
'img': 'PaviaU.mat',
'gt': 'PaviaU_gt.mat'
},
'KSC': {
'urls': ['http://www.ehu.es/ccwintco/uploads/2/26/KSC.mat',
'http://www.ehu.es/ccwintco/uploads/a/a6/KSC_gt.mat'],
'img': 'KSC.mat',
'gt': 'KSC_gt.mat'
},
'IndianPines': {
'urls': ['http://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat',
'http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat'],
'img': 'Indian_pines_corrected.mat',
'gt': 'Indian_pines_gt.mat'
},
'Botswana': {
'urls': ['http://www.ehu.es/ccwintco/uploads/7/72/Botswana.mat',
'http://www.ehu.es/ccwintco/uploads/5/58/Botswana_gt.mat'],
'img': 'Botswana.mat',
'gt': 'Botswana_gt.mat',
}
}
try:
from custom_datasets import CUSTOM_DATASETS_CONFIG
DATASETS_CONFIG.update(CUSTOM_DATASETS_CONFIG)
except ImportError:
pass
class TqdmUpTo(tqdm):
"""Provides `update_to(n)` which uses `tqdm.update(delta_n)`."""
def update_to(self, b=1, bsize=1, tsize=None):
"""
b : int, optional
Number of blocks transferred so far [default: 1].
bsize : int, optional
Size of each block (in tqdm units) [default: 1].
tsize : int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n) # will also set self.n = b * bsize
def get_dataset(dataset_name, target_folder="./", datasets=DATASETS_CONFIG):
""" Gets the dataset specified by name and return the related components.
Args:
dataset_name: string with the name of the dataset
target_folder (optional): folder to store the datasets, defaults to ./
datasets (optional): dataset configuration dictionary, defaults to prebuilt one
Returns:
img: 3D hyperspectral image (WxHxB)
gt: 2D int array of labels
label_values: list of class names
ignored_labels: list of int classes to ignore
rgb_bands: int tuple that correspond to red, green and blue bands
"""
palette = None
if dataset_name not in datasets.keys():
raise ValueError("{} dataset is unknown.".format(dataset_name))
dataset = datasets[dataset_name]
folder = target_folder + datasets[dataset_name].get('folder', dataset_name + '/')
if dataset.get('download', True):
# Download the dataset if is not present
if not os.path.isdir(folder):
os.mkdir(folder)
for url in datasets[dataset_name]['urls']:
# download the files
filename = url.split('/')[-1]
if not os.path.exists(folder + filename):
with TqdmUpTo(unit='B', unit_scale=True, miniters=1,
desc="Downloading {}".format(filename)) as t:
urlretrieve(url, filename=folder + filename,
reporthook=t.update_to)
elif not os.path.isdir(folder):
print("WARNING: {} is not downloadable.".format(dataset_name))
if dataset_name == 'PaviaC':
# Load the image
img = open_file(folder + 'Pavia.mat')['pavia']
rgb_bands = (55, 41, 12)
gt = open_file(folder + 'Pavia_gt.mat')['pavia_gt']
label_values = ["Undefined", "Water", "Trees", "Asphalt",
"Self-Blocking Bricks", "Bitumen", "Tiles", "Shadows",
"Meadows", "Bare Soil"]
ignored_labels = [0]
elif dataset_name == 'PaviaU':
# Load the image
img = open_file(folder + 'PaviaU.mat')['paviaU']
rgb_bands = (55, 41, 12)
gt = open_file(folder + 'PaviaU_gt.mat')['paviaU_gt']
label_values = ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees',
'Painted metal sheets', 'Bare Soil', 'Bitumen',
'Self-Blocking Bricks', 'Shadows']
ignored_labels = [0]
elif dataset_name == 'Salinas':
img = open_file(folder + 'Salinas.mat')['Salinas_corrected']
rgb_bands = (43, 21, 11) # AVIRIS sensor
gt = open_file(folder + 'Salinas_gt.mat')['Salinas_gt']
label_values = ['Undefined','Brocoli_green_weeds_1', 'Brocoli_green_weeds_2','Fallow','Fallow_rough_plow','Fallow_smooth','Stubble','Celery',
'Grapes_untrained','Soil_vinyard_develop','Corn_senesced_green_weeds','Lettuce_romaine_4wk','Lettuce_romaine_5wk','Lettuce_romaine_6wk',
'Lettuce_romaine_7wk','Vinyard_untrained','Vinyard_vertical_trellis']
ignored_labels = [0]
elif dataset_name == 'IndianPines':
# Load the image
img = open_file(folder + 'Indian_pines_corrected.mat')
img = img['indian_pines_corrected']
rgb_bands = (43, 21, 11) # AVIRIS sensor
gt = open_file(folder + 'Indian_pines_gt.mat')['indian_pines_gt']
label_values = ["Undefined", "Alfalfa", "Corn-notill", "Corn-mintill",
"Corn", "Grass-pasture", "Grass-trees",
"Grass-pasture-mowed", "Hay-windrowed", "Oats",
"Soybean-notill", "Soybean-mintill", "Soybean-clean",
"Wheat", "Woods", "Buildings-Grass-Trees-Drives",
"Stone-Steel-Towers"]
ignored_labels = [0]
elif dataset_name == 'Botswana':
# Load the image
img = open_file(folder + 'Botswana.mat')['Botswana']
rgb_bands = (75, 33, 15)
gt = open_file(folder + 'Botswana_gt.mat')['Botswana_gt']
label_values = ["Undefined", "Water", "Hippo grass",
"Floodplain grasses 1", "Floodplain grasses 2",
"Reeds", "Riparian", "Firescar", "Island interior",
"Acacia woodlands", "Acacia shrublands",
"Acacia grasslands", "Short mopane", "Mixed mopane",
"Exposed soils"]
ignored_labels = [0]
elif dataset_name == 'KSC':
# Load the image
img = open_file(folder + 'KSC.mat')['KSC']
rgb_bands = (43, 21, 11) # AVIRIS sensor
gt = open_file(folder + 'KSC_gt.mat')['KSC_gt']
label_values = ["Undefined", "Scrub", "Willow swamp",
"Cabbage palm hammock", "Cabbage palm/oak hammock",
"Slash pine", "Oak/broadleaf hammock",
"Hardwood swamp", "Graminoid marsh", "Spartina marsh",
"Cattail marsh", "Salt marsh", "Mud flats", "Wate"]
ignored_labels = [0]
else:
# Custom dataset
img, gt, rgb_bands, ignored_labels, label_values, palette = CUSTOM_DATASETS_CONFIG[dataset_name]['loader'](folder)
# Filter NaN out
nan_mask = np.isnan(img.sum(axis=-1))
if np.count_nonzero(nan_mask) > 0:
print("Warning: NaN have been found in the data. It is preferable to remove them beforehand. Learning on NaN data is disabled.")
img[nan_mask] = 0
gt[nan_mask] = 0
ignored_labels.append(0)
ignored_labels = list(set(ignored_labels))
# Normalization
img = np.asarray(img, dtype='float32')
img = (img - np.min(img)) / (np.max(img) - np.min(img))
return img, gt, label_values, ignored_labels, rgb_bands, palette