神经网络输出中间特征图

news/2024/9/1 23:35:33 标签: 神经网络, 人工智能, 深度学习

在进行神经网络的训练过程中,会生成不同的特征图信息,这些特征图中包含大量图像信息,如轮廓信息,细节信息等,然而,我们一般只获取最终的输出结果,至于中间的特征图则很少关注。

前两天师弟突然问起了这个问题,但我也没有头绪,后来和师弟研究了一下,大概有了一个思路。

即每个特征提取模块都会输出一个特征图,这些特征图的每个像素实际上就是一些数值,那么只需要将这些数值保存,再以图像的形式展现出来便OK了。

基于这个思路,我们来进行设计。在观测输出的特征图时,我们可以使用推理代码来进行输出,因为推理时所消耗的资源较少且推理时可以很明确我们输入的图像是什么。
至于要想实现的效果:

原图:

在这里插入图片描述

输出的特征图:

在这里插入图片描述
那么该如何进行呢?
首先是要明确你要输出哪个阶段的特征图像,博主分别选择了主干网络四个阶段的输出结果,输出的特征图大小分别为:

x的shape: torch.Size([1, 64, 200, 300])
x的shape: torch.Size([1, 128, 100, 150])
x的shape: torch.Size([1, 320, 50, 75])
x的shape: torch.Size([1, 512, 25, 38])

代码实现

在要输出特征图的模块后面讲特征图保存为numpy的格式:

sb = x.cpu().data.numpy()
np.save('matric'+str(i)+'.npy', sb)#这里的i是对应四个阶段的id

读取numpy格式数据并转换为特征图:

import numpy as np
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

def normalization(data):  # NORMALIZE TO [0,1]
    _range = np.max(data) - np.min(data)
    data = (data - np.min(data)) / _range  # [0,1]
    return data

def fm_vis(feats, save_dir, save_name):
    save_dir = os.path.join(save_dir, save_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    feats = normalization(feats[0].cpu().data.numpy())
    for idx in range(min(feats.shape[0], 200*300)):  # CHANNLE NUMBER
        fms = feats[idx, :, :]
        plt.imshow(fms)
        plt.savefig(os.path.join(save_dir, save_name + '_' + str(idx) + ".png"))
        
for i in range(0,4):
    s_b1 = np.load('matric'+str(i)+'.npy')
    print(s_b1)
    s_b2 = torch.from_numpy(s_b1)
    out_dir = "outputs"
    s_b = s_b2.reshape(1, 64, 200, 300)
    fm_vis(s_b, out_dir, "s_b_vis"+str(i))

最终结果:输出四个阶段的特征图,博主选了其中几张:

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述


http://www.niftyadmin.cn/n/5013262.html

相关文章

windows10使用wheel安装tensorflow2.13.0/2.10.0 (保姆级教程)

安装过程 安装虚拟环境安装virtualenv安装满足要求的python版本使用virtualenv创建指定python版本的虚拟环境 安装tensorflow安装tensorflow-docs直接下载使用wheel下载 在VSCode编辑器中使用虚拟环境下的包常见错误 注意: tensorflow 2.10.0是最后一个支持GPU的版本…

Python 统一地铁线路名称

最近在做一个文本挖掘项目时遇到一个很实际的问题:文本里对地铁线路名称的表述很杂乱,如何统一。 比如,地铁1号线,可能表述为1号线、地铁1号线、轨道1号线、轨道交通1号线、1号地铁、一号线、地铁一号线、轨道一号线、轨道交通一…

DataGridView选中的单元格求和

DataGridView单元格求和功能的基本思路是先得到选中的单元格, 1,在内存中定义两张表,一张存放列名,一张存放列名和数个。这样这两张表就开成了一对多的父子关系。 2,在将两张定及他们的父子关系添加到DataSet对象中 4…

vscode配置conda环境

vscode配置conda环境 写在最前面安装vscodeanaconda3 配置vscode中文vscode配置anaconda环境步骤 新建.ipynb项目 写在最前面 之前一直是jupyter notebookpycharm 帮朋友配置环境的时候发现:vscode结合了cell自动补齐,狠狠心动了 于是安装配置vscode 参…

手撕 队列

队列的基本概念 只允许在一端进行插入数据操作,在另一端进行删除数据操作的特殊线性表,队列具有先进先出 入队列:进行插入操作的一端称为队尾 出队列:进行删除操作的一端称为队头 队列用链表实现 队列的实现 队列的定义 队列…

如何预防CSRF攻击

CSRF 攻击的防范措施 CSRF(Cross-Site Request Forgery)攻击是一种常见的 Web 攻击,即攻击者在用户不知情的情况下,利用用户已登录的身份,向目标网站发送恶意请求,从而实现攻击目的。本文将介绍 CSRF 攻击…

微信小程序开发---页面导航

目录 一、页面导航的概念 二、页面导航的实现 (1)声明式导航 1、概念 2、导航到tabBar页面 3、导航非tabBar页面 4、后退导航 (2)编程式导航 1、导航到tabBar页面 2、导航到非tabBar页面 3、后退导航 三、导航传参 &…

【回溯算法】77. 组合

77. 组合 解题思路 回溯结束条件 track长度等于k 然后收集当前的路径遍历所有的节点 然后选择当前节点 通过start参数控制遍历 避免产生重复的子集移除当前的选择 class Solution {List<List<Integer>> res new LinkedList<>();LinkedList<Integer>…