小白编程,时常看tf代码看得头痛,也没有自己的一点思路。今天就结合网上的一些资料以及我自己的一个代码,整理了一下tensorflow编程一般思路。
一般我们从GitHub上下载的tensorflow的代码文档,主要包含如下几个文件:
- 训练与测试数据集文件夹datasets;
- 保存的模型文件夹snapshots;
- 数据传输接口image_reader.py;
- 网络定义文件net.py;
- 训练主控文件train.py;
- 测试主控文件evaluate.py;
- 辅助文件utils.py;
下面简单地概述一下每个文档的大概作用。
1.datasets
这个文档里面就是存放训练和测试的数据集,包括train_dataset和test_dataset两个文件夹,记录了训练与测试图片;另外还有两个train_data.txt和test_data.txt文档,记录了对于所有训练与测试图片的索引,作用是给数据传送接口调用。对于txt文件可以使用python编写(比如用上图中的write_txt.py和write_test_txt.py撰写训练/测试数据集索引),索引的内容也可自己定制。一般直接添加图片和对应标签的名字,使用空格分离。
2.snapshots
训练完成后,模型会保存在这个文件夹下。同时在进行测试时,会对这个文件夹里面的模型参数进行读取。
3.image_reader.py
一般不采用tensorflow的官方数据传送接口,而是根据用户自己定义。如下列出一个我自己的文件。
import os
import numpy as np
import cv2
def ImageReader(file_name, picture_path, label_path, picture_format=".png", label_format=".jpg", size=256):
picture_name = picture_path + file_name + picture_format # 得到图片名称和路径
label_name = label_path + file_name + label_format # 得到标签名称和路径
picture = cv2.imread(picture_name, 1) # 读取图片
label = cv2.imread(label_name, 1) # 读取标签
height = picture.shape[0] # 得到图片的高
width = picture.shape[1] # 得到图片的宽
picture_resize_t = cv2.resize(picture, (size, size)) # 调整图片的尺寸,改变成网络输入的大小
picture_resize = picture_resize_t / 127.5 - 1. # 归一化图片
label_resize_t = cv2.resize(label, (size, size)) # 调整标签的尺寸,改变成网络输入的大小
label_resize = label_resize_t / 127.5 - 1. # 归一化标签
return picture_resize, label_resize, height, width # 返回网络输入的图片,标签,还有原图片和标签的长宽
在进行ImageReader定制的过程中,往往结合train_data.txt,从外存中读取数据,再送入网络进行迭代,用户可根据自己的需求进行各式各样的定制。
4. net.py
很多GitHub上的代码都把网络定义单独写一个.py文件,这是因为训练/测试主控程序是一个层次较高的代码,在进行实验的时候是不关心神经网络的具体细节的。因此,网络定义的细节应该是定义在net.py文件中。在进行net.py的撰写中,网络底层代码与网络高层代码相互分离,网络高层代码调用网络底层代码。
那么,何谓网络高层代码,何谓网络底层代码?网络高层代码协定了网络的设计架构,而网络底层代码则制约了网络底层(比如神经网络层)的规范。在训练主控代码中,程序只调用网络高层代码。对于高层代码与底层代码的封装,如下我的net.py文件中的部分代码:
import numpy as np
import tensorflow as tf
import math
# 构造可训练参数
def make_var(name, shape, trainable=True):
return tf.get_variable(name, shape, trainable=trainable)
############################底层部分#############################
# 卷积层
def conv2d(input_, output_dim, kernel_size, stride, padding="SAME", name="conv2d", biased=False):
input_dim = input_.get_shape()[-1]
with tf.variable_scope(name):
kernel = make_var(name='weights', shape=[kernel_size, kernel_size, input_dim, output_dim])
output = tf.nn.conv2d(input_, kernel, [1, stride, stride, 1], padding=padding)
if biased:
biases = make_var(name='biases', shape=[output_dim])
output = tf.nn.bias_add(output, biases)
return output
# 空洞卷积层
def atrous_conv2d(input_, output_dim, kernel_size, dilation, padding="SAME", name="atrous_conv2d", biased=False):
input_dim = input_.get_shape()[-1]
with tf.variable_scope(name):
kernel = make_var(name='weights', shape=[kernel_size, kernel_size, input_dim, output_dim])
output = tf.nn.atrous_conv2d(input_, kernel, dilation, padding=padding)
if biased:
biases = make_var(name='biases', shape=[output_dim])
output = tf.nn.bias_add(output, biases)
return output
# 反卷积层
def deconv2d(input_, output_dim, kernel_size, stride, padding="SAME", name="deconv2d"):
input_dim = input_.get_shape()[-1]
input_height = int(input_.get_shape()[1])
input_width = int(input_.get_shape()[2])
with tf.variable_scope(name):
kernel = make_var(name='weights', shape=[kernel_size, kernel_size, output_dim, input_dim])
output = tf.nn.conv2d_transpose(input_, kernel, [1, input_height * 2, input_width * 2, output_dim],
[1, 2, 2, 1], padding="SAME")
return output
############################高层部分#############################
# 生成器,采用UNet架构,主要由8个卷积层和8个反卷积层组成
def generator(image, gf_dim=64, reuse=False, name="generator"):
input_dim = int(image.get_shape()[-1]) # 获取输入通道
dropout_rate = 0.5 # 定义dropout的比例
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
# 第一个卷积层,输出尺度[1, 128, 128, 64]
e1 = batch_norm(conv2d(input_=image, output_dim=gf_dim, kernel_size=4, stride=2, name='g_e1_conv'),
name='g_bn_e1')
# 第二个卷积层,输出尺度[1, 64, 64, 128]
e2 = batch_norm(conv2d(input_=lrelu(e1), output_dim=gf_dim * 2, kernel_size=4, stride=2, name='g_e2_conv'),
name='g_bn_e2')
# 第三个卷积层,输出尺度[1, 32, 32, 256]
e3 = batch_norm(conv2d(input_=lrelu(e2), output_dim=gf_dim * 4, kernel_size=4, stride=2, name='g_e3_conv'),
name='g_bn_e3')
# 第四个卷积层,输出尺度[1, 16, 16, 512]
e4 = batch_norm(conv2d(input_=lrelu(e3), output_dim=gf_dim * 8, kernel_size=4, stride=2, name='g_e4_conv'),
name='g_bn_e4')
# 第五个卷积层,输出尺度[1, 8, 8, 512]
e5 = batch_norm(conv2d(input_=lrelu(e4), output_dim=gf_dim * 8, kernel_size=4, stride=2, name='g_e5_conv'),
name='g_bn_e5')
# 第六个卷积层,输出尺度[1, 4, 4, 512]
e6 = batch_norm(conv2d(input_=lrelu(e5), output_dim=gf_dim * 8, kernel_size=4, stride=2, name='g_e6_conv'),
name='g_bn_e6')
# 第七个卷积层,输出尺度[1, 2, 2, 512]
e7 = batch_norm(conv2d(input_=lrelu(e6), output_dim=gf_dim * 8, kernel_size=4, stride=2, name='g_e7_conv'),
name='g_bn_e7')
# 第八个卷积层,输出尺度[1, 1, 1, 512]
e8 = batch_norm(conv2d(input_=lrelu(e7), output_dim=gf_dim * 8, kernel_size=4, stride=2, name='g_e8_conv'),
name='g_bn_e8')
# 第一个反卷积层,输出尺度[1, 2, 2, 512]
d1 = deconv2d(input_=tf.nn.relu(e8), output_dim=gf_dim * 8, kernel_size=4, stride=2, name='g_d1')
d1 = tf.nn.dropout(d1, dropout_rate) # 随机扔掉一般的输出
d1 = tf.concat([batch_norm(d1, name='g_bn_d1'), e7], 3)
# 第二个反卷积层,输出尺度[1, 4, 4, 512]
d2 = deconv2d(input_=tf.nn.relu(d1), output_dim=gf_dim * 8, kernel_size=4, stride=2, name='g_d2')
d2 = tf.nn.dropout(d2, dropout_rate) # 随机扔掉一般的输出
d2 = tf.concat([batch_norm(d2, name='g_bn_d2'), e6], 3)
# 第三个反卷积层,输出尺度[1, 8, 8, 512]
d3 = deconv2d(input_=tf.nn.relu(d2), output_dim=gf_dim * 8, kernel_size=4, stride=2, name='g_d3')
d3 = tf.nn.dropout(d3, dropout_rate) # 随机扔掉一般的输出
d3 = tf.concat([batch_norm(d3, name='g_bn_d3'), e5], 3)
# 第四个反卷积层,输出尺度[1, 16, 16, 512]
d4 = deconv2d(input_=tf.nn.relu(d3), output_dim=gf_dim * 8, kernel_size=4, stride=2, name='g_d4')
d4 = tf.concat([batch_norm(d4, name='g_bn_d4'), e4], 3)
# 第五个反卷积层,输出尺度[1, 32, 32, 256]
d5 = deconv2d(input_=tf.nn.relu(d4), output_dim=gf_dim * 4, kernel_size=4, stride=2, name='g_d5')
d5 = tf.concat([batch_norm(d5, name='g_bn_d5'), e3], 3)
# 第六个反卷积层,输出尺度[1, 64, 64, 128]
d6 = deconv2d(input_=tf.nn.relu(d5), output_dim=gf_dim * 2, kernel_size=4, stride=2, name='g_d6')
d6 = tf.concat([batch_norm(d6, name='g_bn_d6'), e2], 3)
# 第七个反卷积层,输出尺度[1, 128, 128, 64]
d7 = deconv2d(input_=tf.nn.relu(d6), output_dim=gf_dim, kernel_size=4, stride=2, name='g_d7')
d7 = tf.concat([batch_norm(d7, name='g_bn_d7'), e1], 3)
# 第八个反卷积层,输出尺度[1, 256, 256, 3]
d8 = deconv2d(input_=tf.nn.relu(d7), output_dim=input_dim, kernel_size=4, stride=2, name='g_d8')
return tf.nn.tanh(d8)
在上述代码的网络底层代码中,规定了卷积、空洞卷积和转置卷积操作的定义,里面分别使用了tf.nn.conv2d、tf.nn.atrous_conv2d和tf.nn.conv2d_transpose接口,而网络高层代码看不见tensorflow的网络层接口。在高层代码中,制定了网络的具体架构,如上所示,定义了一个采用8个卷积层和8个反卷积层组成的生成器,而训练主控程序就调用net函数。从net中也可以方便地看出网络架构,并使得训练主控程序简洁明了。
这样做还有一个好处,就是在参数设置方面。注意到tensorflow的权重参数名称设置,在构造训练代码中,tf.get_variable(name, )和with tf.variable_scope(name)中的name是相当重要的。因为在tensorflow中,参数名称是按照类似堆栈的架构一级一级由底往上堆叠的,而每一次添加scope中的name就组成了堆栈的一部分。因此,如果说要满足同一个网络有两套不同的参数,又不需重新定义网络结构,应该怎么做呢?就将参数名称堆栈的栈顶换掉就行了,而栈顶以下的部分是不需要改动的。在网络定义文件中,网络高层代码和网络底层代码共同完成了对参数名称堆栈的除栈顶以外部分的定义与约束,而参数的栈顶部分则可以在训练主控程序中定制,这样就可以实现同一网络结构配置多套参数。并且,代码更加层次分明!
5. 训练主控文件train.py
对于训练的主控程序,首先往往需要在其中定义训练参数,这样方便在代码中修改。然后,可能需要一些输入占位符定义,其次最主要的是进行网络前向传播得到训练结果,并根据这个结果计算loss,之后需要使用训练器对网络进行反向传播并更新参数梯度。最后,根据需要导入fine-tune的参数,进行初始化,然后就是不停地送入数据并且进行网络的训练。
简而言之,训练的主控程序逻辑如下:
获取用户定义训练程序参数–>置占位符(placeholder)–>进行网络前向传播–>计算loss–>设置训练器并且进行网络反向传播–>保存sammary(可选)–>进行fine-tune参数的导入–>初始化训练过程–>传入数据进行网络训练。
from __future__ import print_function
#导入必要的库
import argparse
from random import shuffle
import os
import tensorflow as tf
import cv2
...
#导入自己写好的库
from image_reader import *
from net import *
...
#用户自定义训练参数
parser = argparse.ArgumentParser(description='')
parser.add_argument("--snapshot_dir", default='./snapshots', help="path of snapshots")
...
args = parser.parse_args() # 用来解析命令行参数
##训练主函数
def main():
if not os.path.exists(args.snapshot_dir): # 如果保存模型参数的文件夹不存在则创建
os.makedirs(args.snapshot_dir)
if not os.path.exists(args.out_dir): # 如果保存训练中可视化输出的文件夹不存在则创建
os.makedirs(args.out_dir)
#设置占位符
image = tf.placeholder(tf.float32, shape=[1, args.image_size, args.image_size, image_channels],
name='image') # 输入的训练图像
label = tf.placeholder(tf.float32, shape=[1, args.image_size, args.image_size, label_channels],
name='label') # 输入的与训练图像匹配的标签
#进行网络前向传播
net_output = net(image=image, reuse=False, name='net')
#计算loss
loss = comput_loss(image, label)
#得到loss_summary(可选)
loss_sum = tf.summary.scalar("loss", loss)
summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph())
#设置训练器并进行网络训练,这里使用了一个Adam优化器
vars = [v for v in tf.trainable_variables() if 'net' in v.name]
optim = tf.train.AdamOptimizer(learning_rate)
grads_and_vars = optim.compute_gradients(loss, var_list=vars)
train_op = optim.apply_gradients(grads_and_vars)
#设置tensorflow会话层
config = tf.ConfigProto()
config.gpu_options.allow_growth = True # 设定显存不超量使用
sess = tf.Session(config=config)
#进行fine-tune参数导入
restore_vars = [v for v in tf.trainable_variables() if ...]#设置需要导入的参数
loader = tf.train.Saver(var_list = restore_vars)
loader.restore(sess, 'load parameters path')
#初始化训练过程
init = tf.global_variables_initializer() # 参数初始化器
sess.run(init) # 初始化所有可训练参数
#保存器
saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=max_to_keep_nums)
for step in range(training_steps):#进行网络训练
load_image, load_label = ImageReader(image_path, step, ...)#按需传入读取图片的参数
feed_dict = {image : load_image, label : load_label}
oss_value, loss_sum_value, _= sess,run([loss, loss_sum, train_op], feed_dict = feed_dict)
if step % add_summary_per_step == 0: #按需更新summary
summary_writer.add_summary(loss_sum_value, loss)
if step % save_per_step == 0: #按需保存模型参数
saver.save(sess, 'save_path', step)
print('...') #按需打印loss等
if __name__ == '__main__':
main()
在进行训练器设置并让网络自动求解梯度更新训练参数时,使用了两个接口:compute_gradients和apply_gradients,为什么要这么做呢?在进行参数更新的时候,先求解出参数和对应的梯度,然后可以按需对梯度做出转换,最后再利用改变的梯度来更新参数,这样可以有效地满足一些用户的自定制需求。
6. 测试主控文件evaluate.py
对于测试的程序,与训练主程序不同的地方是,测试程序并不需要计算loss和设置训练器进行网络参数的训练,当然也不需要保存summary。但是测试程序重要的是需要导入所需要的网络的全部参数(从snapshots文件中导入),这样网络才能完成前向传播过程,并且按需进行对结果的处理,测试精度,得到可视化效果等。
简而言之,测试的主控程序逻辑如下:
获取用户定义测试程序参数–>设置占位符(placeholder)–>进行网络前向传播得到结果–>进行前向传播所需参数的导入–>送数据进网络–>进行结果后处理
#导入必要的库
import argparse
import os
import tensorflow as tf
import cv2
...
#用户自定义训练参数
parser = argparse.ArgumentParser(description='')
parser.add_argument("--snapshot_dir", default='./snapshots', help="path of snapshots")
...
args = parser.parse_args() # 用来解析命令行参数
#测试主函数
def main():
if not os.path.exists(args.out_dir): # 如果保存测试结果的文件夹不存在则创建
os.makedirs(args.out_dir)
#设置占位符
image = tf.placeholder(tf.float32,shape=[1, image_height, image_height, image_channels],name='image')
#进行网络前向传播
net_output = net(image=image, reuse=False, name='net')
#设置tensorflow会话层
config = tf.ConfigProto()
config.gpu_options.allow_growth = True # 设定显存不超量使用
sess = tf.Session(config=config) # 建立会话层
#进行测试所需参数导入
restore_vars = [v for v in tf.global_variables() if ...] #设置测试需要导入的参数
saver = tf.train.Saver(var_list=restore_var, max_to_keep=1) # 导入模型参数时使用
checkpoint = tf.train.latest_checkpoint(args.snapshots) # 读取模型参数
saver.restore(sess, checkpoint) # 导入模型参数
for step in range(testing_steps): #送入测试样本
load_image, load_label = ImageReader(image_path, step, ...) #按需传入读取图片的参数
feed_dict = {image : load_image, label : load_label}
net_output = sess,run(net_output, feed_dict = feed_dict)
result = postprocess(net_output) #进行后处理(可选)
accuracy = compute_accuracy(result, label)#进行精度计算(可选)
print('...') #按需打印信息等
if __name__ == '__main__':
main()
7. utils.py
utils.py文件往往就是一些中小型的tools,比如,经常将训练图像的前处理、后处理、数据转换等一些小型函数放在utils.py里面,起到一些辅助与查缺补漏的作用。如下我将save函数、l1和l2loss等写入其中。
import tensorflow as tf
import numpy as np
import os
def save(saver, sess, logdir, step): # 保存模型的save函数
model_name = 'model' # 保存的模型名前缀
checkpoint_path = os.path.join(logdir, model_name) # 模型的保存路径与名称
if not os.path.exists(logdir): # 如果路径不存在即创建
os.makedirs(logdir)
saver.save(sess, checkpoint_path, global_step=step) # 保存模型
print('The checkpoint has been created.')
def l1_loss(src, dst): # 定义l1_loss
return tf.reduce_mean(tf.abs(src - dst))
def l2_loss(src, dst): #定义l2_loss
return tf.reduce_mean((src - dst)**2)
本文章主要是在https://blog.csdn.net/jiongnima/article/details/78337783一文上编写的。个人写只为记录方便自己时常学习,如有侵权联系我立删。