博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Pytorch(十四) —— hook
阅读量:2135 次
发布时间:2019-04-30

本文共 675 字,大约阅读时间需要 2 分钟。

HOOK——获取神经网络特征和梯度的有效工具

为了更深入地理解神经网络模型,有时候我们需要观察它训练得到的卷积核、特征图或者梯度等信息,这在CNN可视化研究中经常用到。其中,卷积核最易获取,将模型参数保存即可得到;特征图是中间变量,所对应的图像处理完即会被系统清除,否则将严重占用内存;梯度跟特征图类似,除了叶子结点外,其它中间变量的梯度都被会内存释放,因而不能直接获取。

最容易想到的获取方法就是改变模型结构,在forward的最后不但返回模型的预测输出,还返回所需要的特征图等信息。

如何在不改变模型结构的基础上获取特征图、梯度等信息呢?

Pytorch的hook编程可以在不改变网络结构的基础上有效获取、改变模型中间变量以及梯度等信息。

hook可以提取或改变Tensor的梯度,也可以获取nn.Module的输出和梯度(这里不能改变)。

 

Pytorch在进行完一次反向传播后,出于节省内存的考虑,只会存储叶子节点的梯度信息,并不会存储中间变量的梯度信息。然而有些时候我们又不得不使用中间变量的梯度信息完成某些工作(如获取中间层的梯度,获取中间层的特征图),这时候hook()函数就可以派上用场啦

主要有四种钩子函数:

  • ①torch.Tensor.register_hook
  • ②torch.nn.Module.register_backward_hook
  • ③torch.nn.Module.register_forward_hook
  • ④torch.nn.Module.register_forward_pre_hook,接下来分别对他们进行介绍
     

 

 

 

转载地址:http://feygf.baihongyu.com/

你可能感兴趣的文章
【LEETCODE】119-Pascal's Triangle II
查看>>
【LEETCODE】88-Merge Sorted Array
查看>>
【LEETCODE】19-Remove Nth Node From End of List
查看>>
【LEETCODE】125-Valid Palindrome
查看>>
【LEETCODE】28-Implement strStr()
查看>>
【LEETCODE】6-ZigZag Conversion
查看>>
【LEETCODE】8-String to Integer (atoi)
查看>>
【LEETCODE】14-Longest Common Prefix
查看>>
【LEETCODE】38-Count and Say
查看>>
【LEETCODE】278-First Bad Version
查看>>
【LEETCODE】303-Range Sum Query - Immutable
查看>>
【LEETCODE】21-Merge Two Sorted Lists
查看>>
【LEETCODE】231-Power of Two
查看>>
【LEETCODE】172-Factorial Trailing Zeroes
查看>>
【LEETCODE】112-Path Sum
查看>>
【LEETCODE】9-Palindrome Number
查看>>
【极客学院】-python学习笔记-Python快速入门(面向对象-引入外部文件-Web2Py创建网站)
查看>>
【LEETCODE】190-Reverse Bits
查看>>
【LEETCODE】67-Add Binary
查看>>
【LEETCODE】7-Reverse Integer
查看>>