机器学习实战(九)

数回归

分类回归树 Classification And Regression Trees 分类回归树。该算法既可以用于回归还可以用于分类。

复杂数据的局部性建模

数回归

优点:可以对复杂和线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型数据

第三章使用的树构建的算法是ID3。ID3的做法是每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来切分。也就取值,那么数据将被切分成4份,一但按某种特征切分后,该特征在之后的算法执行过程中将不会再起作用,所以有观点认为这种切分方式过于迅速。另一种方法是二元切分发,即每次吧数据集切分成两份。如果数据的某个特征等于切分所要求的值,那么这些数据就进入树的左子树,反之则进入树的右子树。

除了切分过于迅速外,ID3算法还存在另一个问题,它不能直接处理连续型特征。只有事先将连续型特征转换成离散型,才能在ID3算法中使用。但这种转换过程会破坏连续型变量的内在性质。而使用二元切分法则对于树构建过程进行调整以处理连续型特征。

具体处理方法是:

如果特征值大于给定值就走左子树,否则就走右子树。

另外,二元切分法也节省了树的构建时间,但这点意义也不是特别大因为这些树的构建一般是离线完成的。

CART是十分著名且广泛记载的树构建算法,它使用二元切分来处理连续型变量。对CART稍作修改就可以处理回归问题。

回归树的一般方法:

  1. 收集数据
  2. 准备数据:需要数值型的数据,标称型数据应该映射成二值型数据
  3. 分析数据:绘出数据的二维可视化显示结果,以字典方式生成树
  4. 训练算法:大部分时间都花费在叶节点树模型的构建上
  5. 测试算法:使用测试数据上的R^2值来分析模型的效果
  6. 使用算法:使用训练出的树做预测

连续和离散型特征的树的构建

在树的构建过程中,需要解决多种类型数据的存储问题。这里将使用字典来存储树的数据结构,该字典将包含以下四种元素。

待切分的特征
待切分的特征值
右子树。当不需要切分时,也可以是单值
左子树。与右子树类似

CART算法只做二元切分,所以这里可以固定树的数据结构。树包含左键和右键,可以存储另一颗树或者单个值。字典还包含特征和特征值这两个键,它们给出的切分算法所有的特征和特征值。

接下来构建两种树,第一种是回归树(regression tree)其中每个叶节点包含单个值,第二种是是模型树(model tree)其中每个叶节点包含一个线性方程。

createTree()的伪代码大致如下:

找到最佳的待切分特征:
    如果该节点不能再分,将该节点存为叶节点
    执行二元切分
    在右子树调用createTree()方法
    在左子树调用createTree()方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from numpy import *
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = map(float, curLine)
fltLine = list(fltLine)
dataMat.append(fltLine)
fr.close()
return dataMat

def regLeaf(dataSet): # returns the value used for each leaf
return mean(dataSet[:,-1])

def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
# 满足条件时返回叶节点值
if feat == None:
return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree

def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
return mat0, mat1
1
2
testMat = mat(eye(4))
testMat
matrix([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
1
mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)
1
mat0
matrix([[0., 1., 0., 0.]])
1
mat1
matrix([[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
1
nonzero(testMat[:, 1] > 0.5)[0][0]
1

将CART算法用于回归

要对数据的复杂关系建模,我们已经决定借用树结构来帮助切分数据,那么如何实现数据的切分呢?

为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。在给定的节点计算数据的混乱度,计算数据混乱度的方法,首先计算所有数据的均值,然后计算每条数据的值到均值的差值,为了对正负差值同等看待,一般使用绝对值或平方值来代替上述差值。类似方差的计算,唯一不同是方差是平方误差的均值,而这里需要的是平方误差的总值,总方差可以通过均方差乘以数据集中样本点的个数来得到。

构建树

构建回归树,需要补充一些新的代码,首先要做的就是实现chooseBestSplit()函数,给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。另外该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。因此chooseBestSplit()函数需要完成两件事:用最佳方式切分数据集和生成相应的叶节点。

leafType是对创建叶节点的函数的引用,errType是对前面介绍的总方差计算函数的引用,而ops是一个用户定义的参数构成的元组,用已完成树的构建。

伪代码如下:

对每个特征:
    对每个特征值:
        将数据集切分成两份
        计算切分的误差
        如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
tolS = ops[0]
tolN = ops[1]

if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 如果所有值相等则退出
return None, leafType(dataSet)

m, n = shape(dataSet)
S = errType(dataSet)
bestS = inf
bestIndex = 0
bestValue = 0

for featIndex in range(n-1):
for splitVal in set(dataSet[:, featIndex].tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN) :
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS

if (S - bestS) < tolS: # 如果误差减小不大则退出
return None, leafType(dataSet)

mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)

if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): # 如果切分出的数据集很小则退出
return None, leafType(dataSet)

return bestIndex, bestValue

regLeaf()它负责生成叶节点。当chooseBestSplit()函数确定不再对数据进行切分时,将调用该regLeaf()函数来得到叶节点的模型,在回归树中,该模型就是目标变量的均值。

regErr()是误差估计函数,该函数在给定数据上计算目标变量的平方误差,当然也可以先计算出均值,然后计算每个差值再平方。因为这里需要总方差,所以用均方差函数var()的结果乘以数据集中的样本个数。

chooseBestSplit()该函数的目的是找到数据的最佳二元切分方式。如果找不到一个好的二元切分,该函数返回None并同时调用createTree()方法来产生叶节点,叶节点的值也将返回None。ops设定了tolS和tolN两个值,tolS是容许的误差下降值,tolN是切分的最少样本数。

运行代码

1
myDat = loadDataSet('MLiA_SourceCode/Ch09/ex00.txt')
1
myMat = mat(myDat)
1
createTree(myMat)
{'spInd': 0,
 'spVal': 0.036098,
 'left': 0.5878577680412371,
 'right': 0.050698999999999994}
1
2
3
4
5
6
7
import matplotlib.pyplot as plt
def plotScatter(data):
fig = plt.figure()
ax = fig.add_subplot(111)
#print(myMat)
ax.scatter(data[:, 0].T.tolist()[0], data[:, 1].T.tolist()[0], 5, c='red')
plt.show()
1
plotScatter(myMat)

png

1
2
3
myDat = loadDataSet('MLiA_SourceCode/Ch09/ex0.txt')
myMat = mat(myDat)
createTree(myMat)
{'spInd': 1,
 'spVal': 0.409175,
 'left': {'spInd': 1,
  'spVal': 0.663687,
  'left': {'spInd': 1,
   'spVal': 0.725426,
   'left': 3.7206952592592595,
   'right': 2.998615611111111},
  'right': 2.2076016800000002},
 'right': 0.45470547435897446}
1
plotScatter(myMat[:, 1:])

png

树剪枝

通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。

预剪枝

树构建算法其实对输入的参数tolS和tolN非常敏感,对ops参数调整就是预剪枝。

后剪枝

利用测试集来对树进行剪枝,不需要用户指定参数,为后剪枝。

prune()伪代码如下

基于已有的树切分测试数据:
    如果存在任一子集是一棵树,则在该子集递归剪枝过程
    计算将当前两个叶节点合并后的误差
    计算不合并的误差
    如果合并降低误差,就将叶节点合并
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def isTree(obj):
return (type(obj).__name__=='dict')

def getMean(tree):
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['left']+tree['right'])/2.0

def prune(tree, testData):
if shape(testData)[0] == 0:
return getMean(tree) # if we have no test data collapse the tree
if (isTree(tree['right']) or isTree(tree['left'])): # if the branches are not trees try to prune them
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
# if they are now both leafs, see if we can merge them
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
sum(power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
return treeMean
else:
return tree
else:
return tree
1
2
3
4
5
myDat2 = loadDataSet('MLiA_SourceCode/Ch09/ex2.txt')
myMat2 = mat(myDat2)
myTree = createTree(myMat2, ops=(0, 1))
myDatTest = loadDataSet('MLiA_SourceCode/Ch09/ex2test.txt')
myMat2Test = mat(myDatTest)
1
prune(myTree, myMat2Test)
{'spInd': 0,
 'spVal': 0.228628,
 'left': {'spInd': 0,
  'spVal': 0.965969,
  'left': 92.5239915,
  'right': 65.53919801898735},
 'right': -1.1055498250000002}

模型树

用树来对数据建模,除了吧叶节点简单地设定为常数值之外,还有一种方法是把叶节点设定为分段线性函数,这里所谓的分段性(piecewise linear)是指模型由多个线性片段组成。

决策树相比其他机器学习算法的优势之一在于结果更易理解。很显然,两条直线比很多节点组成一棵大树更容易解释。模型树的可解释性是它优于回归树的特点之一。另外,模型树也具有更高的预测准确度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def linearSolve(dataSet):
m, n = shape(dataSet)
X = mat(ones((m, n)))
Y = mat(ones((m, 1)))
X[:, 1:n] = dataSet[:, 0:n-1]
Y = dataSet[:, -1]
xTx = X.T*X

if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse,\
try increasing the second value of ops')

ws = xTx.I * (X.T * Y)
return ws, X, Y

def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)

def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat, 2))
1
createTree(myMat2, modelLeaf, modelErr, (1, 10))
{'spInd': 0, 'spVal': 0.228628, 'left': None, 'right': None}

总结

这一章提供的code有很多错误,修正后并不能得到书中的答案。如果要使用树算法,还是建议使用sklearn,而非自己编写。

CART算法可以用于构建二元树并处理离散型或连续型数据的切分。若使用不同的误差准则,就可以通过CART算法构建模型树和回归树。该算法构建出的树会倾向于对数据过拟合。过拟合的树十分复杂,剪枝可以解决这个问题。

0%