基于torch的图像识别训练策略与常用模块

数据预处理部分:

  • 数据增强:torchvision中transforms模块自带功能,比较实用
  • 数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可
  • DataLoader模块直接读取batch数据

网络模块设置:

  • 加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习
  • 需要注意的是别人训练好的任务跟咱们的可不是完全一样,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务
  • 训练时可以全部重头训练,也可以只训练最后咱们任务的层,因为前几层都是做特征提取的,本质任务目标是一致的

网络模型保存与测试

  • 模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存
  • 读取模型进行实际测试
data_transforms = {
    'train': 
        transforms.Compose([
        transforms.Resize([96, 96]),
        transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ]),
    'valid': 
        transforms.Compose([
        transforms.Resize([64, 64]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

选择性的权重更新

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

自定义修改模型输出层,以resnet18为例

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    
    model_ft = models.resnet18(pretrained=use_pretrained)
    set_parameter_requires_grad(model_ft, feature_extract)
    
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, 102)#类别数自己根据自己任务来
                            
    input_size = 64#输入大小根据自己配置来

    return model_ft, input_size

训练权重 选择

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)

#GPU还是CPU计算
model_ft = model_ft.to(device)

# 模型保存,名字自己起
filename='checkpoint.pth'

# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

基本训练代码

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25,filename='best.pt'):
    #咱们要算时间的
    since = time.time()
    #也要记录最好的那一次
    best_acc = 0
    #模型也得放到你的CPU或者GPU
    model.to(device)
    #训练过程中打印一堆损失和指标
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    #学习率
    LRs = [optimizer.param_groups[0]['lr']]
    #最好的那次模型,后续会变的,先初始化
    best_model_wts = copy.deepcopy(model.state_dict())
    #一个个epoch来遍历
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 训练和验证
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # 训练
            else:
                model.eval()   # 验证

            running_loss = 0.0
            running_corrects = 0

            # 把数据都取个遍
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)#放到你的CPU或GPU
                labels = labels.to(device)

                # 清零
                optimizer.zero_grad()
                # 只有训练的时候计算和更新梯度
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                # 训练阶段更新权重
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # 计算损失
                running_loss += loss.item() * inputs.size(0)#0表示batch那个维度
                running_corrects += torch.sum(preds == labels.data)#预测结果最大的和真实值是否一致
                
            
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)#算平均
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            time_elapsed = time.time() - since#一个epoch我浪费了多少时间
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            

            # 得到最好那次的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                  'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重
                  'best_acc': best_acc,
                  'optimizer' : optimizer.state_dict(),
                }
                torch.save(state, filename)
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                #scheduler.step(epoch_loss)#学习率衰减
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
        
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()
        scheduler.step()#学习率衰减

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 训练完后用最好的一次当做模型最终的结果,等着一会测试
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

调用训练

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20)
def im_convert(tensor):
    """ 展示数据"""
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)
    return image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/552404.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

两阶段提交进阶

两阶段提交之进阶 上一节我们讲了,两阶段提交逻辑上的表现,其实较为肤浅,并且偏向理论,可能大家都能看懂,但是如果放入实际的mysql应用中并联系事务和日志进行分析,又会怎么样呢? 这次就专门分…

Unity类银河恶魔城学习记录13-1 p142 Save system源代码

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释,可供学习Alex教程的人参考 此代码仅为较上一P有所改变的代码 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili FileDataHandler.cs using System; using System.IO; using UnityEngine; p…

软考133-上午题-【软件工程】-软件项目估算

一、COCOMO 估算模型 COCOMO 模型是一种精确的、易于使用的成本估算模型。 COCOMO 模型按其详细程度分为:基本 COCOMO 模型、中级 COCOMO 模型和详细 COCOMO 模型。 1)基本 COCOMO 模型 基本 COCOMO 模型是一个静态单变量模型,用于对整个软…

内衣裤洗衣机如何选购?掌握这六个挑选技巧,轻松选购!

这两年内衣裤洗衣机可以称得上较火的小电器,小小的身躯却有大大的能力,一键可以同时启动洗、漂、脱三种全自动为一体化功能,在多功能和性能的提升上,还可以解放我们双手的同时将衣物给清洗干净,让越来越多小伙伴选择一…

node基础 第二篇

01 ffmpeg开源跨平台多媒体处理工具,处理音视频,剪辑,合并,转码等 FFmpeg 的主要功能和特性:1.格式转换:FFmpeg 可以将一个媒体文件从一种格式转换为另一种格式,支持几乎所有常见的音频和视频格式,包括 MP…

Node Version Manager(nvm):轻松管理 Node.js 版本的利器

文章目录 前言一、名词解释1、node.js是什么?2、nvm是什么? 二、安装1.在 Linux/macOS 上安装2.在 Windows 上安装 二、使用1.查看可安装的node版本2.安装node3. 查看已安装node4.切换node版本5.其它 总结 前言 Node.js 是现代 Web 开发中不可或缺的一部…

docker-compose 安装MongoDB续创建用户及赋权

文章目录 1. 问题描述2. 分析2.1 admin2.2 config2.3 local 3. 如何连接3.解决 1. 问题描述 在这一篇使用docker-compose创建MongoDB环境的笔记里,我们创建了数据库,但是似乎没有办法使用如Robo 3T这样的工具去连接数据库。连接的时候会返回这样的错误&…

c语言,单链表的实现----------有全代码!!!!

1.单链表的定义和结构 单链表是一种链式的数据结构,它用一组不连续的储存单元存反线性表中的数据元素。链表中的数据是以节点的形式来表示的,节点和节点之间相互连接 一般来说节点有两部分组成 1.数据域 :数据域用来存储各种类型的数据&…

基于SpringBoot+Vue的疾病防控系统设计与实现(源码+文档+包运行)

一.系统概述 在如今社会上,关于信息上面的处理,没有任何一个企业或者个人会忽视,如何让信息急速传递,并且归档储存查询,采用之前的纸张记录模式已经不符合当前使用要求了。所以,对疾病防控信息管理的提升&a…

windows 如何安装 perl ?

链接:https://strawberryperl.com/ 我们选择安装 “草莓 perl” 下载后根据引导安装就行了

node.jd版本降级/升级

第一步.先清空本地安装的node.js版本 按健winR弹出窗口,键盘输入cmd,然后敲回车(或者鼠标直接点击电脑桌面最左下角的win窗口图标弹出,输入cmd再点击回车键) 进入命令控制行窗口,输入where node,查看本地…

双指针的引入和深入思考(持续更新中)

目录 1.引入双指针 2.使用场景 3.例题引入 1.引入双指针 当我们需要维护某个区间性质的或者是求满足某些性质的区间的长度时,对于一个区间是由左右端点的,我们有简单的枚举左右端点的O()的时间的做法,当时在大多数题目中是不可行的&#…

DataX案例,MongoDB数据导入HDFS与MySQL

【尚硅谷】Alibaba开源数据同步工具DataX技术教程_哔哩哔哩_bilibili 目录 1、MongoDB 1.1、MongoDB介绍 1.2、MongoDB基本概念解析 1.3、MongoDB中的数据存储结构 1.4、MongoDB启动服务 1.5、MongoDB小案例 2、DataX导入导出案例 2.1、读取MongoDB的数据导入到HDFS 2…

论文笔记:Does Writing with Language Models Reduce Content Diversity?

iclr 2024 reviewer评分 566 1 intro 大模型正在迅速改变人们创造内容的方式 虽然基于LLM的写作助手有可能提高写作质量并增加作者的生产力,但它们也引入了算法单一文化——>论文旨在评估与LLM一起写作是否无意中降低了内容的多样性论文设计了一个控制实验&…

Kubernetes部署应用利器Helm详解

文章目录 一、helm概述&安装1.为什么需要Helm2.Helm介绍3.Helm架构4.部署Helm客户端5.Helm基本使用5.1 创建Chart示例 二、Helm 应用部署、升级1.创建项目(chat所需目录、文件)2.创建/拷贝项目的yaml文件到templates目录下3.使用Helm进行部署项目4.H…

第十五届蓝桥杯复盘python大学A组——试题B 召唤数学精灵

按照正常思路解决,由于累乘消耗大量时间,因此这不是一个明智的解决方案。 这段代码执行速度非常慢的原因在于它试图计算非常大的数的阶乘(累乘),并且对于每一个i的值都执行这个计算。阶乘的增长是极其迅速的&#xff…

49.HarmonyOS鸿蒙系统 App(ArkUI)Tab导航组件的使用

HarmonyOS鸿蒙系统 App(ArkUI)Tab导航组件的使用 图片显示 Row() {Image($r(app.media.leaf)).height(100).width(100)Image($r(app.media.icon)).height(100).width(100) } 左侧导航 import prompt from ohos.prompt; import promptAction from ohos.promptAction; Entry C…

vue2知识点1 ———— (vue指令,vue的响应式基础)

vue2的知识点,更多前端知识在主页,还有其他知识会持续更新 Vue 指令 Vue指令是Vue.js中的一个重要概念,用于向DOM元素添加特定行为或功能。Vue指令以v-开头,例如v-bind、v-if、v-for等。 v-bind 动态绑定属性 用法&#xff1a…

windows ubuntu 子系统:肿瘤全外篇,2. fq 数据质控,比对。

首先我们先下载一组全外显子测序数据。nabi sra库,随机找了一个。 来自受试者“16177_CCPM_1300019”(SRR28391647, SRR28398576)的样本“16177_CCPM_1300019_BB5”的基因组DNA配对端文库“0369547849_Illumina_P5-Popal_P7-Hefel”的Illumina随机外显子测序 下载下…

SGI_STL空间配置器源码剖析(一)总览

SGI 全称为 Silicon Graphics [Computer System] Inc. 硅图[计算机系统] 公司,SGI_STL是SGI实现的C的标准模板库。 SGI STL的空间配置器包括一级和二级两种。 一级空间配置器allocator采用malloc和free来管理内存,这与C标准库中提供的allocator是相似的…
最新文章