Как исправить ошибку 'MXNetError: include / mxnet / operator.h: 228' в собственной ssd-модели - PullRequest
0 голосов
/ 29 марта 2019

Я создаю собственную модель SSD, используя собственный набор данных VOC паскаля. Как исправить ошибку в MultiBoxTarget.

проверил эту программу со встроенным набором данных (pikachu) в моем прошлом, когда делал с собственным набором данных, это привело к ошибке

здесь я вызываю функцию ( training_target )

 import time
 from mxnet import autograd as ag
 for epoch in range(start_epoch, epochs):
# reset iterator and tick

     cls_metric.reset()
     box_metric.reset()
     tic = time.time()
     # iterate through all batch
     for i, batch in enumerate(train_data):
         btic = time.time()
         # record gradients
         with ag.record():
             x = batch[0].as_in_context(ctx)
             y = batch[1].as_in_context(ctx)
             default_anchors, class_predictions, box_predictions = net(x)
             box_target, box_mask, cls_target = training_targets(default_anchors, class_predictions, y)
             # losses
             loss1 = cls_loss(class_predictions, cls_target)
             loss2 = box_loss(box_predictions, box_target, box_mask)
             # sum all losses
             loss = loss1 + loss2
             # backpropagate
             loss.backward()
         # apply
         trainer.step(batch_size)
         # update metrics
         cls_metric.update([cls_target], [nd.transpose(class_predictions, (0, 2, 1))])
         box_metric.update([box_target], [box_predictions * box_mask])
         if (i + 1) % log_interval == 0:
             name1, val1 = cls_metric.get()
             name2, val2 = box_metric.get()
             print('[Epoch %d Batch %d] speed: %f samples/s, training: %s=%f, %s=%f'
                   %(epoch ,i, batch_size/(time.time()-btic), name1, val1, name2, val2))

     # end of epoch logging
     name1, val1 = cls_metric.get()
     name2, val2 = box_metric.get()
     print('[Epoch %d] training: %s=%f, %s=%f'%(epoch, name1, val1, name2, val2))
     print('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))

функция training_targets определена где MultiBoxTarget Я получил ошибку:

from mxnet.contrib.ndarray import MultiBoxTarget
def training_targets(default_anchors, class_predicts, labels):
       class_predicts = nd.transpose(class_predicts, axes=(0, 2, 1))
       z = MultiBoxTarget(*[default_anchors, labels, class_predicts])
       box_target = z[0]  # box offset target for (x, y, width, height)
       box_mask = z[1]  # mask is used to ignore box offsets we don't want to penalize, e.g. negative samples
       cls_target = z[2]  # cls_target is an array of labels for all anchors boxes
       return box_target, box_mask, cls_target

ожидаемый результат: обучение модели и сохранение net.save_parameters ('ssd_% d.params'% эпох)

фактическая выработка:

 MXNetError                       Traceback (most recent call last)
 <ipython-input-80-6e8fe42e4df5> in <module>()
 16             default_anchors, class_predictions, box_predictions = net(x)
 17             print(y.shape)
 ---> 18             box_target, box_mask, cls_target = training_targets(default_anchors, class_predictions, y)
 19             # losses
 20             loss1 = cls_loss(class_predictions, cls_target)

 <ipython-input-68-866caabcf8c9> in training_targets(default_anchors, class_predicts, labels)
  2 def training_targets(default_anchors, class_predicts, labels):
  3     class_predicts = nd.transpose(class_predicts, axes=(0, 2,                     
  1))
 ----> 4     z = MultiBoxTarget(*[default_anchors, labels, class_predicts])
  5     box_target = z[0]  # box offset target for (x, y, width, height)
  6     box_mask = z[1]  # mask is used to ignore box offsets we don't want to penalize, e.g. negative samples

 /usr/local/lib/python3.6/dist-packages/mxnet/ndarray/register.py in MultiBoxTarget(anchor, label, cls_pred, overlap_threshold, ignore_label, negative_mining_ratio, negative_mining_thresh, minimum_negative_samples, variances, out, name, **kwargs)

 /usr/local/lib/python3.6/dist-packages/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out)
 90         c_str_array(keys),
 91         c_str_array([str(s) for s in vals]),
 ---> 92         ctypes.byref(out_stypes)))
 93 
 94     if original_output is not None:

 /usr/local/lib/python3.6/dist-packages/mxnet/base.py in check_call(ret)
250     """
251     if ret != 0:
 --> 252         raise MXNetError(py_str(_LIB.MXGetLastError()))
253 
254 

 **MXNetError: [08:50:36] include/mxnet/operator.h:228: Check failed: in_type->at(i) == mshadow::default_type_flag || in_type->at(i) == -1 Unsupported data type 1**

 Stack trace returned 10 entries:
 [bt] (0) /usr/local/lib/python3.6/dist-          
  packages/mxnet/libmxnet.so(+0x23d55a) [0x7f454cf5555a]
 [bt] (1) /usr/local/lib/python3.6/dist-          
   packages/mxnet/libmxnet.so(+0x23dbc1) [0x7f454cf55bc1]
 [bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x2fb9dd) [0x7f454d0139dd]
 [bt] (3) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x2e0c0b5) [0x7f454fb240b5]
 [bt] (4) /usr/local/lib/python3.6/dist-     packages/mxnet/libmxnet.so(mxnet::imperative::SetShapeType(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::DispatchMode*)+0x1274) [0x7f454f936814]
 [bt] (5) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x309) [0x7f454f9400b9]
 [bt] (6) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x2b2d8b9) [0x7f454f8458b9]
 [bt] (7) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x6f) [0x7f454f845eaf]
 [bt] (8) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7f45761e6dae]
 [bt] (9) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x22f) [0x7f45761e671f]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...