Нужна помощь в преобразовании DCGAN в Java для Tensorflow для Java - PullRequest
0 голосов
/ 14 января 2020

Я пытаюсь заставить DCGAN (Глубокие сверточные генеративные состязательные сети) работать с тензорным потоком для Java.

Я добавил необходимый код в model.py от DCGAN, как показано ниже, для вывода графика для последующего использования в tenorflow для Java.

//at the beginning to define where the model will be saved
    #
    self.load_dir = load_dir
    self.models_dir = models_dir

    graph = tf.Graph()
    self.graph = graph

    self.graph.as_default()
    #
//near the end where the session is ran in order to build and save the model to be used in tensorflow for java. A model is saved every 200 samples as defined by DCGAN’s default settings.
    #
    steps = "training_steps-" + "{:08d}".format(step)
    set_models_dir = os.path.join(self.models_dir, steps)
    builder = tf.saved_model.builder.SavedModelBuilder(set_models_dir)
    self.builder = builder
    self.builder.add_meta_graph_and_variables(self.sess, [tf.saved_model.tag_constants.SERVING])
    self.builder.save()
    #

Приведенные выше коды выводят график, который загружается следующий Java код

package Main;

import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Random;

import javax.imageio.ImageIO;

import org.tensorflow.Tensor;

public class DCGAN {
    public static void main(String[] args) throws Exception {
        String model_dir = "E:\\AgentWeb\\mnist-steps\\training_steps-00050000";
        //SavedModelBundle model = SavedModelBundle.load(model_dir , "serve");
        //Session sess = model.session();

        Random rand = new Random();
        int sample_num = 64;
        int z_dim = 100;
        float [][] gen_random = new float [64][100];
        for(int i = 0 ; i < sample_num ; i++) {
            for(int j = 0 ; j < z_dim ; j++) {
                gen_random[i][j] = (float)rand.nextGaussian();
            }
        }
        Tensor <Float> sample_z = Tensor.<Float>create(gen_random, Float.class);


        Tensor <Float> sample_inputs = Tensor.<Float>create(placeholder, Float.class);
// placeholder is the tensor which I want to create after solving the problem below.

        //Tensor result = sess.runner().fetch("t_vars").feed("z", sample_z).feed("inputs", sample_inputs).run().get(3);
    }
}

(я оставил несколько комментариев, поскольку использовал их для отладки)

При использовании этого метода я застрял в определенной части перевода кода python до Java для использования в тензорном потоке для Java. В DCGAN model.py , где обрабатываются изображения, есть следующий код:

          get_image(sample_file,
                    input_height=self.input_height,
                    input_width=self.input_width,
                    resize_height=self.output_height,
                    resize_width=self.output_width,
                    crop=self.crop,
                    grayscale=self.grayscale) for sample_file in sample_files]

, который вызывает get_iamge в сохраненный_utils.py следующим образом

def get_image(image_path, input_height, input_width,
              resize_height=64, resize_width=64,
              crop=True, grayscale=False):
  image = imread(image_path, grayscale)
  return transform(image, input_height, input_width,
                   resize_height, resize_width, crop)

, который затем вызывает метод с именем imread следующим образом

def imread(path, grayscale = False):
  if (grayscale):
    return scipy.misc.imread(path, flatten = True).astype(np.float)
  else:
    # Reference: https://github.com/carpedm20/DCGAN-tensorflow/issues/162#issuecomment-315519747
    img_bgr = cv2.imread(path)
    # Reference: https://stackoverflow.com/a/15074748/
    img_rgb = img_bgr[..., ::-1]
    return img_rgb.astype(np.float)

Мой вопрос заключается в том, что я не уверен, что делает часть img_rgb = img_bgr[..., ::-1] и как мне перевести ее для использования в моем файле Java в tenorflow. java.

Я знаком с тем, как массивы срезов python, но я не знаком с тремя точками, используемыми там. Я читал там про ссылку на вопросы stackoverflow и там упоминается, что она похожа на img[:, :, ::-1]. Но я не совсем уверен, что именно он делает.

Любая помощь приветствуется, и спасибо, что нашли время прочитать этот длинный пост.

1 Ответ

0 голосов
/ 16 января 2020

Что в основном означает imread и get_image 1) читает изображение 2) конвертирует его из BGR в RGB 3) конвертирует в плавающие 4) масштабирует изображение

Вы можете сделать это в Java либо с помощью библиотеки изображений, такой как JMagick или AWT, либо с помощью TensorFlow.

Если вы используете TensorFlow, можно выполнить эту предварительную обработку в активном режиме или путем построения и запуска небольшого графика , Например, учитывая tf экземпляр org.tensorflow.op.Ops:

  • tf.image.decode* может считывать содержимое изображения (вы знаете, что знаете тип изображения, хотя выбираете правильную операцию).
  • tf.reverse может обратить значение в измерении вашего канала (от RGB до BGR)
  • tf.dtypes.cast может преобразовать изображение в плавающие
  • tf.image.resizeBilinear может изменить масштаб изображения
...