Я запрограммировал классификатор изображений с помощью PyTorch. Когда я тренирую net с несколькими параметрами, он либо случайно прогрессирует и учится, либо ничего не узнает. Я запускал код несколько раз, и для тех же параметров у меня разные результаты.
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
self.conv2 = nn.Conv2d(6, 12, kernel_size=5,)
self.conv3 = nn.Conv2d(12, 24, kernel_size=3 )
self.conv4 = nn.Conv2d(24, 24, kernel_size=3 )
self.fc1 = nn.Linear(864, 300)
self.fc2 = nn.Linear(300,2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x,2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x,2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x,2)
x = F.relu(self.conv4(x))
x = F.max_pool2d(x,2)
x = F.max_pool2d(x,2)
#print(x.shape)
x = x.reshape(-1, 864)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = CatDogs('train/',transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(
mean = [0.485,0.456,0.406],
std= [0.229,0.224,0.225])
]))
train_set, test_set = torch.utils.data.random_split(data, [20000, 5000])
params = OrderedDict(
batch_size = [10,20,30,40,50,60,70,80,90,100,110,120]
,lr = [.005]
,shuffle = [True]
,pin_memory = [True]
,num_workers = [2]
)
m = RunManager()
for run in RunBuilder.get_runs(params):
network = Network()
network = network.to(device)
loader = DataLoader(train_set,
pin_memory=run.pin_memory,
batch_size=run.batch_size,
shuffle=run.shuffle,
num_workers=run.num_workers)
optimizer = optim.Adam(network.parameters(), lr=run.lr)
network.train()
m.begin_run(run, network, loader)
for epoch in range(10):
m.begin_epoch()
for batch in loader: # Get Batch
images, labels = batch
images = images.to(device)
labels = labels.to(device)
preds = network(images) # Pass Batch
loss = F.cross_entropy(preds, labels) # Calculate Loss
optimizer.zero_grad()
loss.backward() # Calculate Gradients
optimizer.step() # Update Weights
m.track_loss(loss, images)
m.track_num_correct(preds, labels)
m.end_epoch()
m.end_run()
m.save('results_batchsize_3')
Пример моего вывода:
Class RunManager отслеживает несколько статистических данных, таких как потери, точность и время, но они не влияют на обучение.
В результате сеть возвращает положительные результаты, и потери уменьшаются, хотя иногда потери остаются почти такими же.
Если у вас есть предположения, что это вызвало, дайте мне знать. Большое спасибо за помощь!