1111
1212我们在第 8 章中介绍了线性回归的一些强大的方法,但这些方法创建的模型需要拟合所有的样本点(局部加权线性回归除外)。当数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型的想法就显得太难了,也略显笨拙。而且,实际生活中很多问题都是非线性的,不可能使用全局线性模型来拟合任何数据。
1313
14- 一种可行的方法是将数据集切分成很多分易建模的数据 ,然后利用我们的线性回归技术来建模。如果首次切分后仍然难以拟合线性模型就继续切分。在这种切分方式下,树回归和回归法就相当有用。
14+ 一种可行的方法是将数据集切分成很多份易建模的数据 ,然后利用我们的线性回归技术来建模。如果首次切分后仍然难以拟合线性模型就继续切分。在这种切分方式下,树回归和回归法就相当有用。
1515
1616除了我们在 第3章 中介绍的 决策树算法,我们介绍一个新的叫做 CART(Classification And Regression Trees, 分类回归树) 的树构建算法。该算法既可以用于分类还可以用于回归。
1717
1818## 1、树回归 原理
1919
20- 为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。第3章使用树进行分类,会在给点节点时计算数据的混乱度。那么如何计算连续型数值的混乱度呢?
20+ ### 1.1、树回归 原理概述
21+
22+ 为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。第3章使用树进行分类,会在给定节点时计算数据的混乱度。那么如何计算连续型数值的混乱度呢?
2123
2224在这里,计算连续型数值的混乱度是非常简单的。首先计算所有数据的均值,然后计算每条数据的值到均值的差值。为了对正负差值同等看待,一般使用绝对值或平方值来代替上述差值。
2325
@@ -33,6 +35,8 @@ CART 是十分著名且广泛记载的树构建算法,它使用二元切分来
3335
3436回归树与分类树的思路类似,但是叶节点的数据类型不是离散型,而是连续型。
3537
38+ #### 1.2.1、附加 各常见树构造算法的划分分支方式
39+
3640还有一点要说明,构建决策树算法,常用到的是三个方法: ID3, C4.5, CART.
3741三种方法区别是划分树的分支的方式:
38421 . ID3 是信息增益分支
@@ -47,7 +51,17 @@ CART 和 C4.5 之间主要差异在于分类结果上,CART 可以回归分析
4751
4852### 1.3、树回归 工作原理
4953
50- 函数 createTree()
54+ 1、找到数据集切分的最佳位置,函数 chooseBestSplit() 伪代码大致如下:
55+
56+ ```
57+ 对每个特征:
58+ 对每个特征值:
59+ 将数据集切分成两份
60+ 计算切分的误差
61+ 如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差
62+ 返回最佳切分的特征和阈值
63+ ```
64+ 2、树构建算法,函数 createTree() 伪代码大致如下:
5165
5266```
5367找到最佳的待切分特征:
@@ -109,11 +123,12 @@ data1.txt 文件中存储的数据格式如下:
109123
110124> 分析数据:绘出数据的二维可视化显示结果,以字典方式生成树
111125
126+ 基于 CART 算法构建回归树的简单数据集
112127![ 基于 CART 算法构建回归树的简单数据集] ( ../images/9.TreeRegression/RegTree_1.png )
113- 基于 CART 算法构建回归树的简单数据集
114128
129+ 用于测试回归树的分段常数数据集
115130![ 用于测试回归树的分段常数数据集] ( ../images/9.TreeRegression/RegTree_2.png )
116- 用于测试回归树的分段常数数据集
131+
117132
118133> 训练算法: 构造树的数据结构
119134
@@ -255,7 +270,7 @@ def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
255270
256271### 2.2、后剪枝(postpruning)
257272
258- 决策树构造完成后进行剪枝。剪枝的过程是对拥有同样父节点的一组节点进行检查,判断如果将其合并,熵的增加量是否小于某一阈值。如果确实小,则这一组节点可以合并一个节点,其中包含了所有可能的结果。后剪枝是目前最普遍的做法。
273+ 决策树构造完成后进行剪枝。剪枝的过程是对拥有同样父节点的一组节点进行检查,判断如果将其合并,熵的增加量是否小于某一阈值。如果确实小,则这一组节点可以合并一个节点,其中包含了所有可能的结果。合并也被称作 ` 塌陷处理 ` ,在回归树中一般采用取需要合并的所有子树的平均值。 后剪枝是目前最普遍的做法。
259274
260275后剪枝的剪枝过程是删除一些子树,然后用其叶子节点代替,这个叶子节点所标识的类别通过大多数原则(majority class criterion)确定。所谓大多数原则,是指剪枝过程中, 将一些子树删除而用叶节点代替,这个叶节点所标识的类别用这棵子树中大多数训练样本所属的类别来标识,所标识的类 称为majority class ,(majority class 在很多英文文献中也多次出现)。
261276
@@ -276,11 +291,28 @@ def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
276291``` python
277292# 判断节点是否是一个字典
278293def isTree (obj ):
294+ """
295+ Desc:
296+ 测试输入变量是否是一棵树,即是否是一个字典
297+ Args:
298+ obj -- 输入变量
299+ Returns:
300+ 返回布尔类型的结果。如果 obj 是一个字典,返回true,否则返回 false
301+ """
279302 return (type (obj).__name__ == ' dict' )
280303
281304
282305# 计算左右枝丫的均值
283306def getMean (tree ):
307+ """
308+ Desc:
309+ 从上往下遍历树直到叶节点为止,如果找到两个叶节点则计算它们的平均值。
310+ 对 tree 进行塌陷处理,即返回树平均值。
311+ Args:
312+ tree -- 输入的树
313+ Returns:
314+ 返回 tree 节点的平均值
315+ """
284316 if isTree(tree[' right' ]):
285317 tree[' right' ] = getMean(tree[' right' ])
286318 if isTree(tree[' left' ]):
@@ -290,6 +322,15 @@ def getMean(tree):
290322
291323# 检查是否适合合并分枝
292324def prune (tree , testData ):
325+ """
326+ Desc:
327+ 从上而下找到叶节点,用测试数据集来判断将这些叶节点合并是否能降低测试误差
328+ Args:
329+ tree -- 待剪枝的树
330+ testData -- 剪枝所需要的测试数据 testData
331+ Returns:
332+ tree -- 剪枝完成的树
333+ """
293334 # 判断是否测试数据集没有数据,如果没有,就直接返回tree本身的均值
294335 if shape(testData)[0 ] == 0 :
295336 return getMean(tree)
@@ -304,7 +345,9 @@ def prune(tree, testData):
304345 if isTree(tree[' right' ]):
305346 tree[' right' ] = prune(tree[' right' ], rSet)
306347
307- # 如果左右两边同时都不是dict字典,那么分割测试数据集。
348+ # 上面的一系列操作本质上就是将测试数据集按照训练完成的树拆分好,对应的值放到对应的节点
349+
350+ # 如果左右两边同时都不是dict字典,也就是左右两边都是叶节点,而不是子树了,那么分割测试数据集。
308351 # 1. 如果正确
309352 # * 那么计算一下总方差 和 该结果集的本身不分枝的总方差比较
310353 # * 如果 合并的总方差 < 不合并的总方差,那么就进行合并
@@ -331,15 +374,15 @@ def prune(tree, testData):
331374### 3.1、模型树 简介
332375用树来对数据建模,除了把叶节点简单地设定为常数值之外,还有一种方法是把叶节点设定为分段线性函数,这里所谓的 ` 分段线性(piecewise linear) ` 是指模型由多个线性片段组成。
333376
334- 我们看一下图 9-4 中的数据,如果使用两条直线拟合是否比使用一组常数来建模好呢?答案显而易见。可以设计两条分别 0.0~ 0.3、从 0.3~ 1.0 的直线,于是就可以得到两个线性模型。因为数据集里的一部分数据(0.0~ 0.3)以某个线性模型建模,而另一部分数据(0.3~ 1.0)则以另一个线性模型建模,因此我们说采用了所谓的分段线性模型。
377+ 我们看一下图 9-4 中的数据,如果使用两条直线拟合是否比使用一组常数来建模好呢?答案显而易见。可以设计两条分别从 0.0~ 0.3、从 0.3~ 1.0 的直线,于是就可以得到两个线性模型。因为数据集里的一部分数据(0.0~ 0.3)以某个线性模型建模,而另一部分数据(0.3~ 1.0)则以另一个线性模型建模,因此我们说采用了所谓的分段线性模型。
335378
336379决策树相比于其他机器学习算法的优势之一在于结果更易理解。很显然,两条直线比很多节点组成一棵大树更容易解释。模型树的可解释性是它优于回归树的特点之一。另外,模型树也具有更高的预测准确度。
337380
338381![ 分段线性数据] ( ../images/9.TreeRegression/RegTree_3.png )
339382
340383将之前的回归树的代码稍作修改,就可以在叶节点生成线性模型而不是常数值。下面将利用树生成算法对数据进行划分,且每份切分数据都能很容易被线性模型所表示。这个算法的关键在于误差的计算。
341384
342- 那么为了找到最佳切分,应该怎样计算误差呢?前面用于回归树的误差计算方法这里不能再用。稍加变化,对于给定的数据集,应该先用模型来对它进行拟合,然后计算真实的目标值与模型预测值键的差值 。最后将这些差值的平方求和就得到了所需的误差。
385+ 那么为了找到最佳切分,应该怎样计算误差呢?前面用于回归树的误差计算方法这里不能再用。稍加变化,对于给定的数据集,应该先用模型来对它进行拟合,然后计算真实的目标值与模型预测值间的差值 。最后将这些差值的平方求和就得到了所需的误差。
343386
344387### 3.2、模型树 代码
345388
@@ -349,12 +392,28 @@ def prune(tree, testData):
349392# 得到模型的ws系数:f(x) = x0 + x1*featrue1+ x3*featrue2 ...
350393# create linear model and return coeficients
351394def modelLeaf (dataSet ):
395+ """
396+ Desc:
397+ 当数据不再需要切分的时候,生成叶节点的模型。
398+ Args:
399+ dataSet -- 输入数据集
400+ Returns:
401+ 调用 linearSolve 函数,返回得到的 回归系数ws
402+ """
352403 ws, X, Y = linearSolve(dataSet)
353404 return ws
354405
355406
356407# 计算线性模型的误差值
357408def modelErr (dataSet ):
409+ """
410+ Desc:
411+ 在给定数据集上计算误差。
412+ Args:
413+ dataSet -- 输入数据集
414+ Returns:
415+ 调用 linearSolve 函数,返回 yHat 和 Y 之间的平方误差。
416+ """
358417 ws, X, Y = linearSolve(dataSet)
359418 yHat = X * ws
360419 # print corrcoef(yHat, Y, rowvar=0)
@@ -363,6 +422,16 @@ def modelErr(dataSet):
363422
364423 # helper function used in two places
365424def linearSolve (dataSet ):
425+ """
426+ Desc:
427+ 将数据集格式化成目标变量Y和自变量X,执行简单的线性回归,得到ws
428+ Args:
429+ dataSet -- 输入数据
430+ Returns:
431+ ws -- 执行线性回归的回归系数
432+ X -- 格式化自变量X
433+ Y -- 格式化目标变量Y
434+ """
366435 m, n = shape(dataSet)
367436 # 产生一个关于1的矩阵
368437 X = mat(ones((m, n)))
@@ -395,7 +464,7 @@ def linearSolve(dataSet):
395464
396465前面介绍了模型树、回归树和一般的回归方法,下面测试一下哪个模型最好。
397466
398- 这些模型将在某个数据上进行测试,该数据涉及人的智力水平和自行车的速度的关系。
467+ 这些模型将在某个数据上进行测试,该数据涉及人的智力水平和自行车的速度的关系。当然,数据是假的。
399468
400469#### 4.1.2、开发流程
401470
@@ -430,12 +499,33 @@ def linearSolve(dataSet):
430499用树回归进行预测的代码
431500``` python
432501# 回归树测试案例
502+ # 为了和 modelTreeEval() 保持一致,保留两个输入参数
433503def regTreeEval (model , inDat ):
504+ """
505+ Desc:
506+ 对 回归树 进行预测
507+ Args:
508+ model -- 指定模型,可选值为 回归树模型 或者 模型树模型,这里为回归树
509+ inDat -- 输入的测试数据
510+ Returns:
511+ float(model) -- 将输入的模型数据转换为 浮点数 返回
512+ """
434513 return float (model)
435514
436515
437516# 模型树测试案例
517+ # 对输入数据进行格式化处理,在原数据矩阵上增加第0列,元素的值都是1,
518+ # 也就是增加偏移值,和我们之前的简单线性回归是一个套路,增加一个偏移量
438519def modelTreeEval (model , inDat ):
520+ """
521+ Desc:
522+ 对 模型树 进行预测
523+ Args:
524+ model -- 输入模型,可选值为 回归树模型 或者 模型树模型,这里为模型树模型
525+ inDat -- 输入的测试数据
526+ Returns:
527+ float(X * model) -- 将测试数据乘以 回归系数 得到一个预测值 ,转化为 浮点数 返回
528+ """
439529 n = shape(inDat)[1 ]
440530 X = mat(ones((1 , n+ 1 )))
441531 X[:, 1 : n+ 1 ] = inDat
@@ -444,7 +534,21 @@ def modelTreeEval(model, inDat):
444534
445535
446536# 计算预测的结果
537+ # 在给定树结构的情况下,对于单个数据点,该函数会给出一个预测值。
538+ # modelEval是对叶节点进行预测的函数引用,指定树的类型,以便在叶节点上调用合适的模型。
539+ # 此函数自顶向下遍历整棵树,直到命中叶节点为止,一旦到达叶节点,它就会在输入数据上
540+ # 调用modelEval()函数,该函数的默认值为regTreeEval()
447541def treeForeCast (tree , inData , modelEval = regTreeEval):
542+ """
543+ Desc:
544+ 对特定模型的树进行预测,可以是 回归树 也可以是 模型树
545+ Args:
546+ tree -- 已经训练好的树的模型
547+ inData -- 输入的测试数据
548+ modelEval -- 预测的树的模型类型,可选值为 regTreeEval(回归树) 或 modelTreeEval(模型树),默认为回归树
549+ Returns:
550+ 返回预测值
551+ """
448552 if not isTree(tree):
449553 return modelEval(tree, inData)
450554 if inData[tree[' spInd' ]] <= tree[' spVal' ]:
@@ -461,10 +565,22 @@ def treeForeCast(tree, inData, modelEval=regTreeEval):
461565
462566# 预测结果
463567def createForeCast (tree , testData , modelEval = regTreeEval):
568+ """
569+ Desc:
570+ 调用 treeForeCast ,对特定模型的树进行预测,可以是 回归树 也可以是 模型树
571+ Args:
572+ tree -- 已经训练好的树的模型
573+ inData -- 输入的测试数据
574+ modelEval -- 预测的树的模型类型,可选值为 regTreeEval(回归树) 或 modelTreeEval(模型树),默认为回归树
575+ Returns:
576+ 返回预测值矩阵
577+ """
464578 m = len (testData)
465579 yHat = mat(zeros((m, 1 )))
580+ # print yHat
466581 for i in range (m):
467582 yHat[i, 0 ] = treeForeCast(tree, mat(testData[i]), modelEval)
583+ # print "yHat==>", yHat[i, 0]
468584 return yHat
469585```
470586[ 完整代码地址] ( https://github.com/apachecn/MachineLearning/blob/master/src/python/9.RegTrees/regTrees.py ) : < https://github.com/apachecn/MachineLearning/blob/master/src/python/9.RegTrees/regTrees.py >
0 commit comments