ResNet神经网络搭建

一、定义残差结构

BasicBlock   

18层、34层网络对应的残差结构

浅层网络主线由两个3x3的卷积层链接,相加后通过relu激活函数输出。还有一个shortcut捷径

参数解释

        expansion = 1  : 判断对应主分支的残差结构有无变化

        downsample=None : 下采样参数,默认为none

        stride步距为1,对应实线残差结构 ; 步距为2,对应虚线残差结构

        self.conv2 = nn.Conv2d(in_channels=out_channel :卷积层1的输出即为输入

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample
  

定义正向传播

 identity = x : shortcut捷径上的输出值

identity = self.downsample(x) : 将输出特征矩阵x输入到下采样函数中得到捷径分支的输出 

 def forward(self, x):   #定义正向传播
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity  # 相加后通过relu激活函数输出
        out = self.relu(out)

        return out

BottleBlock  

50层、101层、152层神经网路对应的残差结构

深层网络主线由一个1x1的降维卷积层,3x3卷积层、1x1升维卷积层和一个shortcut捷径组成。

 按照残差结构进行定义,大致与BasicBlock参数一样,不同的是expansion=4,卷积核个数是之前的4倍。

class Bottleneck(nn.Module):
  
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

 定义正向传播

 if self.downsample is not None :  is None是实线  is not None 是虚线

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

二、 定义网络结构

block,  根据定义不同的层结构传入不同的block
blocks_num,  所使用残差结构的数目、参数列表
num_classes=1000, 分类个数
include_top=True, 在ResNet基础上搭建其他的网络

self.layer1 对应conv2_x
self.layer2 对应conv3_x
self.layer3 对应conv4_x
self.layer4 对应conv5_x 这一系列的残差结构都通过_make_layer函数实线

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)  不管输入是多杀,通过自适应平均池化下采样都会输出(1,1)

self.fc = nn.Linear(512 * block.expansion, num_classes)  通过全连接输出节点层,输入的节点个数是通过平均池化下采样层后的特征矩阵展平后所得到的节点个数,但是由于节点的高和宽都是1,所以节点的个数=深度

class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top   # 将参数参入变为类变量
        self.in_channel = 64  # 表格中通过maxpooling后得到的深度

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

三、定义_make_layer函数

block,  定义的残差结构

channel, 残差结构中卷积层使用卷积核的个数

block_num, 该层包含了几个残差结构

block(传入第一层残差结构

        self.in_channel,   
        channel,   主分支第一个卷积层卷积核的个数
        downsample=downsample,  下采样函数

 for _ in range(1, block_num): 通过循环,将剩下一系列的实线残差结构压入进去

 return nn.Sequential(*layers)  非关键字传入,将定义的一系列层结构组合并返回。

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:   # 18、34层会跳过这一部分 ;50、101、152层不会
            downsample = nn.Sequential(  # 生成下采样函数
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []  # 定义空的列表
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

 定义正向传播

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)  # 展平后链接
            x = self.fc(x)

        return x

四、建立不同的层结构 

传入的参数分别按照定义的顺序传入。

def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

后续继续补充。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/611241.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CellChat包文献介绍

Inference and analysis of cell-cell communication using CellChat - PubMed (nih.gov) 目录 在线数据 摘要 基础介绍 分析结果 1,概述 2,识别预测通路 3,连续的信号转导 4,预测空间共定位细胞群之间的关键信号转导事件…

企业活动想联系媒体报道宣传如何联系媒体?

在企业的宣传推广工作中,我曾经历过一段费事费力、效率极低的时期。那时,每当公司有重要活动或新项目需要媒体报道时,我便要一家家地联系媒体,发送邮件、打电话,甚至亲自登门拜访,只为求得一篇报道。然而,这样的过程充满了不确定性和挑战,时常让我感到焦虑和压力山大。 记得有一…

vue3专栏项目 -- 项目介绍以及准备工作

这是vue3TS的项目,是一个类似知乎的网站,可以展示专栏和文章的详情,可以登录、注册用户,可以创建、删除、修改文章,可以上传图片等等。 这个项目全部采用Composition API 编写,并且使用了TypeScript&#…

亚马逊产品排名提升全攻略:自养号测评干货

之前我们一同探讨了亚马逊产品排名的多种类型,现在让我们回到正题,探讨一下如何才能有效地提升产品排名,从而吸引并抓住平台的流量,最终将其转化为可观的销量。 首先,卖家必须明晰亚马逊的排名机制,它主要基…

网页版Figma汉化

最近学习Figma,简单介绍一下网页版Figma的汉化方法 1.打开网址:Figma软件汉化-Figma中文版下载-Figma中文社区 2.下载汉化插件离线包 解压汉化包 3.点开谷歌的管理扩展程序 4.点击加载已解压的扩展程序,选择刚刚解压的包 这样就安装好了汉化…

从0到1开发一个vue3+ts项目(一)

1. 环境配置 1.1 安装node 使用官方安装程序 前往 Node.js 官网:访问 Node.js 官网,下载适合你操作系统的安装程序。运行安装程序:下载完成后,双击安装程序并按照提示进行安装。验证安装:安装完成后,在终…

顺序表经典算法OJ题-- 力扣27,88

题1: 移除元素 题2: 合并两个有序数组 一:题目链接:. - 力扣(LetCode) 思路:(双指针法) 创建两个变量src,dst 1)若src指向的值为val&#xf…

Qt复习第二天

1、菜单栏工具栏状态栏 #include "mainwindow.h" #include "ui_mainwindow.h" #pragma execution_character_set("utf-8"); MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::MainWindow) {ui->setupUi(this);//菜…

粤嵌—2024/4/26—跳跃游戏 ||

代码实现&#xff1a; 方法一&#xff1a;回溯 历史答案剪枝优化——超时 int *dis;void dfs(int k, int startindex, int *nums, int numsSize) {if (dis[startindex] < k) {return;}dis[startindex] k;for (int i 0; i < nums[startindex]; i) {if (startindex i &…

嫁接打印的技术要点

所谓嫁接打印&#xff0c;是一种增减材混合制造的方式。它将已成形的模具零件当作基座&#xff0c;在此基础上“生长”出打印的零件。其中基座通常采用传统加工方式制造&#xff0c;而打印部分则使用专用的金属粉末&#xff0c;通过 3D 打印技术成型。 嫁接打印之所以备受欢迎&…

4.nginx.pid打开失败以及失效的解决方案

一. nginx.pid打开失败以及失效的解决方案 1.错误图片&#xff1a; 2.解决方法 步骤1&#xff1a;进入这个目录 /var/run/nginx,提示没有文件或目录&#xff0c;则使用mkdir创建这个目录。 步骤2&#xff1a;然后 ./nginx -s reload 运行,是一个无效的PID 步骤3&#xff1a;使…

SMI接口

目录 SMI 接口帧格式读时序写时序 IP 设计IP 例化界面IP 接口IP 验证 SMI 接口 SMI&#xff08;Serial Management Interface&#xff09;串行管理接口&#xff0c;也被称作 MII 管理接口&#xff08;MII Management Interface&#xff09;&#xff0c;包括 MDC 和 MDIO 两条信…

【字符串】Leetcode 二进制求和

题目讲解 67. 二进制求和 算法讲解 为了方便计算&#xff0c;我们将两个字符串的长度弄成一样的&#xff0c;在短的字符串前面添加字符0&#xff1b;我们从后往前计算&#xff0c;当遇到当前计算出来的字符是> 2’的&#xff0c;那么就需要往前面进位和求余 注意&#xf…

《QT实用小工具·六十二》基于QT实现贝塞尔曲线画炫酷的波浪动画

1、概述 源码放在文章末尾 该项目实现了通过贝塞尔曲线画波浪动画&#xff0c;可控制 颜色密度速度加速度 安装与运行环境 语言&#xff1a;C 框架&#xff1a;Qt 11.3 平台&#xff1a;Windows 将屏幕水平平均分为10块&#xff0c;在一定范围内随机高度的12个点&#xff08;…

OAuth 2.0 和 OAuth 2.1

OAuth 2.0 和 OAuth 2.1比较&#xff1a; OAuth 2.0 和 OAuth 2.1 是授权框架的不同版本&#xff0c;它们用于允许应用程序安全地访问用户在另一个服务上的数据。以下是它们之间的一些主要区别&#xff1a; 安全性增强&#xff1a;OAuth 2.1 旨在提高安全性&#xff0c;它整合…

C语言/数据结构——每日一题(移除链表元素)

一.前言 今天在leetcode刷到了一道关于单链表的题。想着和大家分享一下。废话不多说&#xff0c;让我们开始今天的知识分享吧。 二.正文 1.1题目要求 1.2思路剖析 我们可以创建一个新的单链表&#xff0c;然后通过对原单链表的遍历&#xff0c;将数据不等于val的节点移到新…

MySQL索引(聚簇索引、非聚簇索引)

了解MySQL索引详细&#xff0c;本文只做整理归纳&#xff1a;https://blog.csdn.net/wangfeijiu/article/details/113409719 概念 索引是对数据库表中一列或多列的值进行排序的一种结构&#xff0c;使用索引可快速访问数据库表中的特定信息。 索引分类 主键索引&#xff1a…

微信群发用什么软件最安全?微信群发软件哪个好?微信群发助手软件在哪里?

今天给大家推荐一款我们目前在使用的电脑群发工具掘金小蜜&#xff0c;不仅可以无限多开&#xff0c;方便你同时管理多个账号&#xff0c;群发功能更是十分强大&#xff0c;轻松释放你的双手。 掘金小蜜&#xff08;只支持Win7及以上操作系统&#xff0c;没有推Mac版和手机客户…

【算法入门赛】B. 自助店评分(C++、STL、推荐学习)题解与代码

比赛地址&#xff1a;https://www.starrycoding.com/contest/8 题目描述 在上一场的入门教育赛中&#xff0c;牢 e e e找到了所有自助店的位置&#xff0c;但是他想发现一些“高分好店”&#xff0c;于是他利用爬虫技术从“小众点评APP”中爬取了武汉所有自助店的评分。 评分…

[笔试训练](十八)

目录 052:字符串压缩 053:chika和蜜柑 054:01背包 052:字符串压缩 压缩字符串(一)_牛客题霸_牛客网 (nowcoder.com) 题目&#xff1a; 题解&#xff1a; 双指针模拟 class Solution { public:string compressString(string param) {int nparam.size();string ret;int num…