train.py
import tensorflow as tf
import os, time
import random
import numpy as np
from data_init import get_data
from model import UNET as G
import config
# from tensorflow.python.profiler import model_analyzer
# from tensorflow.python.profiler import option_builder
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'
# param
epochs = config.epoch
def train():
data_train_batch, label_train_batch, data_test_batch, label_test_batch = get_data()
data_len = float(len(data_train_batch))
g = G()
open('logdir/readme', 'w').write(
g.readme + '\n\n{}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
# launch tensorboard
os.system('pkill tensorboard')
os.system('rm /home/zhuqingjie/prj/tunet/logdir/checkpoint')
os.system('rm /home/zhuqingjie/prj/tunet/logdir/event*')
os.system('rm /home/zhuqingjie/prj/tunet/logdir/model_*')
os.system('rm /home/zhuqingjie/prj/tunet/logdir/v/event*')
time.sleep(1)
os.system('nohup /home/zhuqingjie/env/py3_tf_low/bin/tensorboard --logdir=/home/zhuqingjie/prj/tunet/logdir &')
# 备份一份model.py
os.system("cp /home/zhuqingjie/prj/tunet/model.py /home/zhuqingjie/prj/tunet/logdir/")
# train
with tf.Session(graph=g.graph) as sess:
saver = tf.train.Saver(max_to_keep=1)
summary_writer = tf.summary.FileWriter(logdir='logdir', graph=g.graph)
summary_writer_v = tf.summary.FileWriter(logdir='logdir/v', graph=g.graph)
sess.run(tf.global_variables_initializer())
sess.run([g.train_iterator.initializer, g.test_iterator.initializer],
feed_dict={g.train_features_placeholder: data_train_batch,
g.train_labels_placeholder: label_train_batch,
g.test_features_placeholder: data_test_batch,
g.test_labels_placeholder: label_test_batch})
time_use = []
while True:
time_start = time.time()
_, _, summary, loss, abs_error, gs = sess.run(
[g.train_op_D, g.train_op_G, g.mergeall, g.loss_mse, g.abs_error, g.global_step],
feed_dict={g.train_flag: True})
# _ = sess.run([g.train_op_G2], feed_dict={g.train_flag: True})
# _ = sess.run([g.train_op_G2], feed_dict={g.train_flag: True})
# _ = sess.run([g.train_op_G2], feed_dict={g.train_flag: True})
# _ = sess.run([g.train_op_G2], feed_dict={g.train_flag: True})
time_end = time.time()
time_use.append(time_end - time_start)
summary_writer.add_summary(summary, gs)
# val
if gs % 10 == 0:
_, _, summary, gs = sess.run([g.loss_mse, g.abs_error, g.mergeall, g.global_step],
feed_dict={g.train_flag: False})
summary_writer_v.add_summary(summary, gs)
print('---------------avg_time_use:{}'.format(np.mean(np.array(time_use))))
print('gs / data_len, gs, loss, abs_error')
time_use = []
# save
if gs % 100 == 0:
# 覆盖保存
saver.save(sess, 'logdir/model_{}'.format(gs))
# 每一千次保存一次model
if gs % 1000 == 0:
iter_dir = '/home/zhuqingjie/prj/tunet/backup/iter/{}/'.format(gs)
os.system('mkdir {}'.format(iter_dir))
os.system('cp /home/zhuqingjie/prj/tunet/logdir/model* {}'.format(iter_dir))
os.system('cp /home/zhuqingjie/prj/tunet/logdir/last_random_ind.txt {}'.format(iter_dir))
print('{:.2f} -- {} -- {:.4f} -- {:.4f}'.format(gs / data_len, gs, loss, abs_error))
if gs / data_len > epochs: break
if __name__ == '__main__':
print('\n' * 5)
print('=' * 150)
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
print('=' * 150)
train()
print('ok')
model.py
import tensorflow as tf
import tensorflow.layers as L
from vgg16 import Vgg16 as VGG
import os
import config
# param
batch_size = config.batch_size
hw = config.hw
class UNET():
def __init__(sf, predict_flag=False):
print('loading UNET..')
sf.readme = '''
Data: {}*{}
G: unet
D: cnn
Opt: adam, GAN loss + perc loss + MSE loss
Train: 又多train了2次train_op_G2
'''.format(config.hw, config.hw)
sf.graph = tf.Graph()
sf.istraining = not predict_flag
# dataset_train, dataset_test = load_data_2()
with sf.graph.as_default():
with tf.device('/cpu:0'):
# init dataset
sf.train_flag = tf.placeholder(dtype=tf.bool)
# train
sf.train_features_placeholder = tf.placeholder(tf.float32, [None, batch_size, hw, hw, 1])
sf.train_labels_placeholder = tf.placeholder(tf.float32, [None, batch_size, hw, hw, 1])
train_dataset = tf.data.Dataset.from_tensor_slices(
(sf.train_features_placeholder, sf.train_labels_placeholder)).repeat()
sf.train_iterator = train_dataset.make_initializable_iterator()
# test
sf.test_features_placeholder = tf.placeholder(tf.float32, [None, batch_size, hw, hw, 1])
sf.test_labels_placeholder = tf.placeholder(tf.float32, [None, batch_size, hw, hw, 1])
test_dataset = tf.data.Dataset.from_tensor_slices(
(sf.test_features_placeholder, sf.test_labels_placeholder)).repeat()
sf.test_iterator = test_dataset.make_initializable_iterator()
if not predict_flag:
sf.x, sf.y = sf.get_xy(sf.train_flag)
else:
sf.x = tf.placeholder(tf.float32, [batch_size, hw, hw, 1])
sf.y = tf.placeholder(tf.float32, [batch_size, hw, hw, 1])
# Multi GPU
sf.opt = tf.train.AdamOptimizer(0.0001)
sf.global_step = tf.Variable(0, trainable=False)
tower_grads_G = []
tower_grads_D = []
for gpu_i in range(len(config.gpus)):
with tf.device('/gpu:{}'.format(gpu_i)):
with tf.name_scope('tower_{}'.format(gpu_i)):
# split batch
batch_per_gpu = config.batch_per_gpu
x_ = sf.x[gpu_i * batch_per_gpu:(gpu_i + 1) * batch_per_gpu]
y_ = sf.y[gpu_i * batch_per_gpu:(gpu_i + 1) * batch_per_gpu]
prd = sf.gen(x_)
d_prd = sf.disc(prd)
ls_d_prd = tf.reduce_mean(d_prd)
loss_d_prd = tf.reduce_mean(tf.log(tf.clip_by_value(d_prd, 1e-10, 1.0)))
d_y = sf.disc(y_)
ls_d_y = tf.reduce_mean(d_y)
loss_d_y = tf.reduce_mean(tf.log(tf.clip_by_value(d_y, 1e-10, 1.0)))
abs_error = tf.reduce_mean(tf.abs(y_ - prd)) * 255
# MSE Loss
yr = tf.reshape(y_, [y_.shape[0], -1])
prdr = tf.reshape(prd, [prd.shape[0], -1])
loss_mse = tf.losses.mean_squared_error(yr, prdr)
# Perceptual Loss
vgg = VGG()
y_perc = vgg.forward(y_)
prd_perc = vgg.forward(prd)
yr_perc = tf.reshape(y_perc, [y_perc.shape[0], -1])
prdr_perc = tf.reshape(prd_perc, [prd_perc.shape[0], -1])
loss_perc = tf.losses.mean_squared_error(yr_perc, prdr_perc)
# Adversarial Loss
loss_G = loss_mse + loss_perc - loss_d_prd
loss_D = loss_d_prd - loss_d_y
# gard
var_list_G = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
var_list_D = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
grads_G = sf.opt.compute_gradients(loss_G, var_list=var_list_G)
grads_D = sf.opt.compute_gradients(loss_D, var_list=var_list_D)
tower_grads_G.append(grads_G)
tower_grads_D.append(grads_D)
# summary
if gpu_i == 0:
sf.loss_mse = loss_mse
sf.loss_perc = loss_perc
sf.abs_error = abs_error
sf.loss_G = loss_G
sf.loss_D = loss_D
sf.ls_d_prd = ls_d_prd
sf.ls_d_y = ls_d_y
sf.loss_d_prd = loss_d_prd
sf.loss_d_y = loss_d_y
sf.x_ = x_
sf.y_ = y_
sf.prd = prd
# summary
tf.summary.scalar('loss1/loss', sf.loss_mse)
tf.summary.scalar('loss1/loss_perc', sf.loss_perc)
tf.summary.scalar('loss1/abs_error', sf.abs_error)
tf.summary.scalar('DG_loss/G', sf.loss_G)
tf.summary.scalar('DG_loss/D', sf.loss_D)
tf.summary.scalar('Dloss/ls_d_prd', sf.ls_d_prd)
tf.summary.scalar('Dloss/ls_d_y', sf.ls_d_y)
tf.summary.scalar('Dloss/loss_d_prd', sf.loss_d_prd)
tf.summary.scalar('Dloss/loss_d_y', sf.loss_d_y)
tf.summary.image('img/x', sf.x_[0:1], max_outputs=1)
tf.summary.image('img/y', sf.y_[0:1], max_outputs=1)
tf.summary.image('img/prd', sf.prd[0:1], max_outputs=1)
sf.mergeall = tf.summary.merge_all()
avg_grads_G = sf.average_gradients(tower_grads_G)
avg_grads_D = sf.average_gradients(tower_grads_D)
sf.train_op_G = sf.opt.apply_gradients(avg_grads_G, global_step=sf.global_step)
sf.train_op_D = sf.opt.apply_gradients(avg_grads_D)
# generator
def gen(sf, x):
with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
c1 = sf.conv(x, 64, 'c1')
c2 = sf.conv(c1, 64, 'c2')
p1 = L.max_pooling2d(c2, 2, 2, name='p1')
c3 = sf.conv(p1, 128, 'c3')
c4 = sf.conv(c3, 128, 'c4')
p2 = L.max_pooling2d(c4, 2, 2, name='p2')
c5 = sf.conv(p2, 256, 'c5')
c6 = sf.conv(c5, 256, 'c6')
p3 = L.max_pooling2d(c6, 2, 2, name='p3')
c7 = sf.conv(p3, 512, 'c7')
c8 = sf.conv(c7, 512, 'c8')
d1 = L.dropout(c8, 0.5, training=sf.istraining, name='d1')
p4 = L.max_pooling2d(d1, 2, 2, name='p4')
c9 = sf.conv(p4, 1024, 'c9')
c10 = sf.conv(c9, 1024, 'c10')
d2 = L.dropout(c10, 0.5, training=sf.istraining, name='d2')
u1 = tf.keras.layers.UpSampling2D()(d2)
# u1 = L.conv2d_transpose(d2, 1024, 3, 2, padding='same')
uc1 = sf.conv(u1, 512, 'uc1', ker_size=2)
mg1 = tf.concat([d1, uc1], axis=3, name='mg1')
c11 = sf.conv(mg1, 512, 'c11')
c12 = sf.conv(c11, 512, 'c12')
u2 = tf.keras.layers.UpSampling2D()(c12)
# u2 = L.conv2d_transpose(c12, 512, 3, 2, padding='same')
uc2 = sf.conv(u2, 256, 'uc2', ker_size=2)
mg2 = tf.concat([c6, uc2], axis=3, name='mg2')
c13 = sf.conv(mg2, 256, 'c13')
c14 = sf.conv(c13, 256, 'c14')
u3 = tf.keras.layers.UpSampling2D()(c14)
# u3 = L.conv2d_transpose(c14, 256, 3, 2, padding='same')
uc3 = sf.conv(u3, 128, 'uc3', ker_size=2)
mg3 = tf.concat([c4, uc3], axis=3, name='mg3')
c15 = sf.conv(mg3, 128, 'c15')
c16 = sf.conv(c15, 128, 'c16')
u4 = tf.keras.layers.UpSampling2D()(c16)
# u4 = L.conv2d_transpose(c16, 128, 3, 2, padding='same')
uc4 = sf.conv(u4, 64, 'uc4', ker_size=2)
mg4 = tf.concat([c2, uc4], axis=3, name='mg4')
c17 = sf.conv(mg4, 64, 'c17')
c18 = sf.conv(c17, 64, 'c18')
c19 = sf.conv(c18, 2, 'c19')
# tf.summary.histogram('c19', c19)
prd = sf.conv(c19, 1, 'prd', ker_size=1, act=tf.nn.sigmoid)
return prd
# discriminator
def disc(sf, prd):
with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
c1 = sf.conv(prd, 16, 'c1')
p1 = L.max_pooling2d(c1, 2, 2, name='p1')
c2 = sf.conv(p1, 16, 'c2')
p2 = L.max_pooling2d(c2, 2, 2, name='p2')
c3 = sf.conv(p2, 32, 'c3')
p3 = L.max_pooling2d(c3, 2, 2, name='p3')
c4 = sf.conv(p3, 32, 'c4')
p4 = L.max_pooling2d(c4, 2, 2, name='p4')
c5 = sf.conv(p4, 64, 'c5')
p5 = L.max_pooling2d(c5, 2, 2, name='p5')
c6 = sf.conv(p5, 64, 'c6')
p6 = L.max_pooling2d(c6, 2, 2, name='p6')
c7 = sf.conv(p6, 128, 'c7')
p7 = L.max_pooling2d(c7, 2, 2, name='p7')
c8 = sf.conv(p7, 128, 'c8', pad='valid')
line1 = tf.reshape(c8, [c8.shape[0], -1])
fc1 = L.dense(line1, 128, activation=tf.nn.leaky_relu)
d1 = L.dropout(fc1, 0.5, training=sf.istraining)
fc2 = L.dense(d1, 1, activation=tf.nn.sigmoid)
return fc2
def conv(sf, x, filters, name, ker_size=3,
act=tf.nn.leaky_relu,
pad='same',
init=tf.contrib.layers.xavier_initializer()):
return L.conv2d(x, filters, ker_size,
activation=act,
padding=pad,
kernel_initializer=init,
name=name)
def get_xy(sf, train_flag):
def fn1():
xy_dict = sf.train_iterator.get_next()
return xy_dict[0], xy_dict[1]
def fn2():
xy_dict = sf.test_iterator.get_next()
return xy_dict[0], xy_dict[1]
return tf.cond(train_flag, fn1, fn2)
def average_gradients(sf, tower_grads):
average_grads = []
for grad_and_vars in zip(*tower_grads):
grads = []
for g, _ in grad_and_vars:
expend_g = tf.expand_dims(g, 0)
grads.append(expend_g)
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(grad, 0)
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads