不写R包的分析师不是好全栈

mxnet例子,图像识别

    技术学习

介绍下MXnet里面图像识别的一个小例子,原文在这里

需要下载预先训练好的Inception-BatchNorm network模型,参考文献见[1]


基本代码


主要完成的任务:



  1. 加载相关包

  2. 读取图片,并转换为2242243的RGB矩阵

  3. 放入训练模型中

  4. 得到预测结果


预先训练的模型包括1000个类,(y是1000个分类变量)分布于各种不同的动物


将原文件解压后运行示例代码:


library(mxnet)
library(imager)
# 载入模型
model = mx.model.load("Inception/Inception_BN", iteration=39)
# 载入mean image
mean.img = as.array(mx.nd.load("Inception/mean_224.nd")[["mean_img"]])

使用imager自带的鹦鹉图片来进行测试,数据读取,展示:


im <- load.image(system.file("extdata/parrots.png", package="imager"))
plot(im)


将输入图像转换为224*224的图像,作为模型输入,这部分不属于mxnet的范畴,可以看做是数据规范的过程


preproc.image <-function(im, mean.image) {
# crop the image
shape <- dim(im)
short.edge <- min(shape[1:2])
yy <- floor((shape[1] - short.edge) / 2) + 1
yend <- yy + short.edge - 1
xx <- floor((shape[2] - short.edge) / 2) + 1
xend <- xx + short.edge - 1
croped <- im[yy:yend, xx:xend,,]
# resize to 224 x 224, needed by input of the model.
resized <- resize(croped, 224, 224)
# convert to array (x, y, channel)
arr <- as.array(resized)
dim(arr) = c(224, 224, 3)
# substract the mean
normed <- arr - mean.img
# Reshape to format needed by mxnet (width, height, channel, num)
dim(normed) <- c(224, 224, 3, 1)
return(normed)
}

normed <- preproc.image(im, mean.img)


模型的预测


预测(好棒,我就喜欢这样的预测风格):


prob <- predict(model, X=normed)

预测完毕后查看下概率最大的那个类的名称:


max.idx <- max.col(t(prob))
synsets <- readLines("Inception/synset.txt")
print(paste0("Predicted Top-class: ", synsets[[max.idx]]))

## [1] "Predicted Top-class: n01818515 macaw"

恩,macaw的意思是金刚鹦鹉




其他的实验


我另外又从百度下载下来一个猫和两个狗的图片做了一点测试:


先把以上的功能打包成一个函数:


showPic = function(input){
cat <- load.image(system.file(input, package="imager"))
plot(cat)
normed <- preproc.image(cat, mean.img)
prob <- predict(model, X=normed)
max.idx <- max.col(t(prob))
print(paste0("Predicted Top-class: ", synsets[[max.idx]]))
}

结果解读下:



  1. 我第一个输入的是布偶猫,不是暹罗猫,二者很像~

  2. 第二输入的是二哈,正确识别,husky

  3. 随便搞了个拉布拉多输入,恩,也是OK的


showPic("extdata/cat.jpg")


## [1] "Predicted Top-class: n02123597 Siamese cat, Siamese"

showPic("extdata/husky.jpg")


## [1] "Predicted Top-class: n02110185 Siberian husky"

showPic("extdata/labuladuo.jpg")


## [1] "Predicted Top-class: n02099712 Labrador retriever"

(本来我以为能识别是猫是狗就够厉害了,发现这么渣的配置下,连品种都能预测到一部分,好点赞~)


[1] Ioffe, Sergey, and Christian Szegedy. “Batch normalization: Accelerating deep network training by reducing internal covariate shift.” arXiv preprint arXiv:1502.03167 (2015).



page PV:  ・  site PV:  ・  site UV: