Запретить Tensosflow Dataset сбрасывать генератор при нескольких вызовах model.predict - PullRequest
0 голосов
/ 18 июня 2020

Я использую метод tenorflow Dataset from_generator, чтобы делать прогнозы с использованием модели CNN для разных пакетов. Но я хочу добавить несколько дополнительных logi c после каждого пакетного прогноза. В частности, я хочу объединить разные результаты.

Вот моя функция генератора:

def gen_predict(img_no):
  img_data = nib.load('./testing-images/10' + '%02d' %img_no + '_3.nii.gz').get_fdata()
  patch_size = 23
  dist_center = (patch_size - 1) // 2
  l, b, h = img_data.shape
  for zc in range(dist_center, h - dist_center - 1):
    for yc in range(dist_center, b - dist_center - 1):
      for xc in range(dist_center, l - dist_center - 1):    
        print(xc,yc,zc) 
        xl, yl, zl = (xc - dist_center, yc - dist_center, zc - dist_center)
        xr, yr, zr = (xc + dist_center, yc + dist_center, zc + dist_center)
        cartesianCoordinate = np.array([xc, yc, zc])
        spectralCoordinates = np.array([0, 0, 0])
        X = (np.array(img_data[xl:(xr + 1), yl:(yr + 1), zl:(zr + 1)]), np.concatenate((cartesianCoordinate, spectralCoordinates)).reshape((6,1)))
        yield (X,)

Проблема в том, что после каждого вызова прогнозирования генератор сбрасывается, и при следующем вызове прогнозирования он дает прогнозы для того же набора данных. Вот мой код:

dataset_pred = tf.data.Dataset.from_generator(lambda: gen_predict(3), ((tf.float32, tf.float32),), output_shapes=((tf.TensorShape([23,23,23]), tf.TensorShape([6,1])),))
dataset_pred = dataset_pred.batch(BS)
for i in range(num_batches):
  temp_pred = np.array(model.predict(dataset_pred, batch_size=BS, steps=1))
  ## aggregate the temp_pred result ##

Я хочу имитировать c поведение model.predict(dataset_pred, batch_size=BS, steps=num_batches) с помощью этого дополнительного logi c. Кроме того, я не могу сохранить результат этого вызова из-за большого num_batches.

EDIT: Я добавил ответ. Но буду благодарен за любую помощь в повышении эффективности.

1 Ответ

0 голосов
/ 18 июня 2020

Я нашел ответ. По сути, мы можем сохранить соответствующий генератор в переменной, а затем использовать лямбда, чтобы сделать его вызываемым. Это не сбрасывает генератор.

cur_gen = gen_predict(img_no)
dataset_pred = tf.data.Dataset.from_generator(lambda: cur_gen, ((tf.float32, tf.float32),), output_shapes=((tf.TensorShape([23,23,23]), tf.TensorShape([6,1])),))
dataset_pred = dataset_pred.batch(BS)
for i in range(num_batches):
  temp_pred = np.array(model.predict(dataset_pred, batch_size=BS, steps=1))
  ## aggregate the temp_pred result ##
...