松鼠乐园 松鼠乐园
  • 注册
  • 登录
  • 首页
  • 快捷入口
    • Vue
    • Tensorflow
    • Springboot
    • 语言类
      • CSS
      • ES5
      • ES6
      • Go
      • Java
      • Javascript
    • 工具类
      • Git
      • 工具推荐
    • 服务器&运维
      • Centos
      • Docker
      • Linux
      • Mac
      • MySQL
      • Nginx
      • Redis
      • Windows
    • 资源类
      • 论文
      • 书籍推荐
      • 后端资源
      • 前端资源
      • html网页模板
      • 代码
    • 性能优化
    • 测试
  • 重大新闻
  • 人工智能
  • 开源项目
  • Vue2.0从零开始
  • 广场
首页 › 人工智能 › 深度学习任务面临非平衡数据问题?试试这个简单方法

深度学习任务面临非平衡数据问题?试试这个简单方法

迦娜王
1年前人工智能
264 0 0

对于数据科学或机器学习研究者而言,当解决任何机器学习问题时,可能面临的最大问题之一就是训练数据不平衡的问题。本文将尝试使用图像分类问题来揭示训练数据中不平衡类别的奥秘。

深度学习任务面临非平衡数据问题?试试这个简单方法

数据不平衡问题是什么?

在一个分类问题中,当你想要预测一个或多个类中的样本数量极少时,可能会遇到数据中类不平衡的问题,即部分类的样本数量远远大于其它类中的样本数量。

例子

  • 欺诈预测(真实交易的欺诈数量要低得多);
  • 自然灾害预测(坏事件发生的频率将远远低于好事);
  • 识别图像分类中的恶性肿瘤(具有肿瘤的图像将比训练样本内的无肿瘤的图像少得多);

为什么这会是个问题?

不平衡课程造成问题主要是由于以下两个原因:

  • 由于模型/算法从来没有充分地查看全部类别信息,对于实时不平衡的类别没有得到最优化的结果;
  • 由于少数样本类的观察次数极少,这会产生一个验证或测试样本的问题,即很难在类中进行表示;

解决这个问题的方法有哪些?

解决这个问题的方法主要有三种,三种各有各自的优缺点:

  • 下采样(Undersampling):随机删除具有足够观察多样本的类,以便数据中类的数量比较平衡。虽然这种方法非常简单,但很有可能删除的数据中可能包含有关预测的重要信息。
  • 过采样(Oversampling):对于不平衡类(样本数少的类),随机地增加观测样本的数量,这些观测样本只是现有样本的副本,虽然增加了样本的数量,但过采样可能导致训练数据过拟合。
  • 合成取样(SMOT):该技术要求综合地制造不平衡类的样本,类似于使用最近邻分类。问题是当观察的数目是极其罕见的类时不知道怎么做。

    尽管每种方法都有各自的优点,但没有什么固定的使用方式,需要根据实际问题不断自己尝试。现在将使用深度学习特定的图像分类问题来详细研究这个问题。

图像分类中的不平衡类

在本节中,将分析一个图像分类问题(其中存在不平衡类问题),然后使用一种简单有效的技术来解决它。

问题:在kaggle上选择了“驼背鲸识别挑战”任务,期望解决不平衡类别的挑战(理想情况下,所分类的鲸鱼数量少于未分类的鲸类)。

Kagele上任务说明:在这场比赛中,面临的挑战是要建立一个算法来识别图像中的鲸鱼种类。将分析Happy Whale数据库(包含25,000多张图像),这些数据来自研究机构和公共贡献者。通过竞赛,你将有助于为全球海洋哺乳动物种群动态开启丰富的理解领域。

查看Happy Whale数据集

由于这是一个多标签图像分类问题,首先想要检查数据是如何在类中分布的。

深度学习任务面临非平衡数据问题?试试这个简单方法

上图表明,在4251张训练图像中,每个类只有一张图像的超过了2000张。还有一些类只有2~5张图像。可见这是一个严重的不平衡类问题。我们不能期望深度学习模型每个类别仅使用一张图像进行训练。这也会产生一个问题,即如何在训练和验证样本之间创建一个分界线,理想情况下希望每个类都在训练样本和验证样本中都有表示。

接下来应该做什么?

本文考虑了两个特别的选项:

  • 选项1:对训练样本进行严格的数据增强(只需要针对特定类的数据增强,单这可能无法完全解决本文的问题)。
  • 选项2:类似于之前提到的过采样技术。只是使用不同的图像增强技术将不平衡类的图像复制到训练数据中15次。

    在开始使用选项2处理数据之前,可以从训练样本中查看少量图像。

深度学习任务面临非平衡数据问题?试试这个简单方法

从图像中可以看到,图像是特定于鲸鱼的尾巴,因此,识别将可能与图像的方向有关。同时注意到数据中有很多图像是特定的黑白或只有R/G/B通道。

根据这些观察结果,使用以下代码对训练样本中不平衡类的图像进行小幅改动并保存:

import osfrom PIL import Imagefrom PIL import ImageFilter
filelist = train[\'Image\'].loc[(train[\'cnt_freq\']<10)].tolist()for count in range(0,2): 
 for imagefile in filelist:
 os.chdir(\'/home/paperspace/fastai/courses/dl1/data/humpback/train\')
 im=Image.open(imagefile)
 im=im.convert("RGB")
 r,g,b=im.split()
 r=r.convert("RGB")
 g=g.convert("RGB")
 b=b.convert("RGB")
 im_blur=im.filter(ImageFilter.GaussianBlur)
 im_unsharp=im.filter(ImageFilter.UnsharpMask)
 os.chdir(\'/home/paperspace/fastai/courses/dl1/data/humpback/copy\')
 r.save(str(count) \'r_\' imagefile)
 g.save(str(count) \'g_\' imagefile)
 b.save(str(count) \'b_\' imagefile)
 im_blur.save(str(count) \'bl_\' imagefile)
 im_unsharp.save(str(count) \'un_\' imagefile)

以上代码对不平衡类中的每张图像(频率小于10)都进行如下处理:

  • 将每张图像的增强副本保存为R / B&G ;
  • 保存每张图像的增强副本;
  • 保存每张图像未锐化的增强副本;

    在上面的代码中可以看到,使用pillow库来严格执行此练习,现在已经为所有不平衡的类分配了至少10个样本。接下来进行训练。

图像增强:只想确保模型能够获得鲸鱼fluke的详细视图。为此,将缩放合并成图像增强。

深度学习任务面临非平衡数据问题?试试这个简单方法

学习率设定:从图中可以看到,将学习率定为0.01时效果最好。

深度学习任务面临非平衡数据问题?试试这个简单方法

使用Resnet50模型(第一层参数不变)进行了很少的迭代训练就能取得很好的效果,这是由于imagenet数据库中也有鲸鱼图像。

epoch trn_loss val_loss accuracy 
 0 1.827677 0.492113 0.895976 
 1 0.93804 0.188566 0.964128 
 2 0.844708 0.175866 0.967555 
 3 0.571255 0.126632 0.977614 
 4 0.458565 0.116253 0.979991 
 5 0.410907 0.113607 0.980544 
 6 0.42319 0.109893 0.981097

测试数据集上效果如何?

在kaggle排行榜上可以看到模型在测试集上的效果,本文提出的解决方案在本次比赛中排名34,平均精度均值(MAP)为0.41928。

深度学习任务面临非平衡数据问题?试试这个简单方法

结论

有时候,最简单的方法是最合乎逻辑的(如果你没有更多的数据,只需要复制现有的数据,并有轻微的变化即可),也是最有效的。

作者信息

Shubrashankh Chatterjee,深度学习和数据科学爱好者

本文由阿里云云栖社区组织翻译。

文章原标题《Deep Learning Tips and Tricks》,译者:海棠,审校:Uncle_LLD。

0
深度学习到顶,AI寒冬将至!
上一篇
高德技术团队:深度学习在导航速度预测中的探索与实践
下一篇
评论 (0)

请登录以参与评论。

现在登录
聚合文章
在Gitee收获近 5k Star,更新后的Vue版RuoYi有哪些新变化?
2月前
vue3.x reactive、effect、computed、watch依赖关系及实现原理
2月前
Vue 3 新特性:在 Composition API 中使用 CSS Modules
2月前
新手必看的前端项目去中心化和模块化思想
2月前
标签
AI AI项目 css docker Drone Elaticsearch es5 es6 Geometry Go gru java Javascript jenkins lstm mysql mysql优化 mysql地理位置索引 mysql索引 mysql规范 mysql设计 mysql配置文件 mysql面试题 mysql高可用 nginx Redis redis性能 rnn SpringBoot Tensorflow tensorflow2.0 UI设计 vue vue3.0 vue原理 whistle ZooKeeper 开源项目 抓包工具 日志输出 机器学习 深度学习 神经网络 论文 面试题
相关文章
我收集了12款自动生成器,无聊人士自娱自乐专用
输入一张图,就能让二次元老婆动起来,宛如3D:这全是为了科学啊
使用ONNX+TensorRT部署人脸检测和关键点250fps
基于 Keras 的烟火检测
松鼠乐园

资源整合,创造价值

小伙伴
墨魇博客 无同创意
目录
重大新闻 Centos CSS Docker ES5 ES6 Go Java Javascript Linux Mac MySQL Nginx Redis Springboot Tensorflow Vue Vue2.x从零开始 Windows 书籍推荐 人工智能 前端资源 后端资源 壁纸 开源项目 测试 论文
Copyright © 2018-2021 松鼠乐园. Designed by nicetheme. 浙ICP备15039601号-4
  • 重大新闻
  • Centos
  • CSS
  • Docker
  • ES5
  • ES6
  • Go
  • Java
  • Javascript
  • Linux
  • Mac
  • MySQL
  • Nginx
  • Redis
  • Springboot
  • Tensorflow
  • Vue
  • Vue2.x从零开始
  • Windows
  • 书籍推荐
  • 人工智能
  • 前端资源
  • 后端资源
  • 壁纸
  • 开源项目
  • 测试
  • 论文
热门搜索
  • jetson nano
  • vue
  • java
  • mysql
  • 人工智能
  • 人脸识别
迦娜王
坚持才有希望
1224 文章
33 评论
235 喜欢
  • 0
  • 0
  • Top