松鼠乐园 松鼠乐园
  • 注册
  • 登录
  • 首页
  • 快捷入口
    • Vue
    • Tensorflow
    • Springboot
    • 语言类
      • CSS
      • ES5
      • ES6
      • Go
      • Java
      • Javascript
    • 工具类
      • Git
      • 工具推荐
    • 服务器&运维
      • Centos
      • Docker
      • Linux
      • Mac
      • MySQL
      • Nginx
      • Redis
      • Windows
    • 资源类
      • 论文
      • 书籍推荐
      • 后端资源
      • 前端资源
      • html网页模板
      • 代码
    • 性能优化
    • 测试
  • 重大新闻
  • 人工智能
  • 开源项目
  • Vue2.0从零开始
  • 广场
首页 › 人工智能 › 如何使用深度学习构建水果分类器

如何使用深度学习构建水果分类器

迦娜王
1年前人工智能
944 0 0

今天,我们将使用fast.ai库构建一个图像分类器,以对Fruit-360数据集进行分类。(https://www.kaggle.com/moltean/fruits)

步骤:

  • 加载数据和预训练模型。
  • 使用lr_find()查找合适的学习率并使用precomputed = True训练最后一层1-2个epochs。
  • 使用几个epochs的数据增强训练最后一层。
  • 解冻所有层。为较早的层设置较低的学习率并在机器学习模型上执行完整的训练。
  • 再次使用lr_find()。训练全网几个epochs。
  • 使用测试时间增加来改进预测。
  • 通过混淆矩阵理解分类器。

导入所需Python库

fromfrom fastai.transformsfastai. import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *

第1步:加载数据和预训练模型

首先,我们需要绘制数据的分布,Python代码如下:

dataset_path = \'../dataset/kaggle/fruits-360/\'
train_path = dataset_path   \'/train\'
val_path = dataset_path   \'/test\'
 
def get_nrof_images_of_classes(path):
 dic = {}
 class_names = [os.path.basename(x) for x in glob.glob(path   \'/*\')]
 for class_name in class_names:
 dic[class_name] = len(glob.glob(path   \'/\'   class_name   \'/*\'))
 return dic
 
def plot_nrof_images_histogram(path):
 nrof_images_of_train_classes = get_nrof_images_of_classes(path)
 values = nrof_images_of_train_classes.values()
 f = plt.figure(figsize=(20,20))
 plt.bar(range(len(values)), values)
 plt.show()
 
plot_nrof_images_histogram(train_path)
plot_nrof_images_histogram(val_path)
如何使用深度学习构建水果分类器

训练样品分布

如何使用深度学习构建水果分类器

测试样品分布

来自数据的一些样本:

如何使用深度学习构建水果分类器

Apple Braeburn

如何使用深度学习构建水果分类器

Apricot

如何使用深度学习构建水果分类器

Pepino

如何使用深度学习构建水果分类器

Carambula

因为这是一个水果数据集,所以我们想要将数据随机旋转并颠倒以创建良好的增强数据。

设置数据和模型,Python代码如下:

input_size = 224
model = resnet50
tfms = tfms_from_model(model, input_size, aug_tfms=transforms_side_on   [RandomDihedral()])
data = ImageClassifierData.from_paths(path=dataset_path, bs=64, tfms=tfms, val_name=\'test\')
learner = ConvLearner.pretrained(f=model, data=data, precompute=True)

那么什么是precompute = True?

当precompute为True时,库从一开始就计算所有的激活,并为倒数第二层保存计算。因此,当我们训练最后一层时,我们只需要将precompute提供给最后一层,我们不必一直向前和向后计算神经网络的所有层。这节省了很多时间!

如果precompute = True,则Augmentation 将不起作用。

这是因为precompute 需要特定输入才能准确计算该输入的激活。Augmentation 过程将产生大量随机输入,这就是为什么当我们设置precompute = True时它们将被禁用。

第2步:找到学习率并训练最后一层

lrs = learner.lr_find()
learner.sched.plot()
如何使用深度学习构建水果分类器

我们应该使学习率小于最佳学习率。在这里,我们选择学习率= 0.004,并通过precompute 激活首次训练最后一层。

learner.fit(lrs=0.04, n_cycle=1)
如何使用深度学习构建水果分类器

步骤3:关闭precompute 并使用增强数据训练模型

首先,我们检查增强数据,以确保它们有意义。Python代码如下:

def get_augs():
 data = ImageClassifierData.from_paths(dataset_path, bs=2, tfms=tfms, num_workers=1, val_name=\'test\')
 x,_ = next(iter(data.aug_dl))
 return data.trn_ds.denorm(x)[1]
 
ims = np.stack([get_augs() for i in range(6)])
plots(ims, rows=2)
如何使用深度学习构建水果分类器

增强数据(随机旋转)

设置precompute = False并训练最后一层3个epochs

learner.precompute = False
learner.fit(lrs=0.04, n_cycle=3, cycle_len=1)
如何使用深度学习构建水果分类器

第四步:解冻所有层。在我们的模型上执行完整的训练

我们将模型分为三个模块,三个不同的学习速率。我们为早期的块设置了较低的学习率,因为我们不想破坏ImageNet训练的权重(ImageNet是具有大量图像的数据集)。我们只是想稍微改变一下,让它们更拟合我们的数据。

learner.unfreeze()
lr = 0.04
lrs=np.array([lr/500, lr/50, lr])
learner.fit(lrs=lrs, n_cycle=2, cycle_len=2, cycle_mult=2)
如何使用深度学习构建水果分类器

步骤5:再次找到学习率并更多地训练几个epochs

如何使用深度学习构建水果分类器

learner.fit(lrs=0.0005, n_cycle=2, cycle_len=2, cycle_mult=2)
如何使用深度学习构建水果分类器

步骤6:使用测试时间增加(TTA)来改进预测。

TTA生成测试数据的增强数据。通过得到输出分数的平均值,我们会得到稍微好一点的结果。

preds_tta, y = learner.TTA()
preds = np.argmax(np.mean(preds_tta, 0), axis=1)
acc = len(np.where((preds==data.val_ds.y)==True)[0]) / len(data.val_ds.y); 
acc

acc = 0.9904878576061108

第7步:使用混淆矩阵理解分类器

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y, preds)
plot_confusion_matrix(cm, data.classes, figsize=(35,35))
如何使用深度学习构建水果分类器

Confusion matrix

从混淆矩阵中,我们可以看到一些水果通常被错误分类。例如:Cherry 1 vs Cherry 2,Pepino vs Grape White,机器学习模型还需要进一步调整。

如何使用深度学习构建水果分类器

Cherry 1 vs Cherry 2

如何使用深度学习构建水果分类器

Apple Braeburn vs Apple Golden 2

你可以在这里看到完整的代码:

https://github.com/intheroom/Jupyter-Notebook/blob/master/Easy Image Classifier.ipynb

0
谷歌实时端到端双目系统深度学习网络stereonet
上一篇
Java程序员必备的15个框架
下一篇
评论 (0)

请登录以参与评论。

现在登录
聚合文章
在Gitee收获近 5k Star,更新后的Vue版RuoYi有哪些新变化?
2月前
vue3.x reactive、effect、computed、watch依赖关系及实现原理
2月前
Vue 3 新特性:在 Composition API 中使用 CSS Modules
2月前
新手必看的前端项目去中心化和模块化思想
2月前
标签
AI AI项目 css docker Drone Elaticsearch es5 es6 Geometry Go gru java Javascript jenkins lstm mysql mysql优化 mysql地理位置索引 mysql索引 mysql规范 mysql设计 mysql配置文件 mysql面试题 mysql高可用 nginx Redis redis性能 rnn SpringBoot Tensorflow tensorflow2.0 UI设计 vue vue3.0 vue原理 whistle ZooKeeper 开源项目 抓包工具 日志输出 机器学习 深度学习 神经网络 论文 面试题
相关文章
我收集了12款自动生成器,无聊人士自娱自乐专用
输入一张图,就能让二次元老婆动起来,宛如3D:这全是为了科学啊
使用ONNX+TensorRT部署人脸检测和关键点250fps
基于 Keras 的烟火检测
松鼠乐园

资源整合,创造价值

小伙伴
墨魇博客 无同创意
目录
重大新闻 Centos CSS Docker ES5 ES6 Go Java Javascript Linux Mac MySQL Nginx Redis Springboot Tensorflow Vue Vue2.x从零开始 Windows 书籍推荐 人工智能 前端资源 后端资源 壁纸 开源项目 测试 论文
Copyright © 2018-2021 松鼠乐园. Designed by nicetheme. 浙ICP备15039601号-4
  • 重大新闻
  • Centos
  • CSS
  • Docker
  • ES5
  • ES6
  • Go
  • Java
  • Javascript
  • Linux
  • Mac
  • MySQL
  • Nginx
  • Redis
  • Springboot
  • Tensorflow
  • Vue
  • Vue2.x从零开始
  • Windows
  • 书籍推荐
  • 人工智能
  • 前端资源
  • 后端资源
  • 壁纸
  • 开源项目
  • 测试
  • 论文
热门搜索
  • jetson nano
  • vue
  • java
  • mysql
  • 人工智能
  • 人脸识别
迦娜王
坚持才有希望
1224 文章
33 评论
235 喜欢
  • 0
  • 0
  • Top