Python-代码阅读-将一个神经网络模型的参数复制到另一个模型中(2)

1.代码

def copy_model_parameters(sess, qnet1, qnet2):
    # 获取qnet1和qnet2中的可训练变量(参数)
    q1_params = [t for t in tf.trainable_variables() if t.name.startswith(qnet1.scope)]
    q1_params = sorted(q1_params, key=lambda v: v.name)
    q2_params = [t for t in tf.trainable_variables() if t.name.startswith(qnet2.scope)]
    q2_params = sorted(q2_params, key=lambda v: v.name)
    update_ops = []
    # 遍历qnet1和qnet2中的参数,创建更新操作
    for q1_v, q2_v in zip(q1_params, q2_params):
        # 创建将qnet1中参数值赋值给qnet2中参数的操作
        op = q2_v.assign(q1_v)
        # 将更新操作添加到update_ops列表中
        update_ops.append(op)
    # 在TensorFlow会话中运行所有的更新操作,从而将qnet1的参数复制到qnet2中
    sess.run(update_ops)

2.代码阅读

这个函数用于将一个神经网络模型的参数复制到另一个模型中。函数接受三个输入参数:

  1. sess: TensorFlow会话对象,表示当前执行计算图的会话。
  2. qnet1: 源神经网络模型,从该模型复制参数。
  3. qnet2: 目标神经网络模型,将参数复制到该模型。

函数首先使用tf.trainable_variables()函数获取qnet1qnet2中的可训练变量(参数),并根据它们的作用域(假设每个模型都有唯一的作用域)对其进行筛选。qnet1qnet2中的可训练变量分别存储在q1_paramsq2_params列表中。

接着,函数通过遍历q1_paramsq2_params中的变量,为每一对变量创建一个赋值操作q2_v.assign(q1_v))来将qnet1中的变量值复制到qnet2中。这些更新操作被存储在update_ops列表中。

最后,函数使用sess.run(update_ops)在TensorFlow会话中运行所有的更新操作,从而执行将qnet1的参数复制到qnet2中的操作。执行完这个函数后,qnet2的参数将被更新为与qnet1相同的参数值,实现了从一个模型复制参数到另一个模型的目的。

2.1 tf.trainable_variables()

q1_params = [t for t in tf.trainable_variables() if t.name.startswith(qnet1.scope)]

这行代码使用列表推导式从所有的可训练变量(tf.trainable_variables())中筛选出具有指定作用域(qnet1.scope)前缀的变量,并将其保存在q1_params列表中。

具体而言,tf.trainable_variables()函数返回当前图中所有的可训练变量的列表,每个变量都包含了变量的名称、值和其他属性。t.name表示变量的名称,而startswith(qnet1.scope)则检查变量的名称是否以qnet1.scope作为前缀,从而筛选出具有指定作用域前缀的变量。

例如,如果qnet1.scope的值为"qnet1/",那么q1_params列表将包含所有名称以"qnet1/"作为前缀的可训练变量。这样可以方便地获取qnet1模型中的所有参数,以便后续进行参数复制操作。

这一行代码使用了列表推导式(List Comprehension)的结构,是一种简洁的 Python 编码方式,用于从一个可迭代对象中生成新的列表。

列表推导式的结构如下:

[expression for item in iterable if condition]

[表达式 for 迭代变量 in 可迭代对象 [if 条件表达式] ]

其中:

  • expression:表示对每个item执行的表达式,用于生成新的列表中的元素。
  • item:表示迭代的对象中的每个元素。
  • iterable:表示要迭代的对象,可以是列表、元组、集合、字典等。
  • condition:表示可选的条件表达式,用于筛选出符合条件的元素。

在这行代码中,expressiont,表示对于可训练变量列表中的每个元素t,将其添加到q1_params列表中。itemtf.trainable_variables()函数返回的可训练变量列表中的每个元素,iterable就是tf.trainable_variables()函数返回的可训练变量列表。

conditiont.name.startswith(qnet1.scope),表示筛选出以qnet1.scope作为前缀的变量。

因此,这行代码的作用是从tf.trainable_variables()函数返回的所有可训练变量中,筛选出具有指定作用域前缀的变量,并将其保存在q1_params列表中。

2.2 sorted()函数

q1_params = sorted(q1_params, key=lambda v: v.name)

这行代码使用了sorted()函数对q1_params列表进行排序,排序的依据是变量的名称(v.name)。

sorted()函数是 Python 内置函数,用于对列表进行排序。它接受一个列表作为输入,并返回一个新的已排序的列表。其中,key参数是一个可选的函数,用于指定排序的依据。在这行代码中,使用了lambda表达式作为key参数,定义了一个匿名函数,其输入参数为变量v,输出为变量v.name,表示对变量的名称进行排序。

通过对q1_params列表进行排序,可以保证复制模型参数时的一致性,即按照变量名称的字典序对参数进行复制操作,从而确保了参数复制的顺序和对应关系一致。

2.3 zip()函数

    for q1_v, q2_v in zip(q1_params, q2_params):
        op = q2_v.assign(q1_v)
        update_ops.append(op)

这部分代码通过使用zip()函数将q1_paramsq2_params两个列表中的元素一一对应起来,然后使用q2_v.assign(q1_v)操作将q1_params中的变量值复制到q2_params中对应的变量中,并将复制操作的结果保存在op变量中。

zip()函数是 Python 内置函数,用于将多个列表中的元素按索引一一对应起来,生成一个新的可迭代对象(元组列表)。在这里,zip(q1_params, q2_params)q1_paramsq2_params中的元素按索引一一对应起来,生成了一个包含元组的列表,其中每个元组中的第一个元素来自q1_params,第二个元素来自q2_params,即q1_paramsq2_params中的对应位置的变量一一对应。

然后,通过q2_v.assign(q1_v)操作,将q1_params中的变量值复制到q2_params中对应的变量中。q2_vq1_v分别表示q2_paramsq1_params中对应位置的变量,assign()是 TensorFlow 中的赋值操作,用于将一个变量的值赋给另一个变量。

最后,将复制操作的结果op添加到update_ops列表中,以便在后续通过sess.run(update_ops)执行这些复制操作,从而实现模型参数的复制。

2.4 sess.run()

sess.run(update_ops)

sess.run(update_ops)是使用 TensorFlow 的会话(sess)执行一系列更新操作(update_ops)的语句。

update_ops是一个包含了一系列更新操作的列表,这些操作在前面的代码中通过q2_v.assign(q1_v)语句生成。这些操作的目的是将q1_params中的模型参数复制到q2_params中对应的模型参数中。

通过调用sess.run(update_ops),会话会依次执行update_ops列表中的每个更新操作,将q1_params中的模型参数的值复制到q2_params中对应的模型参数中,从而实现模型参数的复制操作。执行完成后,q2_params中的模型参数将与q1_params中的模型参数保持一致,完成了参数复制的操作。


http://www.niftyadmin.cn/n/232842.html

相关文章

centos7.6部署ELK集群(一)之elasticsearch7.7.0集群部署

32.3. 部署es7.7.0 32.3.1. 下载es(各节点都做) wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.7.0-linux-x86_64.tar.gz 32.3.2. 解压至安装目录(各节点都做) tar -xvf elasticsearch-7.7.0-li…

Hbase伪分布安装配置

Hbase安装配置 文章目录Hbase安装配置Hbase安装前提下载Hbase压缩包软件版本兼容性Hadoop和HbaseHbase和JDK软件安装软件位置创建数据保存和日志保存文件夹修改配置文件修改hbase-site.xml文件修改hbase-env.sh文件修改~/.bashrc文件启动hbase并验证权限问题Permission denied修…

JavaScript的作用域、闭包、高阶函数、柯里化、函数缓存和纯函数

作用域 作用域就是当前执行环境的上下文,它限制了变量、函数的有效范围。 在当前作用域下声明的变量、函数只能在当前作用域内以及它嵌套的子作用域内有效。这样避免变量和函数的命名冲突,还可以形成私有数据,从而保证数据不会被外部作用域篡…

Scrum敏捷研发和项目管理

Scrum是全球运用最广泛的敏捷管理框架,Leangoo基于Scrum框架提供了一系列的流程和模板,可以帮助敏捷团队快速启动Scrum敏捷开发。 Leangoo完美支持Scrum敏捷框架,它提供了灵活的敏捷模板和极致的协作体验,可以让团队快速上手&am…

升级长江存储最新闪存,忆恒创源发布新一代企业级NVMe SSD

2023年4月11日 —— 北京忆恒创源科技股份有限公司(Memblaze)正式发布搭载高品质国产闪存的PBlaze6 6541 系列企业级PCIe 4.0 NVMe SSD。作为 MUFP 平台化开发的最新作品,PBlaze6 6541 采用长江存储最新一代晶栈 Xtacking 3D NAND&#xff0c…

异常的讲解(2)

目录 throws异常处理 基本介绍 throws异常处理注意事项和使用细节 自定义异常 基本概念 自定义异常的步骤 throw 和throws的区别 本章作业 第一题 第二题 第三题 第四题 throws异常处理 基本介绍 1)如果一个方法(中的语句执行时)可能生成某种异常,但是…

刷题day55:省份数量

题意描述: 有 n 个城市,其中一些彼此相连,另一些没有相连。如果城市 a 与城市 b 直接相连,且城市 b 与城市 c 直接相连,那么城市 a 与城市 c 间接相连。 省份 是一组直接或间接相连的城市,组内不含其他没…

【散文诗】单片机运行下和非运行下的 ROM 和 RAM

目录一、两种处理器的结构体系1. 哈佛结构体系2. 冯诺依曼结构体系3. 两种结构的总结二、单片机运行下和非运行下的内存分配1. 非运行时的单片机程序在ROM内的分布2. 运行时的单片机程序在RAM内的分布三、单片机程序和操作系统应用程序的对比四、编译流程一、两种处理器的结构体…