knn算法简介


  • 行业应用:比如文字识别、面部识别、预测某人是否喜欢推荐电影
        基因模式识别:比如用于检测某种疾病,更适合于稀有事件的分类问题(客户流失识别)
  • 应用场合:通常最近邻分类器适用特征与目标类之间的关系比较复杂的数字类型或者二者关系难以理解,但是相似类间的特征都是相似的

特点:
1. 简单有效,对数据分布没有假设,数据训练也很快。

  1. 但是他没有模型输出,因此限制了对特征的理解。

  2. 分类阶段比较慢。

  3. 需要标准化(nominal)特征以及缺少数据需要预先处理

优点 缺点
简单且有效 不产生模型.在发现特彻之间关系上的能力有限
对数据的分布没有要求 分类阶段很慢, 需要大量的内存
训练阶段很快 名义变量(特征变量)和缺失数据需要额外处理

k的取值:
1. k通常在3~10之间直接取值(看分析者的心情)
2. 可采用一般方法:k等于训练数据个数的平方根(15个数据,k可能取4)

1. 导入数据

#导入数据
# import the CSV file
wbcd <- read.csv("wisc_bc_data.csv", stringsAsFactors = FALSE)

# 查看一下数据结构,发现除了要预测的变量diagnosis是字符型变量其余全是数字型变量
str(wbcd)
## 'data.frame':    569 obs. of  32 variables:
##  $ id               : int  87139402 8910251 905520 868871 9012568 906539 925291 87880 862989 89827 ...
##  $ diagnosis        : chr  "B" "B" "B" "B" ...
##  $ radius_mean      : num  12.3 10.6 11 11.3 15.2 ...
##  $ texture_mean     : num  12.4 18.9 16.8 13.4 13.2 ...
##  $ perimeter_mean   : num  78.8 69.3 70.9 73 97.7 ...
##  $ area_mean        : num  464 346 373 385 712 ...
##  $ smoothness_mean  : num  0.1028 0.0969 0.1077 0.1164 0.0796 ...
##  $ compactness_mean : num  0.0698 0.1147 0.078 0.1136 0.0693 ...
##  $ concavity_mean   : num  0.0399 0.0639 0.0305 0.0464 0.0339 ...
##  $ points_mean      : num  0.037 0.0264 0.0248 0.048 0.0266 ...
##  $ symmetry_mean    : num  0.196 0.192 0.171 0.177 0.172 ...
##  $ dimension_mean   : num  0.0595 0.0649 0.0634 0.0607 0.0554 ...
##  $ radius_se        : num  0.236 0.451 0.197 0.338 0.178 ...
##  $ texture_se       : num  0.666 1.197 1.387 1.343 0.412 ...
##  $ perimeter_se     : num  1.67 3.43 1.34 1.85 1.34 ...
##  $ area_se          : num  17.4 27.1 13.5 26.3 17.7 ...
##  $ smoothness_se    : num  0.00805 0.00747 0.00516 0.01127 0.00501 ...
##  $ compactness_se   : num  0.0118 0.03581 0.00936 0.03498 0.01485 ...
##  $ concavity_se     : num  0.0168 0.0335 0.0106 0.0219 0.0155 ...
##  $ points_se        : num  0.01241 0.01365 0.00748 0.01965 0.00915 ...
##  $ symmetry_se      : num  0.0192 0.035 0.0172 0.0158 0.0165 ...
##  $ dimension_se     : num  0.00225 0.00332 0.0022 0.00344 0.00177 ...
##  $ radius_worst     : num  13.5 11.9 12.4 11.9 16.2 ...
##  $ texture_worst    : num  15.6 22.9 26.4 15.8 15.7 ...
##  $ perimeter_worst  : num  87 78.3 79.9 76.5 104.5 ...
##  $ area_worst       : num  549 425 471 434 819 ...
##  $ smoothness_worst : num  0.139 0.121 0.137 0.137 0.113 ...
##  $ compactness_worst: num  0.127 0.252 0.148 0.182 0.174 ...
##  $ concavity_worst  : num  0.1242 0.1916 0.1067 0.0867 0.1362 ...
##  $ points_worst     : num  0.0939 0.0793 0.0743 0.0861 0.0818 ...
##  $ symmetry_worst   : num  0.283 0.294 0.3 0.21 0.249 ...
##  $ dimension_worst  : num  0.0677 0.0759 0.0788 0.0678 0.0677 ...
dim(wbcd)
## [1] 569  32

乳腺癌数据包括 569 例细胞活检案例, 每个案例有32 个特征。一个特征是识别号码(id变量),一 个特征是癌症诊断结果(diagnosis变量), 其他 30 个特征是数值型的实验室测挝结果。癌症诊断结果用编码“M”表示恶性,用编码“B”表示良性。

2. 数据预处理————(因子变量转化为数字变量)

#第一个名为ID的整形变量(作用起唯一性,不能提供有用的信息)
wbcd <- wbcd[-1]

#对目标属性重新编码为因子类型
wbcd$diagnosis <- factor(wbcd$diagnosis, levels = c("B", "M"),
                         labels = c("良性B", "恶性M"))


# 查看变量diagnosis目标属性的结果数目
table(wbcd$diagnosis)
## 
## 良性B 恶性M 
##   357   212
#计算变量diagnosis目标属性的占比情况
round(prop.table(table(wbcd$diagnosis)) * 100, digits = 1)
## 
## 良性B 恶性M 
##  62.7  37.3
#分析其余30个特征都是数字型变量,因此不需要进一步处理
#但是仔细观察每个变量之间存在数量级差异,则需要标准化——数据处理

3.数据预处理————标准化处理(max-min)和划分训练集以及测试集

normalize <- function(x) {
  return ((x - min(x)) / (max(x) - min(x)))
}
wbcd_n <- as.data.frame(lapply(wbcd[2:31], normalize))
##由于排列随机性,所以我们之间选取则可
wbcd_train <- wbcd_n[1:469, ]
wbcd_test <- wbcd_n[470:569, ]
##训练数据和测试数据的目标变量(分类结果先行保留)
wbcd_train_labels <- wbcd[1:469, 1]
wbcd_test_labels <- wbcd[470:569, 1]

4. knn算法实现(class包)

创建分类器并进行预测:
p <- knn(train, test, class, k)
            train: 一个包含数值型训练数据的数据框
            test:   一个包含数值型测试数据的数据框
            class :  包含训练数据每一行分类的一个因子向量
             k:    标识最近邻数目的一个整数

该函数返回一个因子向量,  该向量含有测试数据框中每一行的预测分类。
例子:
wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test,
                      cl = wbcd_train_labels, k=3)

建议k用奇数,这样会减少各个类别票数相等这一情况的发生

library(class)
wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test,
                      cl = wbcd_train_labels, k=21)

5. 模型性能的评估(gmodels包)

library(gmodels)

# 创建两个识别向量的交叉表(类似table),prop.chisq=FALSE讲话从输出中除去不需要的卡方值
CrossTable(x = wbcd_test_labels, y = wbcd_test_pred,
           prop.chisq=FALSE)
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  100 
## 
##  
##                  | wbcd_test_pred 
## wbcd_test_labels |     良性B |     恶性M | Row Total | 
## -----------------|-----------|-----------|-----------|
##            良性B |        61 |         0 |        61 | 
##                  |     1.000 |     0.000 |     0.610 | 
##                  |     0.968 |     0.000 |           | 
##                  |     0.610 |     0.000 |           | 
## -----------------|-----------|-----------|-----------|
##            恶性M |         2 |        37 |        39 | 
##                  |     0.051 |     0.949 |     0.390 | 
##                  |     0.032 |     1.000 |           | 
##                  |     0.020 |     0.370 |           | 
## -----------------|-----------|-----------|-----------|
##     Column Total |        63 |        37 |       100 | 
##                  |     0.630 |     0.370 |           | 
## -----------------|-----------|-----------|-----------|
## 
## 
library(caret)
confusionMatrix(wbcd_test_labels,wbcd_test_pred,positive = "良性B")#confusionMatrix(真实值,预测值,positive =“  ” 阳性积极的为什么,根据不同模型数据而不同
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction 良性B 恶性M
##      良性B    61     0
##      恶性M     2    37
##                                           
##                Accuracy : 0.98            
##                  95% CI : (0.9296, 0.9976)
##     No Information Rate : 0.63            
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.9576          
##                                           
##  Mcnemar's Test P-Value : 0.4795          
##                                           
##             Sensitivity : 0.9683          
##             Specificity : 1.0000          
##          Pos Pred Value : 1.0000          
##          Neg Pred Value : 0.9487          
##              Prevalence : 0.6300          
##          Detection Rate : 0.6100          
##    Detection Prevalence : 0.6100          
##       Balanced Accuracy : 0.9841          
##                                           
##        'Positive' Class : 良性B           
## 

6. 改善模型的性能

  • 可尝试 用不同的标准化
  • 可尝试用不同的k值
  • 可尝试10折交叉重复验证(随机)
  • 。。。。。。

7. 分析最终结果

按上述5 来分析,只有两个分析错了,有可能是病人的原因。。。。。。。。当然原因有很多种,需要分析者对这个结果有充分的解释能力以及判断能力.

表格中单元格的百分比表示落在4个分类里的值所占的比例。在左上角的单元格(标记为TN)中,是真阴性(True Negative)的结果。100个值中有61个值标识肿块是良性的,而kNN算法也正确地把它们标识为良性的。在右下角的单元格(标记为TP)中,显示的是真阳性(True Positive)的结果,这里表示的是分类器和临床确定的标签一致认为肿块是恶性的情形。100个预测值中有37个是真阳性(True Positive)的。

落在另一条对角线上的单元格包含了kNN算法与真实标签不一致的案例计数。位于左下角FN单元格的2个案例是假阴性(False Negative)的结果。在这种情况下,预测的值是良性的,但肿瘤实际上是恶性的。这个方向上的错误可能会产生极其高昂的代价,因为它们可能导致一位病人认为自己没有癌症,而实际上这种疾病可能会继续蔓延。如果右上角标记为FP的单元格里有值.它包含的是假阳性(False Positive)的结果。当模型把肿块标识为恶性时而事实上它是良性时就会产生这里的值。尽管这类错误没有假阴性(False Negative)的结果那么危险,但这类错误也应该避免,因为它们可能会导致医疗系统的额外财政负担,或者病人的额外压力,毕竟这需要提供额外的检查或者治疗。

一共有2%,即根据kNN算法,100个肿块中.有2个是被错误分类的。虽然对于仅用几行的R代码,就得到98%的准确度似乎令入印象深刻,但是我们可以尝试一些其他的模型迭代方法来看看我们是否可以提高性能并减少错误分类值的数量,特别当错误是危险的假阴性(False Negative)结果时




附件

knn实现方法

  1. class包 —- knn
  2. caret包 —- knn3
  3. Rweka包 —- IBK
library(RWeka)

Rweka_knn=IBk(Species~.,data=iris)#,control=Weka_control(k=21,x=TRUE))
table(predict(Rweka_knn,iris[1:4]),iris$Species)#不用交叉验证,直接预测
##             
##              setosa versicolor virginica
##   setosa         50          0         0
##   versicolor      0         50         0
##   virginica       0          0        50
#上面knn中的k好像自己选取1,那么我们需要自动选取怎么办呢?用Weka_control()调参
#如下,自动选取1:k=20里面最合适的参数
Rweka_knns=IBk(Species~.,data=iris,control = Weka_control(K = 20,X = TRUE))#注意k、x的大小写,这里都是大写
table(predict(Rweka_knns,iris[1:4]),iris$Species)#不用交叉验证,直接预测
##             
##              setosa versicolor virginica
##   setosa         50          0         0
##   versicolor      0         49         4
##   virginica       0          1        46
#Rweka_knn为分类器,evaluate_Weka_classifier()这个函数把分类器的数据平均分成10分,做10折交叉验证,查看结果
#类似于CrossTable()函数
evaluate_Weka_classifier(Rweka_knn,numFolds = 10)
## === 10 Fold Cross Validation ===
## 
## === Summary ===
## 
## Correctly Classified Instances         143               95.3333 %
## Incorrectly Classified Instances         7                4.6667 %
## Kappa statistic                          0.93  
## Mean absolute error                      0.0401
## Root mean squared error                  0.1748
## Relative absolute error                  9.0146 %
## Root relative squared error             37.0711 %
## Total Number of Instances              150     
## 
## === Confusion Matrix ===
## 
##   a  b  c   <-- classified as
##  50  0  0 |  a = setosa
##   0 47  3 |  b = versicolor
##   0  4 46 |  c = virginica
evaluate_Weka_classifier(Rweka_knns,numFolds = 10)
## === 10 Fold Cross Validation ===
## 
## === Summary ===
## 
## Correctly Classified Instances         142               94.6667 %
## Incorrectly Classified Instances         8                5.3333 %
## Kappa statistic                          0.92  
## Mean absolute error                      0.0468
## Root mean squared error                  0.156 
## Relative absolute error                 10.5235 %
## Root relative squared error             33.0967 %
## Total Number of Instances              150     
## 
## === Confusion Matrix ===
## 
##   a  b  c   <-- classified as
##  50  0  0 |  a = setosa
##   0 47  3 |  b = versicolor
##   0  5 45 |  c = virginica
sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Mojave 10.14.5
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] zh_CN.UTF-8/zh_CN.UTF-8/zh_CN.UTF-8/C/zh_CN.UTF-8/zh_CN.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] RWeka_0.4-43    caret_6.0-86    ggplot2_3.3.2   lattice_0.20-41
## [5] gmodels_2.18.1  class_7.3-17   
## 
## loaded via a namespace (and not attached):
##  [1] gtools_3.8.2         tidyselect_1.1.0     xfun_0.17           
##  [4] purrr_0.3.4          reshape2_1.4.4       rJava_0.9-13        
##  [7] splines_4.0.2        colorspace_1.4-1     vctrs_0.3.2         
## [10] generics_0.0.2       stats4_4.0.2         htmltools_0.5.0     
## [13] yaml_2.2.1           survival_3.1-12      prodlim_2019.11.13  
## [16] rlang_0.4.7          e1071_1.7-3          ModelMetrics_1.2.2.2
## [19] pillar_1.4.6         glue_1.4.1           withr_2.2.0         
## [22] foreach_1.5.0        lifecycle_0.2.0      plyr_1.8.6          
## [25] lava_1.6.7           stringr_1.4.0        timeDate_3043.102   
## [28] munsell_0.5.0        blogdown_0.20        gtable_0.3.0        
## [31] recipes_0.1.13       codetools_0.2-16     evaluate_0.14       
## [34] RWekajars_3.9.3-2    knitr_1.29           Rcpp_1.0.5          
## [37] scales_1.1.1         gdata_2.18.0         ipred_0.9-9         
## [40] digest_0.6.25        stringi_1.4.6        bookdown_0.20       
## [43] dplyr_1.0.1          grid_4.0.2           tools_4.0.2         
## [46] magrittr_1.5         tibble_3.0.3         crayon_1.3.4        
## [49] pkgconfig_2.0.3      MASS_7.3-51.6        ellipsis_0.3.1      
## [52] Matrix_1.2-18        data.table_1.13.0    pROC_1.16.2         
## [55] lubridate_1.7.9      gower_0.2.2          rmarkdown_2.3       
## [58] iterators_1.0.12     R6_2.4.1             rpart_4.1-15        
## [61] nnet_7.3-14          nlme_3.1-148         compiler_4.0.2

次;