|
| 1 | +## [ 输入参数类型不一致 ]torch.linalg.lu_solve |
| 2 | + |
| 3 | +### [torch.linalg.lu_solve](https://pytorch.org/docs/stable/generated/torch.linalg.lu_solve.html#torch.linalg.lu_solve) |
| 4 | + |
| 5 | +```python |
| 6 | +torch.linalg.lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None) |
| 7 | +``` |
| 8 | + |
| 9 | +### [paddle.linalg.lu_solve](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/linalg/lu_solve_cn.html) |
| 10 | + |
| 11 | +```python |
| 12 | +paddle.linalg.lu_solve(b, lu, pivots, trans="N", name=None) |
| 13 | +``` |
| 14 | + |
| 15 | +PyTorch 相比 Paddle 支持更多其他参数,具体如下: |
| 16 | + |
| 17 | +### 参数映射 |
| 18 | + |
| 19 | +| PyTorch | PaddlePaddle | 备注 | |
| 20 | +| ------- | ------------ | ----------------------------------------------------- | |
| 21 | +| LU | lu | 表示 LU 分解结果矩阵,由 L、U 拼接组成,仅参数名不一致。 | |
| 22 | +| pivots | pivots | 表示 LU 分解结果的主元信息 Tensor 。 | |
| 23 | +| B | b | 表示欲进行线性方程组求解的右值 Tensor ,仅参数名不一致。 | |
| 24 | +| left | - | 表示系数矩阵 A 是否在左侧, Paddle 无此参数,需要转写。| |
| 25 | +| adjoint | trans | 表示是否使用转置 LU 分解结果, PyTorch 为 bool 类型,Paddle 为 str 类型,需要转写。| |
| 26 | +| out | - | 表示输出的 Tensor 元组 , Paddle 无此参数,需要转写。 | |
| 27 | + |
| 28 | +### 转写示例 |
| 29 | + |
| 30 | +#### out:指定输出 |
| 31 | + |
| 32 | +```python |
| 33 | +# PyTorch 写法 |
| 34 | +torch.linalg.lu_solve(LU, pivots, B, out=A) |
| 35 | + |
| 36 | +# Paddle 写法 |
| 37 | +y = paddle.linalg.lu_solve(B, LU, pivots) |
| 38 | +paddle.assign(y, A) |
| 39 | +``` |
| 40 | + |
| 41 | +#### left=True, adjoint=True |
| 42 | +```python |
| 43 | +# PyTorch 写法 |
| 44 | +LU, pivots = torch.linalg.lu(A) |
| 45 | +torch.linalg.lu_solve(LU, pivots, B, left=True, adjoint=True) |
| 46 | + |
| 47 | +# Paddle 写法 |
| 48 | +LU, pivots = paddle.linalg.lu(A) |
| 49 | +paddle.linalg.lu_solve(B, LU, pivots, trans="C") |
| 50 | +``` |
| 51 | + |
| 52 | +#### left=True, adjoint=False |
| 53 | +```python |
| 54 | +# PyTorch 写法 |
| 55 | +LU, pivots = torch.linalg.lu(A) |
| 56 | +torch.linalg.lu_solve(LU, pivots, B, left=True, adjoint=False) |
| 57 | + |
| 58 | +# Paddle 写法 |
| 59 | +LU, pivots = paddle.linalg.lu(A) |
| 60 | +paddle.linalg.lu_solve(B, LU, pivots, trans="N") |
| 61 | +``` |
| 62 | + |
| 63 | +#### left=False, adjoint=True |
| 64 | +```python |
| 65 | +# PyTorch 写法 |
| 66 | +LU, pivots = torch.linalg.lu(A) |
| 67 | +torch.linalg.lu_solve(LU, pivots, B, left=False, adjoint=True) |
| 68 | + |
| 69 | +# Paddle 写法 |
| 70 | +LU, pivots = paddle.linalg.lu(A.T) |
| 71 | +paddle.linalg.lu_solve(B.T, LU, pivots, trans="C").T |
| 72 | +``` |
| 73 | + |
| 74 | +#### left=False, adjoint=False |
| 75 | +```python |
| 76 | +# PyTorch 写法 |
| 77 | +LU, pivots = torch.linalg.lu(A) |
| 78 | +torch.linalg.lu_solve(LU, pivots, B, left=False, adjoint=False) |
| 79 | + |
| 80 | +# Paddle 写法 |
| 81 | +LU, pivots = paddle.linalg.lu(A.T) |
| 82 | +paddle.linalg.lu_solve(B.T, LU, pivots, trans="N").T |
| 83 | +``` |
0 commit comments