У меня есть набор изображений из 250 изображений формы (3, 320, 240) и 250 файлов аннотаций.Я использую ChainerCV для обнаружения и распознавания двух классов на изображении: мяч и игрок.Здесь мы используем модель SSD300, предварительно обученную на наборе данных ImageNet.
РЕДАКТИРОВАТЬ: КЛАСС ДЛЯ СОЗДАНИЯ ОБЪЕКТА ДАННЫХ
bball_labels = ('ball','player')
class BBall_dataset(VOCBboxDataset):
def _get_annotations(self, i):
id_ = self.ids[i]
anno = ET.parse(os.path.join(self.data_dir, 'Annotations', id_ +
'.xml'))
bbox = []
label = []
difficult = []
for obj in anno.findall('object'):
bndbox_anno = obj.find('bndbox')
bbox.append([int(bndbox_anno.find(tag).text) - 1 for tag in ('ymin',
'xmin', 'ymax', 'xmax')])
name = obj.find('name').text.lower().strip()
label.append(bball_labels.index(name))
bbox = np.stack(bbox).astype(np.float32)
label = np.stack(label).astype(np.int32)
difficult = np.array(difficult, dtype=np.bool)
return bbox, label, difficult
СКАЧАТЬ ПОДГОТОВЛЕННУЮ МОДЕЛЬ
import chainer
from chainercv.links import SSD300
from chainercv.links.model.ssd import multibox_loss
class MultiboxTrainChain(chainer.Chain):
def __init__(self, model, alpha=1, k=3):
super(MultiboxTrainChain, self).__init__()
with self.init_scope():
self.model = model
self.alpha = alpha
self.k = k
def forward(self, imgs, gt_mb_locs, gt_mb_labels):
mb_locs, mb_confs = self.model(imgs)
loc_loss, conf_loss = multibox_loss(
mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, self.k)
loss = loc_loss * self.alpha + conf_loss
chainer.reporter.report(
{'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss},
self)
return loss
model = SSD300(n_fg_class=len(bball_labels), pretrained_model='imagenet')
train_chain = MultiboxTrainChain(model)
ПРЕОБРАЗОВАНИЕDATASET импортирует необходимые библиотеки
class Transform(object):
def __init__(self, coder, size, mean):
self.coder = copy.copy(coder)
self.coder.to_cpu()
self.size = size
self.mean = mean
def __call__(self, in_data):
img, bbox, label = in_data
img = random_distort(img)
if np.random.randint(2):
img, param = transforms.random_expand(img, fill=self.mean,
return_param=True)
bbox = transforms.translate_bbox(bbox, y_offset=param['y_offset'],
x_offset=param['x_offset'])
img, param = random_crop_with_bbox_constraints(img, bbox,
return_param=True)
bbox, param = transforms.crop_bbox(bbox, y_slice=param['y_slice'],
x_slice=param['x_slice'],allow_outside_center=False, return_param=True)
label = label[param['index']]
_, H, W = img.shape
img = resize_with_random_interpolation(img, (self.size, self.size))
bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size))
img, params = transforms.random_flip(img, x_random=True,
return_param=True)
bbox = transforms.flip_bbox(bbox, (self.size, self.size),
x_flip=params['x_flip'])
img -= self.mean
mb_loc, mb_label = self.coder.encode(bbox, label)
return img, mb_loc, mb_label
transformed_train_dataset = TransformDataset(train_dataset,
Transform(model.coder, model.insize, model.mean))
train_iter =
chainer.iterators.MultiprocessIterator(transformed_train_dataset,
batchsize)
valid_iter = chainer.iterators.SerialIterator(valid_dataset,
batchsize,
repeat=False, shuffle=False)
Во время обучения выдает следующую ошибку:
Exception in thread Thread-4:
Traceback (most recent call last):
File "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
self.run()
File "/usr/lib/python3.6/threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.6/dist-
packages/chainer/iterators/multiprocess_iterator.py", line 401, in
fetch_batch
batch_ret[0] = [self.dataset[idx] for idx in indices]
File "/usr/local/lib/python3.6/dist-
........................................................................
packages/chainer/iterators/multiprocess_iterator.py", line 401, in
<listcomp>
batch_ret[0] = [self.dataset[idx] for idx in indices]
File "/usr/local/lib/python3.6/dist-
packages/chainer/dataset/dataset_mixin.py", line 67, in __getitem__
return self.get_example(index)
File "/usr/local/lib/python3.6/dist-
packages/chainer/datasets/transform_dataset.py", line 51, in get_example
in_data = self._dataset[i]
File "/usr/local/lib/python3.6/dist-
packages/chainer/dataset/dataset_mixin.py", line 67, in __getitem__
return self.get_example(index)
File "/usr/local/lib/python3.6/dist--
packages/chainercv/utils/image/read_image.py", line 120, in read_image
return _read_image_cv2(path, dtype, color, alpha)
File "/usr/local/lib/python3.6/dist-
packages/chainercv/utils/image/read_image.py", line 49, in _read_image_cv2
if img.ndim == 2:
AttributeError: 'NoneType' object has no attribute 'ndim'
TypeError: 'NoneType' object is not iterable
Я хочу знать, что является причиной этого.В этом случае неправильный формат входных данных?И как разрешить эту ситуацию.