tensorflow 获取checkpoint中的变量列表实例

(编辑:jimmy 日期: 2024/9/24 浏览:2)

方式1:静态获取,通过直接解析checkpoint文件获取变量名及变量值

通过

reader = tf.train.NewCheckpointReader(model_path)

或者通过:

from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(model_path)

代码:

model_path = "./checkpoints/model.ckpt-75000"
## 下面两个reader作用等价
#reader = pywrap_tensorflow.NewCheckpointReader(model_path)
reader = tf.train.NewCheckpointReader(model_path)
 
## 用reader获取变量字典,key是变量名,value是变量的shape
var_to_shape_map = reader.get_variable_to_shape_map()
for var_name in var_to_shape_map.keys():
 #用reader获取变量值
 var_value = reader.get_tensor(var_name)
 
 print("var_name",var_name)
 print("var_value",var_value)

方式2:动态获取,先加载checkpoint模型,然后用graph.get_tensor_by_name()获取变量值

代码 (注意:要先在脚本中构建model中对应的变量及scope):

 model_path = "./checkpoints/model.ckpt-75000"
 config = tf.ConfigProto()
 config.gpu_options.allow_growth = True
 with tf.Session(config=config) as sess:
  ## 获取待加载的变量列表
  trainable_vars = tf.trainable_variables()
  g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="generator")
  d_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='discriminator')
  flow_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='flow_net')
  var_restore = g_vars + d_vars
 
  ## 仅加载目标变量
  loader = tf.train.Saver(var_restore)
  loader.restore(sess,model_path)
 
  ## 显示加载的变量值
  graph = tf.get_default_graph()
  for var in var_restore:
   tensor = graph.get_tensor_by_name(var.name)
   print("=======变量名=======",tensor)
   print("-------变量值-------",sess.run(tensor))

以上这篇tensorflow 获取checkpoint中的变量列表实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

一句话新闻
一文看懂荣耀MagicBook Pro 16
荣耀猎人回归!七大亮点看懂不只是轻薄本,更是游戏本的MagicBook Pro 16.
人们对于笔记本电脑有一个固有印象:要么轻薄但性能一般,要么性能强劲但笨重臃肿。然而,今年荣耀新推出的MagicBook Pro 16刷新了人们的认知——发布会上,荣耀宣布猎人游戏本正式回归,称其继承了荣耀 HUNTER 基因,并自信地为其打出“轻薄本,更是游戏本”的口号。
众所周知,寻求轻薄本的用户普遍更看重便携性、外观造型、静谧性和打字办公等用机体验,而寻求游戏本的用户则普遍更看重硬件配置、性能释放等硬核指标。把两个看似难以相干的产品融合到一起,我们不禁对它产生了强烈的好奇:作为代表荣耀猎人游戏本的跨界新物种,它究竟做了哪些平衡以兼顾不同人群的各类需求呢?