交叉分解回归
什么是交叉分解算法?
- 指在聚类分析中使用的一种方法,特别是在层次聚类中。这种算法通过逐步合并或分解数据点来构建一个层次结构,从而形成一棵聚类树(或称为树状图、dendrogram)。在每一步,算法选择最接近的一对簇进行合并,或者将一个簇分解成更小的子簇
如何理解 scikit-learn 中的 PLSCanonical 算法?
- 局部最小二乘法的变换器和回归器
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19>>> from sklearn.cross_decomposition import PLSCanonical
>>> X = [[0., 0., 1.], [1.,0.,0.], [2.,2.,2.], [2.,5.,4.]]
>>> Y = [[0.1, -0.2], [0.9, 1.1], [6.2, 5.9], [11.9, 12.3]]
>>> plsca = PLSCanonical(n_components=2)
>>> plsca.fit(X, Y)
PLSCanonical()
>>> X_c, Y_c = plsca.transform(X, Y)
>>> X_c.shape, Y_c.shape
((4, 2), (4, 2))
>>> X_c
array([[-1.39700475, -0.1179672 ],
[-1.19678754, 0.17050027],
[ 0.56032252, -0.0991593 ],
[ 2.03346977, 0.04662624]])
>>> Y_c
array([[-1.22601804, -0.01674181],
[-0.9602955 , 0.04216316],
[ 0.32491535, -0.04379 ],
[ 1.86139819, 0.01836865]])
如何理解 scikit-learn 中的 PLSSVD 法?
- 部分最小二乘法 SVD,
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24>>> import numpy as np
>>> from sklearn.cross_decomposition import PLSSVD
>>> X = np.array([[0., 0., 1.],
... [1., 0., 0.],
... [2., 2., 2.],
... [2., 5., 4.]])
>>> Y = np.array([[0.1, -0.2],
... [0.9, 1.1],
... [6.2, 5.9],
... [11.9, 12.3]])
>>> pls = PLSSVD(n_components=2).fit(X, Y)
>>> X_c, Y_c = pls.transform(X, Y)
>>> X_c.shape, Y_c.shape
((4, 2), (4, 2))
>>> X_c
array([[-1.39700475, -0.10283021],
[-1.19678754, 0.17159333],
[ 0.56032252, -0.10849725],
[ 2.03346977, 0.03973413]])
>>> Y_c
array([[-1.22601804, -0.01930121],
[-0.9602955 , 0.04015847],
[ 0.32491535, -0.04311171],
[ 1.86139819, 0.02225445]])
如何理解 scikit-learn 中的 PLSRegression 法?
- PLSRegression 估计器与算法 ='nipals’的 PLSCanonical 类似,也被称为 PLS2 或 PLS1,这取决于目标的数量
1
2
3
4
5
6
7
8
9
10
11>>> from sklearn.cross_decomposition import PLSRegression
>>> X = [[0., 0., 1.], [1.,0.,0.], [2.,2.,2.], [2.,5.,4.]]
>>> Y = [[0.1, -0.2], [0.9, 1.1], [6.2, 5.9], [11.9, 12.3]]
>>> pls2 = PLSRegression(n_components=2)
>>> pls2.fit(X, Y)
PLSRegression()
>>> Y_pred = pls2.predict(X)
array([[ 0.26087869, 0.15302213],
[ 0.60667302, 0.45634164],
[ 6.46856199, 6.48931562],
[11.7638863 , 12.00132061]])