Учебные машины для рисования - PullRequest
0 голосов
/ 21 сентября 2019

хорошо .. так что дело в том, что я пытался решить проблему глубокого обучения, она включает в себя использование базовых штрихов эскизов в геометрическом формате, цель состоит в том, чтобы научить программу самостоятельно создавать линейные эскизы различных элементов элементовв этом диапазоне.поэтому команда специалистов по мозгу Google провела тренинг по чему-то похожему, под названием Google Magenta (научить машины рисовать и делать музыку).

Поэтому в основном я использовал их контрольно-пропускные пункты, просто восстановил их на моей виртуальной машине иrun.

В файлах контрольных точек работали четыре (4) основных разных модели (одинаковый код для запуска четырех), две модели работали успешно, две другие вернули ошибки ..

Первые две модели были опробованы на aaron_sheep и flamingo, модель смогла правильно нарисовать / нарисовать изображения после восстановления файлов контрольных точек, но остальные два изображения, которые были частью обучающих моделей, восстановлены, но не смогли нарисовать изображения

после восстановления файлов контрольных точек с помощью

   tf.restore_checkpoint(pretrained_model_dir)

затем попытался создать реконструкции овец

   z_list = [] #interpolate spherically between z0 and z1
   N = 10
   for t in np.linspace(0, 1, N):
       z_list.append(slerp(z0, z1, t))


   reconstructions = []
   for i in range(N):
        reconstructions.append([decode(z_list[i], draw_mode=False), [0, i]])

   stroke_grid = make_grid_svg(reconstructions)
   draw_line_strokes(stroke_grid)

, поэтому я попытался загрузить файл контрольных точек, затем набросал catbus

   load_checkpoint(sess, model_dir)
   z_0 = np.random.randn(eval_model.hps.z_size)
   _ = decode(z_0)

Я ожидал, что вывод привлечет catbus, но это не удалось

    ValueError           Traceback (most recent call last)
    <ipython-input-73-d2aff98d71b8> in <module>()
    1 z_0 = np.random.randn(eval_model.hps.z_size)
    ----> 2 _ = decode(z_0)

    <ipython-input-13-1c996f5d54f7> in decode(z_input, draw_mode, 
    temperature, factor)
    3     if z_input is not None:
    4         z = [z_input]
    ----> 5     sample_strokes, m = sample(sess, sample_model, 
    seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
    6     strokes = to_normal_strokes(sample_strokes)
    7     if draw_mode:

    ~\Anaconda3\lib\site-packages\magenta\models\sketch_rnn\model.py in 
    sample(sess, model, seq_len, temperature, greedy_mode, z)
    429         model.pi, model.mu1, model.mu2, model.sigma1, 
    model.sigma2, model.corr,
    430         model.pen, model.final_state
    --> 431     ], feed)
    432 
    433     [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, 
    next_state] = params

    ~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in 
    run(self, fetches, feed_dict, options, run_metadata)
    927     try:
    928       result = self._run(None, fetches, feed_dict, options_ptr,
    --> 929                          run_metadata_ptr)
    930       if run_metadata:
    931         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

    ~\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in 
    _run(self, handle, fetches, feed_dict, options, run_metadata)
    1126                              'which has shape %r' %
    1127                              (np_val.shape, subfeed_t.name,
    -> 1128                               str(subfeed_t.get_shape())))
    1129           if not self.graph.is_feedable(subfeed_t):
    1130             raise ValueError('Tensor %s may not be fed.' % 
    subfeed_t)

    ValueError: Cannot feed value of shape (1, 1, 5) for Tensor 
    'vector_rnn_2/strided_slice_1:0', which has shape '(1, 150, 5)'
...