博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
AdaBoost元算法
阅读量:5276 次
发布时间:2019-06-14

本文共 11708 字,大约阅读时间需要 39 分钟。

boosting:不同的分类器是通过串行训练而获得的,每个新分类器都根据已经训练出的分类器的性能来进行训练。通过集中关注被已有分类器错分的那些样本来获得新的分类器。

权重alpha:弱分类器的线性组合系数,用来构成完整分类器。对每个数据的分类时,其结果是弱分类器结果的线性组合。

权重D:样本的权重向量,每个元素表征对应样本的重要性。m*1阶列向量。

基于单层决策树构建弱分类器:仅基于单个特征来做决策。

单层决策树生成函数:

from numpy import *def loadSimpData():    datMat = matrix([[ 1. ,  2.1],        [ 2. ,  1.1],        [ 1.3,  1. ],        [ 1. ,  1. ],        [ 2. ,  1. ]])    classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]    return datMat,classLabelsdef stumpClassify(dataMatrix,dimen,threshVal,threshIneq):    retArray=ones((shape(dataMatrix)[0],1))    if threshIneq=='lt':        retArray[dataMatrix[:,dimen]<=threshVal]=-1.0    else:        retArray[dataMatrix[:,dimen]>threshVal]=-1.0    return retArraydef buildStump(dataArr,classLabels,D):    dataMatrix=mat(dataArr)    labelMat=mat(classLabels).T    m,n=shape(dataMatrix)    numSteps=10.0    bestStump={}    bestClasEst=mat(zeros((m,1)))    minError=inf    for i in range(n):        rangeMin=dataMatrix[:,i].min()        rangeMax=dataMatrix[:,i].max()        stepSize=(rangeMax-rangeMin)/numSteps        for j in range(-1,int(numSteps)+1):            for inequal in ['lt','gt']:                threshVal=(rangeMin+float(j)*stepSize)                predictVals=stumpClassify(dataMatrix,i,threshVal,inequal)                errArr=mat(ones((m,1)))                errArr[predictVals==labelMat]=0                weightedError=D.T*errArr        #1*m m*1 ==>标量                print('split:dim %d,thresh:%.2f,thresh inequal:%s,the weighted error is %.3f'%(i,threshVal,inequal,weightedError))                if weightedError

stumpClassify(dataMatrix,dimen,threshVal,threshIneq):单层决策树,通过阈值比较对数据分类。所有在阈值一边的数据会分为-1,另一边的数据分为+1.该函数通过数组过滤来实现。分为两种模式:小于等于阈值分为-1,大于阈值分为+1;或者相反。

weightedError=D.T*errArr     #1*m m*1 ==>标量   将错误向量errArr和权重向量D的相应元素相乘并求和,得到数值weightedError,这就是AdaBoost与分类器交互的地方。这里基于权重向量D而不是其他错误计算指标来评价分类器。

输出

split:dim 0,thresh:0.90,thresh inequal:lt,the weighted error is 0.400split:dim 0,thresh:0.90,thresh inequal:gt,the weighted error is 0.600split:dim 0,thresh:1.00,thresh inequal:lt,the weighted error is 0.400split:dim 0,thresh:1.00,thresh inequal:gt,the weighted error is 0.600split:dim 0,thresh:1.10,thresh inequal:lt,the weighted error is 0.400split:dim 0,thresh:1.10,thresh inequal:gt,the weighted error is 0.600split:dim 0,thresh:1.20,thresh inequal:lt,the weighted error is 0.400split:dim 0,thresh:1.20,thresh inequal:gt,the weighted error is 0.600split:dim 0,thresh:1.30,thresh inequal:lt,the weighted error is 0.200split:dim 0,thresh:1.30,thresh inequal:gt,the weighted error is 0.800split:dim 0,thresh:1.40,thresh inequal:lt,the weighted error is 0.200split:dim 0,thresh:1.40,thresh inequal:gt,the weighted error is 0.800split:dim 0,thresh:1.50,thresh inequal:lt,the weighted error is 0.200split:dim 0,thresh:1.50,thresh inequal:gt,the weighted error is 0.800split:dim 0,thresh:1.60,thresh inequal:lt,the weighted error is 0.200split:dim 0,thresh:1.60,thresh inequal:gt,the weighted error is 0.800split:dim 0,thresh:1.70,thresh inequal:lt,the weighted error is 0.200split:dim 0,thresh:1.70,thresh inequal:gt,the weighted error is 0.800split:dim 0,thresh:1.80,thresh inequal:lt,the weighted error is 0.200split:dim 0,thresh:1.80,thresh inequal:gt,the weighted error is 0.800split:dim 0,thresh:1.90,thresh inequal:lt,the weighted error is 0.200split:dim 0,thresh:1.90,thresh inequal:gt,the weighted error is 0.800split:dim 0,thresh:2.00,thresh inequal:lt,the weighted error is 0.600split:dim 0,thresh:2.00,thresh inequal:gt,the weighted error is 0.400split:dim 1,thresh:0.89,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:0.89,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:1.00,thresh inequal:lt,the weighted error is 0.200split:dim 1,thresh:1.00,thresh inequal:gt,the weighted error is 0.800split:dim 1,thresh:1.11,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:1.11,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:1.22,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:1.22,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:1.33,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:1.33,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:1.44,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:1.44,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:1.55,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:1.55,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:1.66,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:1.66,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:1.77,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:1.77,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:1.88,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:1.88,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:1.99,thresh inequal:lt,the weighted error is 0.400split:dim 1,thresh:1.99,thresh inequal:gt,the weighted error is 0.600split:dim 1,thresh:2.10,thresh inequal:lt,the weighted error is 0.600split:dim 1,thresh:2.10,thresh inequal:gt,the weighted error is 0.400
View Code

 基于单层决策树的AdaBoost训练过程:

def adaBoostTrainDS(dataArr,classLabels,numIter=40):    weakClassArr=[]    m=shape(dataArr)[0]    D=mat(ones((m,1))/m)    #每个样本的权重均初始化为1/m    aggClassEst=mat(zeros((m,1)))    for i in range(numIter):        beatStump,error,classEst=buildStump(dataArr,classLabels,D)        print('D:',D.T)        alpha=float(0.5*log((1.0-error)/max(error,e-16)))        print('alpha:',alpha)        beatStump['alpha']=alpha        weakClassArr.append(beatStump)        print('分类估计:',classEst.T)        expon=multiply(-1*alpha*mat(classLabels).T,classEst)        D=multiply(D,exp(expon))        D=D/D.sum()        aggClassEst+=alpha*classEst        print('aggClassEst:',aggClassEst.T)        aggErrors=multiply(sign(aggClassEst)!=mat(classLabels).T,ones((m,1)))        errorRate=aggErrors.sum()/m        print('total error:',errorRate,'\n')        if errorRate==0.0:            break    return weakClassArr

D是概率分布向量,D中所有元素之和等于1.

首先利用前面的buildStump()函数建立一个单层决策树。该函数的输入为权重向量D,返回的则是利用D得到的具有最小错误率的单层决策树,同时返回的还有最小的错误率以及预测的类别向量。

alpha=float(0.5*log((1.0-error)/max(error,e-16)))  其中的max(error,e-16)是用来防止error很小时发生的除零溢出。

aggClassEst是m*1阶的列向量,用来存储运行时的类别估计值,符号代表预测结果,为正时表示目前此样本的预测类别为1,为负时表示-1.

aggClassEst+=alpha*classEst  用各弱分类器的分类结果与权重alpha的线性组合值作为最终的预测值。迭代一次,就产生一个弱分类器,相当于对最终的结果修正一次。

aggErrors=multiply(sign(aggClassEst)!=mat(classLabels).T,ones((m,1)))  将分类错误的样本对应位置设置为1,方便求出错误分类总数和错误率。

测试AdaBoost:

if __name__=='__main__':    D=mat(ones((5,1))/5)    datMat, classLabels=loadSimpData()    classifyArray=adaBoostTrainDS(datMat,classLabels,9)    print(classifyArray)

 输出:

D: [[0.2 0.2 0.2 0.2 0.2]]alpha: 0.6931471805599453分类估计: [[-1.  1. -1. -1.  1.]]aggClassEst: [[-0.69314718  0.69314718 -0.69314718 -0.69314718  0.69314718]]total error: 0.2 D: [[0.5   0.125 0.125 0.125 0.125]]alpha: 0.9729550745276565分类估计: [[ 1.  1. -1. -1. -1.]]aggClassEst: [[ 0.27980789  1.66610226 -1.66610226 -1.66610226 -0.27980789]]total error: 0.2 D: [[0.28571429 0.07142857 0.07142857 0.07142857 0.5       ]]alpha: 0.8958797346140273分类估计: [[1. 1. 1. 1. 1.]]aggClassEst: [[ 1.17568763  2.56198199 -0.77022252 -0.77022252  0.61607184]]total error: 0.0 [{'alpha': 0.6931471805599453, 'dim': 0, 'ineq': 'lt', 'thresh': 1.3}, {'alpha': 0.9729550745276565, 'dim': 1, 'ineq': 'lt', 'thresh': 1.0}, {'alpha': 0.8958797346140273, 'dim': 0, 'ineq': 'lt', 'thresh': 0.9}]

 classifyArray是数组,由三个弱分类器组成,包含了分类所需的所有信息。此时的训练错误率为0,以下讨论其测试错误率。


 

上述函数的返回值中含有弱分类器及其alpha值,容易进行测试:只需要将弱分类器提取出来作用到待分类数据上,每个弱分类器的结果以其对应的alpha值为权重,所有这些弱分类器的结果加权求和就得到了最后的结果。

if __name__=='__main__':    D=mat(ones((5,1))/5)    datMat, classLabels=loadSimpData()    classifyArray=adaBoostTrainDS(datMat,classLabels,9)    result=adaClassify([[5,5],[0,0]],classifyArray)    print('最终分类结果为:',result)

 输出:

aggClassEst: [[ 0.69314718] [-0.69314718]]aggClassEst: [[ 1.66610226] [-1.66610226]]aggClassEst: [[ 2.56198199] [-2.56198199]]最终分类结果为: [[ 1.] [-1.]]

 由aggClassEst可以看出,随着三个弱分类器的叠加,其预测结果越来越强,即为离分类边界值0的距离越来越远。


 

在一个难数据集上应用AdaBoost,预测疝病马能否存活。

 自适应数据加载函数,不需指定每个文件中的特征数目,并且假定最后一列数据是类别标签。

def loadDataSet(filename):    numFeatures = len(open(filename).readline().split('\t')) - 1    dataMat = []    labelMat = []    f = open(filename)    for line in f.readlines():        lineArr=[]        curLine=line.strip().split('\t')        for i in range(0,numFeatures):            lineArr.append(float(curLine[i]))        dataMat.append(lineArr)        labelMat.append(float(curLine[-1]))    return dataMat,labelMat

 用疝病马数据集测试元算法:

if __name__=='__main__':    dataArr,labelArr=loadDataSet('horseColicTraining2.txt')    classifyArray=adaBoostTrainDS(dataArr,labelArr,10)    testArr,testLabelArr=loadDataSet('horseColicTest2.txt')    prediction10=adaClassify(testArr,classifyArray)    errArr=mat(ones((67,1)))    count=errArr[prediction10!=mat(testLabelArr).T].sum()    print(prediction10)    print(count)

 输出:

total error: 0.2842809364548495 total error: 0.2842809364548495 total error: 0.24749163879598662 total error: 0.24749163879598662 total error: 0.25418060200668896 total error: 0.2408026755852843 total error: 0.2408026755852843 total error: 0.22073578595317725 total error: 0.24749163879598662 total error: 0.23076923076923078 [[ 1.] [ 1.] [ 1.] [-1.] [ 1.] [ 1.] [-1.] [ 1.] [ 1.] [-1.] [-1.] [-1.] [-1.] [ 1.] [ 1.] [ 1.] [ 1.] [-1.] [-1.] [-1.] [-1.] [ 1.] [-1.] [-1.] [ 1.] [ 1.] [ 1.] [ 1.] [ 1.] [ 1.] [ 1.] [-1.] [-1.] [-1.] [-1.] [ 1.] [ 1.] [ 1.] [ 1.] [ 1.] [ 1.] [ 1.] [ 1.] [-1.] [-1.] [ 1.] [-1.] [ 1.] [ 1.] [ 1.] [-1.] [ 1.] [ 1.] [ 1.] [ 1.] [ 1.] [ 1.] [-1.] [ 1.] [-1.] [ 1.] [-1.] [-1.] [ 1.] [ 1.] [ 1.] [ 1.]]16.0
View Code

 

迭代了10次,产生10个弱分类器,训练错误率最终为:total error: 0.23076923076923078

测试数据集上有67个样本,分类结果中有16个错误,错误率为16/67=0.23880597014925373,比起logistic回归预测结果35%的错误率降低很多。

 

from numpy import *def loadSimpData():    datMat = matrix([[ 1. ,  2.1],        [ 2. ,  1.1],        [ 1.3,  1. ],        [ 1. ,  1. ],        [ 2. ,  1. ]])    classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]    return datMat,classLabelsdef loadDataSet(filename):    numFeatures = len(open(filename).readline().split('\t')) - 1    dataMat = []    labelMat = []    f = open(filename)    for line in f.readlines():        lineArr=[]        curLine=line.strip().split('\t')        for i in range(0,numFeatures):            lineArr.append(float(curLine[i]))        dataMat.append(lineArr)        labelMat.append(float(curLine[-1]))    return dataMat,labelMatdef stumpClassify(dataMatrix,dimen,threshVal,threshIneq):    retArray=ones((shape(dataMatrix)[0],1))    if threshIneq=='lt':        retArray[dataMatrix[:,dimen]<=threshVal]=-1.0    else:        retArray[dataMatrix[:,dimen]>threshVal]=-1.0    return retArraydef buildStump(dataArr,classLabels,D):    dataMatrix=mat(dataArr)    labelMat=mat(classLabels).T    m,n=shape(dataMatrix)    numSteps=10.0    bestStump={}    bestClasEst=mat(zeros((m,1)))    minError=inf    for i in range(n):        rangeMin=dataMatrix[:,i].min()        rangeMax=dataMatrix[:,i].max()        stepSize=(rangeMax-rangeMin)/numSteps        for j in range(-1,int(numSteps)+1):            for inequal in ['lt','gt']:                threshVal=(rangeMin+float(j)*stepSize)                predictVals=stumpClassify(dataMatrix,i,threshVal,inequal)                errArr=mat(ones((m,1)))                errArr[predictVals==labelMat]=0                weightedError=D.T*errArr        #1*m m*1 ==>标量                # print('split:dim %d,thresh:%.2f,thresh inequal:%s,the weighted error is %.3f'%(i,threshVal,inequal,weightedError))                if weightedError
完整代码

 

转载于:https://www.cnblogs.com/zhhy236400/p/9921574.html

你可能感兴趣的文章
PHP扩展-扩展的生成和编译
查看>>
Hello world!
查看>>
部分网站公开数据的汇总(2)
查看>>
03-29复利计算单元测试
查看>>
android中 onResume()方法什么时候执行 ??(转)
查看>>
angularjs用回车键动态添加数据,同时渲染到页面
查看>>
软件设计师2008年12月下午试题4(C语言 动态规划)
查看>>
python基础 ---- 使用pyCharm 调试
查看>>
(转)虚函数和纯虚函数区别
查看>>
[小北De编程手记] : Lesson 05 玩转 xUnit.Net 之 从Assert谈UT框架实践
查看>>
我的记忆,我的年
查看>>
search Paths $(SRCROOT)和$(PROJECT_DIR)区别
查看>>
项目中坑(一)
查看>>
bam文件测序深度统计-bamdst
查看>>
JS 小数的常用处理方法
查看>>
Dom4J两种节点添加方法比较
查看>>
BZOJ2212——线段树合并
查看>>
背包九讲之四(混合三种背包问题)
查看>>
radio选中
查看>>
uva 725 Division(暴力模拟)
查看>>