博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch怎么抽取中间的特征或者梯度
阅读量:6825 次
发布时间:2019-06-26

本文共 1878 字,大约阅读时间需要 6 分钟。

for i, (input, target) in enumerate(trainloader):

# measure data loading time

data_time.update(time.time() - end)

 

input, target = input.cuda(), target.cuda()

if i==2:
def for_hook(module,input, output):
print('output values:',output)
handle2 = model.module.conv1.register_forward_hook(for_hook)

 

# compute output

output = model(input)
# output = output1*2
if i==2:
def variable_hook(grad):
print('grad:',grad)
hook_handle = output.register_hook(variable_hook)

 

# output2 = 2*output_

# output = 0.5*output2
loss = criterion(output, target)

 

 

 

 

 

# measure accuracy and record loss
prec = accuracy(output, target)[0]
losses.update(loss.item(), input.size(0))
top1.update(prec.item(), input.size(0))

 

# compute gradient and do SGD step

optimizer.zero_grad()
loss.backward()
# print('output.grad:',output1.grad)

 

# print('input.grad:',input.grad)

# print('input.is_leaf:',input.is_leaf)
# output.register_hook(print)
# zz.backward()

 

 

 

# print("output.grad:",output.grad)

optimizer.step()

 

# measure elapsed time

batch_time.update(time.time() - end)
end = time.time()

 

if i==2:

# print('the input is :', input)
print('the output is :', output)
hook_handle.remove()
handle2.remove()
# print('the target is :', target)
# print('parameters:',optimizer.param_groups)

 

if i % args.print_freq == 0:

print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
epoch, i, len(trainloader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1))

 

 

解释:

#定义构前向函数def for_hook(module,input, output):    print('output values:',output)#什么要抽取的层 model.module.avgpoolhandle2 = model.module.avgpool.register_forward_hook(for_hook) #前向,为勾函数准备output = model(input)#删除勾函数handle2.remove()

 

 

 

转载于:https://www.cnblogs.com/Wanggcong/p/10269823.html

你可能感兴趣的文章
Mongodb数据库安装及使用
查看>>
08-Windows Server 2012 R2 会话远程桌面-标准部署-使用PowerShell进行部署2-1
查看>>
centos7 systemctl 启动 Redis 失败
查看>>
The Hacker's Guide To Python 单元测试
查看>>
编程王道,唯“慢”不破
查看>>
SQL 必知必会·笔记<13>插入数据
查看>>
Openfire与XMPP协议
查看>>
在.NET下如何实现密码Hash化
查看>>
缩略图不变形
查看>>
【计算机视觉必读干货】图像分类、定位、检测,语义分割和实例分割方法梳理...
查看>>
SSIS Execute SQL Task 用法
查看>>
使用枚举和结构输出日期
查看>>
面试题:单词翻转(代码简洁&效率)
查看>>
使用oledb读写excel出现“操作必须使用一个可更新的查询”的解决办法
查看>>
Windows Azure Cloud Service (11) PaaS之Web Role, Worker Role(上)
查看>>
OEA 中 WPF 树型表格虚拟化设计方案
查看>>
《Android深度探索(卷1):HAL与驱动开发》虚拟实验环境(Ubuntu Linux)免费下载,不需要CPU虚拟化支持...
查看>>
linux 终端提示符
查看>>
C# 实现多线程的同步方法详解
查看>>
[转贴]当前Java开发中的若干问题
查看>>