2017年7月12日星期三

TensorFlow 编程手册: Variables: Creation, Initialization, Saving, and Loading

(本文来自官方手册,笔者为初学者,记下以作学习)
当你训练一个模型的时候,你可以使用variables去存放和更新参数。Variables就是包含tensor的内存缓冲区。你必须明确的初始化,他们能够在训练期间或训练结束被保存到磁盘。之后,你可以恢复他们的值去训练和分析模型。
下面列出了本文引用了TF的类,
     The tf.Variable class
     The tf.train.Saver class


1.Creation
当调用tf.Variable的时候,就在向图中添加op.

#create two variable
weights = tf.Varaible(tf.random_normal([784,200],stddev=0.35),name="weights")
biases = tf.Variable(tf.zeros([200]),name="biases")


2.Initialization
变量的初始化必须在运行其他操作前运行,最方便的方法是加一个op去run所有的变量初始化,如下所示:
#create two variables
weights = tf.Variable(tf.random_normal([784,200],stddev=0.35),name="weights")
biases = tf.Variable(tf.zeros([200]),name="biases")
...
#Add an op to initialize the varaibles
init_op = tf.global_variables_initializer()

#later,when launching the model
with tf.Session() as sess:
    #Run the init operation
    sess.run(init_op)
    ...
    #Use the model
    ...

2.1 Initialization from another Variable
有时候你要用另一个变量的值来初始化变量,由于用tf.global_variables_initializer()初始化所有的变量,所有使用的时候必须小心。这时候要用到initialized_value()属性
# Create a variable with a random value.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
# Create another variable with the same value as 'weights'.
w2 = tf.Variable(weights.initialized_value(), name="w2")
# Create another variable with twice the value of 'weights'
w_twice = tf.Variable(weights.initialized_value() * 2.0, name="w_twice")


3.Saving and restoring
保存和恢复模型的最简单方法是使用tf.train.Saver对象。 构造函数将图形中的所有变量或指定列表的图形添加到图形中。 保护对象提供了运行这些操作的方法,指定检查点文件写入或读取的路径。
 请注意,要恢复没有图形的模型检查点,必须首先从元图形文件导入图形(典型扩展名为.meta)。 这是通过tf.train.import_meta_graph完成的,而tf.train.import_meta_graph又返回一个可以执行还原的Saver。

 3.1.Checkpoint Files
 变量保存在二进制文件中,大致包含从变量名到张量值的映射。创建Saver对象时,您可以选择为检查点文件中的变量名称。 默认情况下,它为每个变量使用tf.Variable.name属性的值。要了解检查点中的哪些变量,可以使用inspect_checkpoint库,特别是print_tensors_in_checkpoint_file函数。

 3.2 Saving Variables
通过tf.train.Saver()创建一个Saver,可以用来管理模型中所有变量
#create some variables
v1 = tf.Variable(...,name="v1")
v2 = tf.Variable(...,name="v2")
 ...
#Add an op to initialize the variables
init_op = tf.global_variables_initializer()

#Add ops to save and restore all the variable
saver = tf.train.Saver()

#Later,launch the model,initialize the variables,do some work,
#save the variables to disk
with tf.Session() as sess:
    sess.run(init_op)
    #do some work with the model
    ..
    #save the variables to disk
    save_path = saver.save(sess,"/tmp/model.ckpt")
    print("Model saved in file:%s" %save_path)

3.3 Restoring Variables
同样的Saver对象被用来恢复对象。当你从文件还原变量时,就不必事先进行初始化。
#create some variables
v1 = tf.Variable(...,name="v1")
v2 = tf.Variable(...,name="v2")
 ...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model
  ...

 3.4 Choosing which variables to save and restore

没有评论:

发表评论

leetcode 17

17.   Letter Combinations of a Phone Number Medium Given a string containing digits from   2-9   inclusive, return all possible l...