Можете ли вы сказать мне, как решить проблему, может быть, где я могу использовать "ndim" для решения проблемы? Неясно, какие оси и какой массив?
Ошибка:
ValueError Traceback (most recent call last)
<ipython-input-8-2e591d410a95> in <module>()
147
148 start = timer()
--> 149 for t,(input, image_id) in enumerate(loader):
150 print('\r t = 3%d / 3%d %s %s : %s'%(
151 t, len(loader)-1, str(input.shape), image_id[0], time_to_str((timer() -
start),'sec'),
4 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
383 def _next_data(self):
384 index = self._next_index() # may raise StopIteration
--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
386 if self._pin_memory:
387 data = _utils.pin_memory.pin_memory(data)
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self,
possibly_batched_index)
45 else:
46 data = self.dataset[possibly_batched_index]
---> 47 return self.collate_fn(data)
<ipython-input-8-2e591d410a95> in null_collate(batch)
115
116 input = np.stack(input)
--> 117 input = torch.from_numpy(image_to_input(input))
118 return input, image_id
119
<ipython-input-8-2e591d410a95> in image_to_input(image)
76 input = image.astype(np.float32)
77 input = input[...,::-1]/255
---> 78 input = input.transpose(0,3,1,2)
79 input[:,0] = (input[:,0]-IMAGE_RGB_MEAN[0])/IMAGE_RGB_STD[0]
80 input[:,1] = (input[:,1]-IMAGE_RGB_MEAN[1])/IMAGE_RGB_STD[1]
ValueError: axes don't match array
Полный код ниже:
DATA_DIR = '/content/drive/My Drive/Severtal/data/severstal-steel-defect-detection'
DATAA_DIR = '/content/drive/My Drive/Severtal/data/train_images/'
DATAAA_DIR = '/content/drive/My Drive/Severtal/data/'
CHECKPOINT_FILE = USER_DATA + '/model_unet_resnet18.pth'
imagee = cv2.imread(DATAA_DIR+"0a4ef8ee7.jpg")
IMAGE_RGB_MEANst =cv2.mean(imagee)
IMAGE_RGB_MEAN=IMAGE_RGB_MEANst[:3]
IMAGE_RGB_STDst =cv2.meanStdDev(imagee)
IMAGE_RGB_STD = IMAGE_RGB_STDst[1]
def time_to_str(t, mode='min'):
if mode=='min':
t = int(t)/60
hr = t//60
min = t%60
return '%2d hr %02d min'%(hr,min)
elif mode=='sec':
t = int(t)
min = t//60
sec = t%60
return '%2d min %02d sec'%(min,sec)
else:
raise NotImplementedError
BatchNorm2d = nn.BatchNorm2d
def image_to_input(image):
input = image.astype(np.float32)
input = input[...,::-1]/255
input = input.transpose(0,3,1,2)
input[:,0] = (input[:,0]-IMAGE_RGB_MEAN[0])/IMAGE_RGB_STD[0]
input[:,1] = (input[:,1]-IMAGE_RGB_MEAN[1])/IMAGE_RGB_STD[1]
input[:,2] = (input[:,2]-IMAGE_RGB_MEAN[2])/IMAGE_RGB_STD[2]
return input
class KaggleTestDataset(Dataset):
def __init__(self):
df = pd.read_csv(DATA_DIR + '/sample_submission.csv')
df['ImageId'] = df['ImageId_ClassId'].apply(lambda x: x.split('_')[0])
self.uid = df['ImageId'].unique().tolist()
def __str__(self):
string = ''
string += '\tlen = %d\n'%len(self)
return string
def __len__(self):
return len(self.uid)
def __getitem__(self, index):
image_id = self.uid[index]
image = cv2.imread(DATA_DIR + '/test_images/%s'%(image_id), cv2.IMREAD_COLOR)
return image, image_id
def null_collate(batch):
batch_size = len(batch)
input = []
image_id = []
for b in range(batch_size):
input.append(batch[b][0])
image_id.append(batch[b][1])
input = np.stack(input)
input = torch.from_numpy(image_to_input(input))
return input, image_id
def run_make_submission_csv():
threshold = 0.5
min_size = 3500
dataset = KaggleTestDataset()
print(dataset)
loader = DataLoader(
dataset,
sampler = SequentialSampler(dataset),
batch_size = 8,
drop_last = False,
num_workers = 0,
pin_memory = True,
collate_fn = null_collate
)
image_id_class_id = []
encoded_pixel = []
start = timer()
for t,(input, image_id) in enumerate(loader):
print('\r t = 3%d / 3%d %s %s : %s'%(
t, len(loader)-1, str(input.shape), image_id[0], time_to_str((timer() - start),'sec'),
),end='', flush=True)
input = input.cuda()
with torch.no_grad():
logit = net(input)
probability= torch.sigmoid(logit)