Как определить ошибку пользовательской целевой функции TFLearn с помощью логической маски? - PullRequest
0 голосов
/ 09 ноября 2018

Я пытаюсь определить пользовательскую функцию потерь

def my_objective(y_pred, y_true):
  pred_slice = tf.slice(y_pred, [0,0], [-1,1])
  true_slice = tf.slice(y_true, [0,0], [-1,1])
  mask = np.array(d['mask_'+str(0)], dtype=bool)
  masked_pred = tf.boolean_mask(pred_slice, mask)
  masked_true = tf.boolean_mask(true_slice, mask)
  return tf.reduce_mean(tf.square(masked_pred - masked_true))

...
net = tflearn.regression(net, optimizer='adam', loss=my_objective)
model = tflearn.DNN(net)
model.fit(Xtrain, ytrain)

, но я получаю следующую ошибку при попытке тренироваться:

--------------------------------- Идентификатор запуска: H57ZXZ Каталог журнала: / tmp / tflearn_logs /

--------------------------------- Учебные образцы: 13136 Валидационные образцы: 0

--------------------------------------------------------------------------- InvalidArgumentError Traceback (последний вызов был последним) /opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py в _do_call (self, fn, * args) 1333 try: -> 1334 return fn (* args) 1335 за исключением ошибок. Ошибка как e:

/ opt / conda / lib / python3.6 / site-packages / tenorflowflow / python/client/session.py в _run_fn (feed_dict, fetch_list, target_list, options, run_metadata) 1318 возвращает self._call_tf_sessionrun (-> 1319 параметров, feed_dict, fetch_list, target_list, run_metadata) 1320

optlib / python3.6 / site-packages / tenorflow / python / client / session.py в _call_tf_sessionrun (self, options, feed_dict, fetch_list, target_list, run_metadata) 1406 self._session, options, feed_dict, fetch_list, target_list, -> 1407 run_metadata) 1408

InvalidArgumentError: индексы отсутствуют в [0] = 66[0, 64) [[{{node boolean_mask_1 / GatherV2}} = GatherV2 [Taxis = DT_INT32, Tindices = DT_INT64, Tparams = DT_FLOAT, _device = "/ job: localhost / replica: 0 / task: 0 / device: CPU:0 "] (boolean_mask_1 / Reshape, boolean_mask / Squeeze, boolean_mask / concat / axis)]]

Во время обработки вышеуказанного исключения произошло другое исключение:

InvalidArgumentError Traceback (последний вызов был последним) в 1 модели = tflearn.DNN (нетто) ----> 2 model.fit (Xtrain, ytrain)

/ opt / conda / lib / python3.6 / site-packages / tflearn / models /dnn.py в соответствии (self, X_inputs, Y_targets, n_epoch, validation_set, show_metric, batch_size, shuffle, snapshot_epoch, snapshot_step, excl_trainops, validation_batch_size, run_id, обратные вызовы) 214 213 excl_trainops = excl_trainops = excl_trainops =acks = обратные вызовы) 217 ​​218 def fit_batch (self, X_inputs, Y_targets):

/ opt / conda / lib / python3.6 / site-packages / tflearn / helpers / trainer.py в fit (self, feed_dicts), n_epoch, val_feed_dicts, show_metric, snapshot_step, snapshot_epoch, shuffle_all, dprep_dict, daug_dict, excl_trainops, run_id, обратные вызовы) 337 (bool (self.best_checkpoint_path) |snapshot_epoch), 338 snapshot_step, -> 339 show_metric) 340 341 # Обновить состояние обучения

/ opt / conda / lib / python3.6 / site-packages / tflearn / helpers / trainer.py в _train (self, training_step, snapshot_epoch, snapshot_step, show_metric) 816 tflearn.is_training (True, сессия = self.session) 817 _, train_summ_str = self.session.run ([self.train, self.summ_op], -> 818 feed_batch) 819820 # Извлечь значение потерь из итоговой строки

/ opt / conda / lib / python3.6 / site-packages / tenorflow / python / client / session.py в ходе выполнения (self, fetches, feed_dict, options, run_metadata)) 927 try: 928 result = self._run (Нет, выборки, feed_dict, options_ptr, -> 929 run_metadata_ptr) 930, если run_metadata: 931 proto_data = tf_session.TF_GetBuffer (run_metadata_ptr) * 1029 /

/python3.6/site-packages/tensorflow/python/client/session.pyв _run (self, handle, fetches, feed_dict, options, run_metadata)
1150 если final_fetches или final_targets или (обрабатывать и feed_dict_tensor): 1151 результатов = self._do_run (дескриптор, final_targets, final_fetches, -> 1152 feed_dict_tensor, options, run_metadata) 1153 else: 1154 результаты = []

/ Opt / Конда / Библиотека / python3.6 / сайт-пакеты / tensorflow / питон / клиент / session.py в _do_run (self, handle, target_list, fetch_list, feed_dict, параметры, run_metadata) 1326, если дескриптор None: 1327 return self._do_call (_run_fn, каналы, выборки, цели, параметры, -> 1328 run_metadata) 1329 остальное: 1330 вернуть self._do_call (_prun_fn, дескриптор, фиды, выборки)

/ Opt / Конда / Библиотека / python3.6 / сайт-пакеты / tensorflow / питон / клиент / session.py в _do_call (self, fn, * args) 1346 проход 1347
message = error_interpolation.interpolate (сообщение, self._graph) -> 1348 тип повышения (e) (node_def, op, message) 1349 1350 def _extend_graph (self):

InvalidArgumentError: indices [0] = 66 отсутствует в [0, 64) [[узел boolean_mask_1 / GatherV2 (определено в: 6) = GatherV2 [Taxis = DT_INT32, Tindices = DT_INT64, Tparams = DT_FLOAT, _device = "/ job: localhost / replica: 0 / task: 0 / device: CPU: 0"] (boolean_mask_1 / Reshape, boolean_mask / Squeeze, boolean_mask / concat / axis)]]

Вызывается операцией 'boolean_mask_1 / GatherV2', определенной в: File "/opt/conda/lib/python3.6/runpy.py", строка 193, в _run_module_as_main Файл " main ", mod_spec) "/opt/conda/lib/python3.6/runpy.py", строка 85, в _run_code Файл exec (code, run_globals) "/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py", строка 16, в Файл app.launch_new_instance () "/opt/conda/lib/python3.6/site-packages/traitlets/config/application.py", строка 658, в launch_instance Файл app.start () "/opt/conda/lib/python3.6/site-packages/ipykernel/kernelapp.py", строка 505, в начале Файл self.io_loop.start () "/opt/conda/lib/python3.6/site-packages/tornado/platform/asyncio.py", строка 132, в начале Файл self.asyncio_loop.run_forever () "/opt/conda/lib/python3.6/asyncio/base_events.py", строка 422, в run_forever Файл self._run_once () "/opt/conda/lib/python3.6/asyncio/base_events.py", строка 1434, в _run_once handle._run () Файл "/opt/conda/lib/python3.6/asyncio/events.py", строка 145, в _run self._callback (* self._args) Файл "/opt/conda/lib/python3.6/site-packages/tornado/ioloop.py", строка 758, в _run_callback ret = callback () Файл "/opt/conda/lib/python3.6/site-packages/tornado/stack_context.py", строка 300, в null_wrapper return fn (* args, ** kwargs) Файл "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", строка 1233, в внутренний Файл self.run () "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", строка 1147, в бежать yielded = self.gen.send (значение) Файл "/opt/conda/lib/python3.6/site-packages/ipykernel/kernelbase.py", строка 357, в process_one yield gen.maybe_future (dispatch (* args)) Файл "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", строка 326, в обертка yielded = следующий (результат) файл "/opt/conda/lib/python3.6/site-packages/ipykernel/kernelbase.py", строка 267, в отправке yield gen.maybe_future (обработчик (поток, идентификаторы, сообщения)) Файл «/opt/conda/lib/python3.6/site-packages/tornado/gen.py», строка 326 обертка yielded = следующий (результат) файл "/opt/conda/lib/python3.6/site-packages/ipykernel/kernelbase.py", строка 534, в execute_request user_expressions, allow_stdin, файл "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", строка 326, в обертка yielded = следующий (результат) файл "/opt/conda/lib/python3.6/site-packages/ipykernel/ipkernel.py", строка 294, в do_execute res = shell.run_cell (код, store_history = store_history, тихий = без звука) Файл"/opt/conda/lib/python3.6/site-packages/ipykernel/zmqshell.py", строка 536, в run_cell вернуть супер (ZMQInteractiveShell, self) .run_cell (* args, ** kwargs) файл "/Opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", строка 2817, в run_cell raw_cell, store_history, silent, shell_futures) Файл "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", строка 2843, в _run_cell вернуть бегуна (coro) Файл "/opt/conda/lib/python3.6/site-packages/IPython/core/async_helpers.py", строка 67, в _pseudo_sync_runner coro.send (Нет) Файл "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", строка 3018, в run_cell_async интерактивность = интерактивность, компилятор = компилятор, результат = результат) Файл "/Opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", строка 3183, в run_ast_nodes if (yield from self.run_code (code, result)): файл "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", строка 3265, в run_code exec (code_obj, self.user_global_ns, self.user_ns) Файл "", строка 7, в net = tflearn.regression (net, оптимизатор = 'adam', loss = my_objective) Файл "/Opt/conda/lib/python3.6/site-packages/tflearn/layers/estimator.py", строка 178, в регрессии потеря = потеря (входящий, заполнитель) Файл "", строка 6, в my_objective файл masked_true = tf.boolean_mask (true_slice, mask) "/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", строка 1204, в логической маске вернуть файл _apply_mask_1d (тензор, маска, ось) "/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", строка 1174, в _apply_mask_1d возвратный сбор (reshaped_tensor, indexes, axis = axis) файл "/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", линия 2675, в сборе вернуть файл gen_array_ops.gather_v2 (параметры, индексы, ось, имя = имя) Файл "/Opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", строка 3332, в collect_v2 "GatherV2", params = params, индексы = индексы, ось = ось, имя = имя) Файл "/Opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", строка 787, в _apply_op_helper op_def = op_def) Файл "/opt/conda/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", строка 488, в new_func вернуть func (* args, ** kwargs) Файл "/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", строка 3274, в create_op op_def = op_def) Файл "/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", строка 1770, в init self._traceback = tf_stack.extract_stack ()

InvalidArgumentError (см. Выше для отслеживания): индексы [0] = 66 не в [0, 64) [[узел boolean_mask_1 / GatherV2 (определен в : 6) = GatherV2 [Такси = DT_INT32, Tindices = DT_INT64, Tparams = DT_FLOAT, _device = "/ работа: локальный / реплика: 0 / задача: 0 / Устройство: ЦП: 0"] (boolean_mask_1 / Reshape, boolean_mask / Squeeze, boolean_mask / concat / axis)]]

...