Сбой model.fit_generator с неподдерживаемыми типами операндов для *: 'Dimension' и 'float' - PullRequest
0 голосов
/ 29 марта 2019

Меня беспокоит странная проблема при использовании кераса.Я успешно использовал model.fit () для триана моей модели.Затем я хочу ускорить его и, таким образом, я использую model.fit_generator ().Тем не менее, я потерпел неудачу, и ошибка отображается следующим образом:

Traceback (most recent call last):
  File "E:/deepmodel/main.py", line 100, in <module>
    model.fit_generator(generator=next_batch(), steps_per_epoch=50, verbose=1, epochs=10)
  File "C:\Users\17210\AppData\Local\conda\conda\envs\GPUDL\lib\site-packages\tensorflow\python\keras\engine\training.py", line 2177, in fit_generator
    initial_epoch=initial_epoch)
  File "C:\Users\17210\AppData\Local\conda\conda\envs\GPUDL\lib\site-packages\tensorflow\python\keras\engine\training_generator.py", line 183, in fit_generator
    callbacks.on_batch_end(batch_index, batch_logs)
  File "C:\Users\17210\AppData\Local\conda\conda\envs\GPUDL\lib\site-packages\tensorflow\python\keras\callbacks.py", line 249, in on_batch_end
    callback.on_batch_end(batch, logs)
  File "C:\Users\17210\AppData\Local\conda\conda\envs\GPUDL\lib\site-packages\tensorflow\python\keras\callbacks.py", line 375, in on_batch_end
    self.totals[k] = v * batch_size
  File "C:\Users\17210\AppData\Local\conda\conda\envs\GPUDL\lib\site-packages\tensorflow\python\framework\tensor_shape.py", line 260, in __rmul__
    return self * other
TypeError: unsupported operand type(s) for *: 'Dimension' and 'float'

Я действительно озадачен, и когда я иду на callbacks.py, на самом деле

  def on_batch_end(self, batch, logs=None):
    logs = logs or {}
    batch_size = logs.get('size', 0)
    # In case of distribution strategy we can potentially run multiple steps
    # at the same time, we should account for that in the `seen` calculation.
    num_steps = logs.get('num_steps', 1)
    self.seen += batch_size * num_steps

    for k, v in logs.items():
      if k in self.stateful_metrics:
        self.totals[k] = v
      else:
        if k in self.totals:
          self.totals[k] += v * batch_size
        else:
          self.totals[k] = v * batch_size

Проблема заключается в том, что batch_sizeРазмер (2048) (как в генераторе я использую batch_size 2048).Я думаю, что это приводит к TypeError, однако я не знаю, что его вызывает и как его избежать.

Кроме того, я также отлаживаю с использованием model.fit (), а затем batch_size в функцииon_batch_end равно 1, поэтому ничего неожиданного не произошло.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...