Skip to content

Commit 50e8abc

Browse files
authored
【Hackathon 8th No.1】add lu_solve api for paddle (#7052)
* 【Hackathon 8th No.1】add `lu_solve` api for paddle * pre-commit * add example * fix review
1 parent 2d970bc commit 50e8abc

File tree

4 files changed

+131
-0
lines changed

4 files changed

+131
-0
lines changed

docs/api/paddle/Tensor/Overview_en.rst

+1
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ Methods
195195
outer
196196
cov
197197
lu
198+
lu_solve
198199
lu_unpack
199200
cholesky_solve
200201
mod

docs/api/paddle/linalg/Overview_cn.rst

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ paddle.linalg 目录下包含飞桨框架支持的线性代数相关 API。具
8080
:widths: 10, 30
8181

8282
" :ref:`paddle.linalg.lstsq <cn_api_paddle_linalg_lstsq>` ", "求解线性方程组的最小二乘问题"
83+
" :ref:`paddle.linalg.lu_solve <cn_api_paddle_linalg_lu_solve>` ", "计算具有唯一解的线性方程组,方程左边为 LU 分解矩阵,右边为矩阵"
8384
" :ref:`paddle.linalg.solve <cn_api_paddle_linalg_solve>` ", "计算具有唯一解的线性方程组,方程左边为方阵,右边为矩阵"
8485
" :ref:`paddle.linalg.triangular_solve <cn_api_paddle_linalg_triangular_solve>` ", "计算具有唯一解的线性方程组,方程左边为上(下)三角方阵,右边为矩阵"
8586
" :ref:`paddle.linalg.cholesky_solve <cn_api_paddle_linalg_cholesky_solve>` ", "通过 Cholesky 分解矩阵,计算具有唯一解的线性方程组"
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
.. _cn_api_paddle_linalg_lu_solve:
2+
3+
lu_solve
4+
-------------------------------
5+
6+
.. py:function:: paddle.linalg.lu_solve(b, lu, pivots, trans="N", name=None)
7+
8+
给定 `A` 的 LU 分解结果 和列向量 `b` ,求解线性方程组的解 `x`。
9+
10+
:math:`A` 为一个或一批方阵,:math:`b` 一个或一批矩阵,当 `trans` 为 `N` 时,公式为:
11+
12+
.. math::
13+
b = A * X
14+
15+
当 `trans` 为 `T` 时,公式为:
16+
17+
.. math::
18+
b = A ^ {T} * X
19+
20+
当 `trans` 为 `C` 时,公式为:
21+
22+
.. math::
23+
b = A ^ {H} * X
24+
25+
.. note::
26+
27+
`lu` 和 `pivots` 由 ``paddle.linalg.lu`` 得到。
28+
29+
参数
30+
::::::::::::
31+
32+
- **b** (Tensor) - 输入的欲进行线性方程组求解的右值,类型为 Tensor。 ``b`` 的形状应为 ``[*, M, K]``,其中 ``*`` 为零或更大的批次维度,数据类型为 float32, float64。
33+
- **lu** (Tensor) - LU 分解结果矩阵,由 L、U 拼接组成,类型为 Tensor。 ``lu`` 的形状应为 ``[*, M, M]``,其中 ``*`` 为零或更大的批次维度。数据类型和 ``b`` 相同。
34+
- **pivots** (Tensor) - LU 分解结果的主元信息,类型为 Tensor。 ``pivots`` 的形状应为 ``[*, M]``,其中 ``*`` 为零或更大的批次维度。数据类型为 int32。
35+
- **trans** (str,可选) - 是否对 A 进行转置,该参数的合法值为 'N','T','C',默认值为 N。
36+
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
37+
38+
返回
39+
::::::::::::
40+
41+
- Tensor,这个(或这批)矩阵 ``lu`` 、 ``pivots`` 和 ``b`` 经过运算后的结果,数据类型及维度和输入 ``b`` 的一致。
42+
43+
代码示例
44+
::::::::::
45+
46+
COPY-FROM: paddle.linalg.lu_solve
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
## [ 输入参数类型不一致 ]torch.linalg.lu_solve
2+
3+
### [torch.linalg.lu_solve](https://pytorch.org/docs/stable/generated/torch.linalg.lu_solve.html#torch.linalg.lu_solve)
4+
5+
```python
6+
torch.linalg.lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None)
7+
```
8+
9+
### [paddle.linalg.lu_solve](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/linalg/lu_solve_cn.html)
10+
11+
```python
12+
paddle.linalg.lu_solve(b, lu, pivots, trans="N", name=None)
13+
```
14+
15+
PyTorch 相比 Paddle 支持更多其他参数,具体如下:
16+
17+
### 参数映射
18+
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------- | ------------ | ----------------------------------------------------- |
21+
| LU | lu | 表示 LU 分解结果矩阵,由 L、U 拼接组成,仅参数名不一致。 |
22+
| pivots | pivots | 表示 LU 分解结果的主元信息 Tensor 。 |
23+
| B | b | 表示欲进行线性方程组求解的右值 Tensor ,仅参数名不一致。 |
24+
| left | - | 表示系数矩阵 A 是否在左侧, Paddle 无此参数,需要转写。|
25+
| adjoint | trans | 表示是否使用转置 LU 分解结果, PyTorch 为 bool 类型,Paddle 为 str 类型,需要转写。|
26+
| out | - | 表示输出的 Tensor 元组 , Paddle 无此参数,需要转写。 |
27+
28+
### 转写示例
29+
30+
#### out:指定输出
31+
32+
```python
33+
# PyTorch 写法
34+
torch.linalg.lu_solve(LU, pivots, B, out=A)
35+
36+
# Paddle 写法
37+
y = paddle.linalg.lu_solve(B, LU, pivots)
38+
paddle.assign(y, A)
39+
```
40+
41+
#### left=True, adjoint=True
42+
```python
43+
# PyTorch 写法
44+
LU, pivots = torch.linalg.lu(A)
45+
torch.linalg.lu_solve(LU, pivots, B, left=True, adjoint=True)
46+
47+
# Paddle 写法
48+
LU, pivots = paddle.linalg.lu(A)
49+
paddle.linalg.lu_solve(B, LU, pivots, trans="C")
50+
```
51+
52+
#### left=True, adjoint=False
53+
```python
54+
# PyTorch 写法
55+
LU, pivots = torch.linalg.lu(A)
56+
torch.linalg.lu_solve(LU, pivots, B, left=True, adjoint=False)
57+
58+
# Paddle 写法
59+
LU, pivots = paddle.linalg.lu(A)
60+
paddle.linalg.lu_solve(B, LU, pivots, trans="N")
61+
```
62+
63+
#### left=False, adjoint=True
64+
```python
65+
# PyTorch 写法
66+
LU, pivots = torch.linalg.lu(A)
67+
torch.linalg.lu_solve(LU, pivots, B, left=False, adjoint=True)
68+
69+
# Paddle 写法
70+
LU, pivots = paddle.linalg.lu(A.T)
71+
paddle.linalg.lu_solve(B.T, LU, pivots, trans="C").T
72+
```
73+
74+
#### left=False, adjoint=False
75+
```python
76+
# PyTorch 写法
77+
LU, pivots = torch.linalg.lu(A)
78+
torch.linalg.lu_solve(LU, pivots, B, left=False, adjoint=False)
79+
80+
# Paddle 写法
81+
LU, pivots = paddle.linalg.lu(A.T)
82+
paddle.linalg.lu_solve(B.T, LU, pivots, trans="N").T
83+
```

0 commit comments

Comments
 (0)