Как использовать увеличение размерности для набора данных MNIST? - PullRequest
0 голосов
/ 30 октября 2019

Я пытаюсь использовать набор данных MNIST для Alexnet с Keras, поэтому я должен изменить измерение (поскольку MNIST имеет полутоновую шкалу, Alexnet должен быть RGB, а также 227 * 227). Теперь я получаю некоторые результаты, numpy_imgs=(10,227,227,1), но я должен сделать это как (10,227,227,3), вы можете увидеть, что я делал раньше, в моем коде, спасибо.

  import tensorflow as tf
  import numpy as np
  from tensorflow.examples.tutorials.mnist import input_data

  mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  batch=mnist.train.next_batch(10)
  X_batch = batch[0]
  batch_tensor = tf.reshape(X_batch, [10, 28, 28, 1])
  resized_images = tf.image.resize_images(batch_tensor, [227,227])
  with tf.Session() as sess:
      numpy_imgs = resized_images.eval(session=sess) # mnist images converted to numpy array

  r2=[]
  t=list(numpy_imgs)
  dim = np.zeros((227,227))
  for i in range(0,10):
      R=np.stack((t[i],dim,dim),axis=2)
      R=list(R)
      r2.append(R)
  y3=np.asarray(r2)

Я попробовал что-то ниже, но получил ошибку вроде«ValueError: все входные массивы должны иметь одинаковую форму», как я могу это исправить?

1 Ответ

0 голосов
/ 30 октября 2019

Взгляните на tf.tile, который повторяет тензор по одному из его измерений:

y3 = tf.tile(numpy_imgs, (1, 1, 1, 3))

Если вы хотите дополнить его нулевыми тензорами, вы должны использовать tf.concat (или np.concatenate вместо stack.

dim = np.zeros((227, 227, 2))
for i in range(0, 10):
    R = np.concatenate((t[i], dim), axis=2)
    ...

Вы можете даже сделать это более кратко, обрабатывая все партии сразу:

dim = np.zeros((10, 227, 227, 2))
y3 = np.concatenate((numpy_imgs, dim), axis=3

Вот более общий пример:

import numpy as np

def main():
    i = np.random.random((10, 227, 227, 1))
    dim = np.zeros((10, 227, 227, 2))
    print(i.shape)
    print(dim.shape)
    print(np.concatenate((i, dim), axis=3).shape)

if __name__ == '__main__':
    main()
(10, 227, 227, 1)
(10, 227, 227, 2)
(10, 227, 227, 3)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...