Как использовать Keras'valu_generator (), когда "Ось должна быть указана ..." - PullRequest
0 голосов
/ 25 ноября 2018

Почему Keras говорит мне, что вес и размеры партий различны?Как я могу это исправить?(добавление next(...) здесь не помогает).Заранее спасибо;есть кое-что, что я просто не получаю здесь.

Ошибка ная :я_производителя (): TypeError: Axis must be specified when shapes of a and weights differ.

from sklearn.utils import shuffle as identical_shuffle

SAMPLES_PER_BATCH=1
BATCHES_PER_EPOCH=1
BATCHES_PER_VALIDATION=1
EPOCHS_PER_SIMULATION=2

def generate_training(batch_size=64):
    validation_length = (int(len(data.train)*0.25) // batch_size) * batch_size
    while True:
        for i in range(validation_length,len(data.train),batch_size):
            x,y,n = identical_shuffle(data.train[i:i+batch_size],data.target[i:i+batch_size],data.context[i:i+batch_size])
            yield {'input':x,'target':n}, y

def generate_validation(batch_size=64):
    validation_length = (int(len(data.train)*0.25) // batch_size) * batch_size
    while True:
        for i in range(0,validation_length,batch_size):
            x,y,n = identical_shuffle(data.train[i:i+batch_size],data.target[i:i+batch_size],data.context[i:i+batch_size])
            yield {'input':x,'target':n}, y

for epoch in range(EPOCHS_PER_SIMULATION):
  for batch in range(BATCHES_PER_EPOCH):
    result_training = model.train_on_batch( *next(generate_training(batch_size=SAMPLES_PER_BATCH)) )
    # <redacted operations on result_training>
  result_validation = model.evaluate_generator( generate_validation(batch_size=SAMPLES_PER_BATCH), steps=BATCHES_PER_VALIDATION )

Для сравнения, нижеприведенный код работает с тем же генератором.

model.fit_generator(generator=generate_training(batch_size=SAMPLES_PER_BATCH),
                    validation_data=next(generate_validation(batch_size=SAMPLES_PER_BATCH)),
                    validation_steps=1,
                    steps_per_epoch=1,
                    epochs=1)

Полная трассировка

TypeError                                 Traceback (most recent call last)
<ipython-input-37-962a23d537c3> in <module>()
     79     print(result_training)
     80   model.reset_states()
---> 81   result_validation = model.evaluate_generator( generate_validation(batch_size=SAMPLES_PER_BATCH), steps=BATCHES_PER_VALIDATION )
     82   #result_validation = model.test_on_batch( *next(generate_validation(batch_size=SAMPLES_PER_BATCH)) )
     83   print("VALIDATION OF EPOCH "+str(epoch))

/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in evaluate_generator(self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose)
   1470             workers=workers,
   1471             use_multiprocessing=use_multiprocessing,
-> 1472             verbose=verbose)
   1473 
   1474     @interfaces.legacy_generator_methods_support

/usr/local/lib/python3.6/dist-packages/keras/engine/training_generator.py in evaluate_generator(model, generator, steps, max_queue_size, workers, use_multiprocessing, verbose)
    375         if i not in stateful_metric_indices:
    376             averages.append(np.average([out[i] for out in outs_per_batch],
--> 377                                        weights=batch_sizes))
    378         else:
    379             averages.append(np.float64(outs_per_batch[-1][i]))

/usr/local/lib/python3.6/dist-packages/numpy/lib/function_base.py in average(a, axis, weights, returned)
   1140             if axis is None:
   1141                 raise TypeError(
-> 1142                     "Axis must be specified when shapes of a and weights "
   1143                     "differ.")
   1144             if wgt.ndim != 1:

TypeError: Axis must be specified when shapes of a and weights differ.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...