Я использую Unet с pytorch для сегментирования судов, когда я тренирую свои сети, он рушится в 5-ю эпоху, где должна происходить валидация
def image_concatenate(image, crop_num1, crop_num2, dim1, dim2):
"""concatenate images
Args :
image : output images (should be square)
crop_num2 (int) : number of crop in horizontal way (2)
crop_num1 (int) : number of crop in vertical way (2)
dim1(int) : vertical size of output (512)
dim2(int) : horizontal size_of_output (512)
Return :
div_array : numpy arrays of numbers of 1,2,4
"""
crop_size = image.shape[1] # size of crop
empty_array = np.zeros([dim1, dim2]).astype("float64") # to make sure no overflow
dim1_stride = stride_size(dim1, crop_num1, crop_size) # vertical stride
dim2_stride = stride_size(dim2, crop_num2, crop_size) # horizontal stride
index = 0
for i in range(crop_num1):
for j in range(crop_num2):
# add image to empty_array at specific position
empty_array[dim1_stride*i:dim1_stride*i+ crop_size,
dim2_stride*j:dim2_stride*j+ crop_size] += image[index]
index += 1
return empty_array
, а когда происходит обучение, это ошибка
Epoch 5 Train loss: 0.012208586327390733 Train acc 0.9599569108751085
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-69-d8ccfac27aaf> in <module>()
43 # Validation every 5 epoch
44 if (i+1) % 5 == 0:
---> 45 val_acc, val_loss = validate_model(model, SEM_val_load, criterion, i+1, True, image_save_path)
46 print('Val loss:', val_loss, "val acc:", val_acc)
47 values = [i+1, train_loss, train_acc, val_loss, val_acc]
2 frames
<ipython-input-61-7f3e547fab95> in image_concatenate(image, crop_num1, crop_num2, dim1, dim2)
330 # add image to empty_array at specific position
331 empty_array[dim1_stride*i:dim1_stride*i+ crop_size,
--> 332 dim2_stride*j:dim2_stride*j+ crop_size] += image[index]
333 index += 1
334 return empty_array
IndexError: index 1 is out of bounds for axis 0 with size 1