深入理解GBDT回归算法
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
来源:公众号 Microstrong 授权转载
目录:
1. GBDT简介
2. GBDT回归算法
2.1 GBDT回归算法推导
2.2 GBDT回归算法实例
3. 手撕GBDT回归算法
3.1 用Python3实现GBDT回归算法
3.2 用sklearn实现GBDT回归算法
4. GBDT回归任务常见的损失函数
5. GBDT的正则化
6. 关于GBDT若干问题的思考
7. 总结
8. Reference
本文的主要内容概览:

1. GBDT简介
决策树有以下优点:
决策树可以认为是if-then规则的集合,易于理解,可解释性强,预测速度快。
决策树算法相比于其他的算法需要更少的特征工程,比如可以不用做特征标准化。
决策树可以很好的处理字段缺失的数据。
决策树能够自动组合多个特征,也有特征选择的作用。
对异常点鲁棒
可扩展性强,容易并行。
决策树有以下缺点:
缺乏平滑性(回归预测时输出值只能输出有限的若干种数值)。
不适合处理高维稀疏数据。
单独使用决策树算法时容易过拟合。
2. GBDT回归算法
2.1 GBDT回归算法推导
棵树组成的加法模型,其对应的公式如下:
为输入样本;
为模型参数;
为分类回归树;
为每棵树的权重。GBDT算法的实现过程如下:给定训练数据集:
其中,
,
为输入空间,
为输出空间,损失函数为
,我们的目标是得到最终的回归树
。
(1)初始化第一个弱学习器
:

(2)对于建立M棵分类回归树
:
,计算第
棵树对应的响应值(损失函数的负梯度,即伪残差):
,利用CART回归树拟合数据
,得到第
棵回归树,其对应的叶子节点区域为
,其中
,且
为第
棵回归树叶子节点的个数。
个叶子节点区域
,计算出最佳拟合值:
:
(3)得到强学习器
的表达式:

2.2 GBDT回归算法实例
(1)数据集介绍


(2)模型训练阶段
参数设置:
学习率:learning_rate = 0.1
迭代次数:n_trees = 5
树的深度:max_depth = 3
1)初始化弱学习器:

。
令导数等于0:

取值为所有训练样本标签值的均值。
,此时得到的初始化学习器为
。2)对于建立M棵分类回归树
:
首先计算负梯度,根据上文损失函数为平方损失时,负梯度就是残差,也就是
与上一轮得到的学习器
的差值:

现将残差的计算结果列表如下:

此时将残差作为样本的真实值来训练弱学习器
,即下表数据:

为左节点的平方损失,
为右节点的平方损失,找到使平方损失和
最小的那个划分节点,即为最佳划分节点。
,右节点包括样本
,则
、
、
,所有可能的划分情况如下表所示:

对于左节点,只含有0,1两个样本,根据下表结果我们选择年龄7为划分点(也可以选体重30)。

对于右节点,只含有2,3两个样本,根据下表结果我们选择年龄30为划分点(也可以选体重70)。

现在我们的第一棵回归树长下面这个样子:

,来拟合残差。
,其实就是标签值的均值。这个地方的标签值不是原始的
,而是本轮要拟合的标残差
。
此时的树长这下面这个样子:

表示。更新公式为:
,即学习率为1,很容易一步学到位导致GBDT过拟合。重复此步骤,直到
结束,最后生成5棵树。
下面将展示每棵树最终的结构,这些图都是我GitHub上用Python3实现GBDT代码生成的,感兴趣的同学可以去运行一下代码。地址:https://github.com/Microstrong0305/WeChat-zhihu-csdnblog-code/tree/master/Ensemble%20Learning/GBDT_Regression
第一棵树:

第二棵树:

第三棵树:

第四棵树:

第五棵树:

3)得到最后的强学习器:

(3)模型预测阶段

在
中,测试样本的年龄为25,大于划分节点21岁,又小于30岁,所以被预测为0.2250。在
中,测试样本的年龄为25,大于划分节点21岁,又小于30岁,所以被预测为0.2025。在
中,测试样本的年龄为25,大于划分节点21岁,又小于30岁,所以被预测为0.1823。在
中,测试样本的年龄为25,大于划分节点21岁,又小于30岁,所以被预测为0.1640。在
中,测试样本的年龄为25,大于划分节点21岁,又小于30岁,所以被预测为0.1476。
最终预测结果为:

3. 手撕GBDT回归算法
本篇文章所有数据集和代码均在我的GitHub中,地址:https://github.com/Microstrong0305/WeChat-zhihu-csdnblog-code/tree/master/Ensemble%20Learning
3.1 用Python3实现GBDT回归算法
需要的Python库:
pandas、PIL、pydotplus、matplotlib3.2 用sklearn实现GBDT回归算法
import numpy as npfrom sklearn.ensemble import GradientBoostingRegressorgbdt = GradientBoostingRegressor(loss='ls', learning_rate=0.1, n_estimators=5, subsample=1, min_samples_split=2, min_samples_leaf=1, max_depth=3, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False)train_feat = np.array([[1, 5, 20],[2, 7, 30],[3, 21, 70],[4, 30, 60],])train_id = np.array([[1.1], [1.3], [1.7], [1.8]]).ravel()test_feat = np.array([[5, 25, 65]])test_id = np.array([[1.6]])print(train_feat.shape, train_id.shape, test_feat.shape, test_id.shape)gbdt.fit(train_feat, train_id)pred = gbdt.predict(test_feat)total_err = 0for i in range(pred.shape[0]):print(pred[i], test_id[i])err = (pred[i] - test_id[i]) / test_id[i]total_err += err * errprint(total_err / pred.shape[0])

4. GBDT回归任务常见的损失函数
(1)均方差,这个是最常见的回归损失函数了,公式如下:

对应的负梯度误差为:

(2)绝对损失,这个损失函数也很常见,公式如下:

对应的负梯度误差为:

(3)Huber损失,它是均方差和绝对损失的折衷产物,对于远离中心的异常点,采用绝对损失,而中心附近的点采用均方差。这个界限一般用分位数点度量。损失函数如下:

对应的负梯度误差为:

(4)分位数损失,它对应的是分位数回归的损失函数,表达式为:

其中,
为分位数,需要我们在回归前指定。对应的负梯度误差为:

5. GBDT的正则化
。系数
也被称为学习率(learning rate),因为它可以对梯度提升的步长进行调整,也就是它可以影响我们设置的回归树个数。对于前面的弱学习器的迭代:
如果我们加上了正则化项,则有:

的取值范围为
。对于同样的训练集学习效果,较小的
意味着我们需要更多的弱学习器的迭代次数。通常我们用学习率和迭代最大次数一起来决定算法的拟合效果。即参数learning_rate会强烈影响到参数n_estimators(即弱学习器个数)。learning_rate的值越小,就需要越多的弱学习器数来维持一个恒定的训练误差(training error)常量。经验上,推荐小一点的learning_rate会对测试误差(test error)更好。在实际调参中推荐将learning_rate设置为一个小的常数(e.g. learning_rate <= 0.1),并通过early stopping机制来选n_estimators。6. 关于GBDT若干问题的思考
(1)GBDT与AdaBoost的区别与联系?
(2)GBDT与随机森林(Random Forest,RF)的区别与联系?
(3)我们知道残差=真实值-预测值,明明可以很方便的计算出来,为什么GBDT的残差要用负梯度来代替?为什么要引入麻烦的梯度?有什么用呢?
7. 总结
8. Reference
由于参考的文献较多,我把每一部分都重点参考了哪些文章详细标注一下。
GBDT简介与GBDT回归算法:
【1】Friedman J H . Greedy Function Approximation: A Gradient Boosting Machine[J]. The Annals of Statistics, 2001, 29(5):1189-1232.
【2】Friedman, Jerome & Hastie, Trevor & Tibshirani, Robert. (2000). Additive Logistic Regression: A Statistical View of Boosting. The Annals of Statistics. 28. 337-407. 10.1214/aos/1016218223.
【3】机器学习-一文理解GBDT的原理-20171001 - 谋杀电视机的文章 - 知乎 https://zhuanlan.zhihu.com/p/29765582
【4】GBDT算法原理深入解析,地址:https://www.zybuluo.com/yxd/note/611571
【5】GBDT的原理和应用 - 文西的文章 - 知乎 https://zhuanlan.zhihu.com/p/30339807
【6】ID3、C4.5、CART、随机森林、bagging、boosting、Adaboost、GBDT、xgboost算法总结 - yuyuqi的文章 - 知乎 https://zhuanlan.zhihu.com/p/34534004
【7】GBDT:梯度提升决策树,地址:https://www.jianshu.com/p/005a4e6ac775
【8】机器学习算法中 GBDT 和 XGBOOST 的区别有哪些?- wepon的回答 - 知乎 https://www.zhihu.com/question/41354392/answer/98658997
【9】http://wepon.me/files/gbdt.pdf
【10】GBDT详细讲解&常考面试题要点,地址:https://mp.weixin.qq.com/s/M2PwsrAnI1S9SxSB1guHdg
【11】Gradient Boosting Decision Tree,地址:http://gitlinux.net/2019-06-11-gbdt-gradient-boosting-decision-tree/
【12】《推荐系统开发实战》之基于点击率预估的推荐算法介绍和案例开发实战,地址:https://mp.weixin.qq.com/s/2VATflDlelfxhOQkcXHSqw
【13】GBDT算法原理以及实例理解,地址:https://blog.csdn.net/zpalyq110/article/details/79527653
手撕GBDT回归算法:
【14】GBDT_Simple_Tutorial(梯度提升树简易教程),GitHub地址:https://github.com/Freemanzxp/GBDT_Simple_Tutorial
【15】SCIKIT-LEARN与GBDT使用案例,地址:https://blog.csdn.net/superzrx/article/details/47073847
【16】手写原始gbdt,地址:https://zhuanlan.zhihu.com/p/82406112?utm_source=wechat_session&utm_medium=social&utm_oi=743812915018104832
GBDT回归任务常见的损失函数与正则化:
【17】Regularization on GBDT,地址:http://chuan92.com/2016/04/11/regularization-on-gbdt
【18】Early stopping of Gradient Boosting,地址:https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_early_stopping.html
【19】Rashmi K V, Gilad-Bachrach R. DART: Dropouts meet Multiple Additive Regression Trees[C]//AISTATS. 2015: 489-497.
关于GBDT若干问题的思考:
【20】关于adaboost、GBDT、xgboost之间的区别与联系,地址:https://blog.csdn.net/HHTNAN/article/details/80894247
【21】[校招-基础算法]GBDT/XGBoost常见问题 - Jack Stark的文章 - 知乎 https://zhuanlan.zhihu.com/p/81368182
【22】gbdt的残差为什么用负梯度代替?- 知乎 https://www.zhihu.com/question/63560633
【23】gbdt的残差为什么用负梯度代替?- 奥奥奥奥噢利的回答 - 知乎 https://www.zhihu.com/question/63560633/answer/581670747
下载1:OpenCV-Contrib扩展模块中文版教程 在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。 下载2:Python视觉实战项目52讲 在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。 下载3:OpenCV实战项目20讲 在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。 交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

