Skip to content

Commit a6143dd

Browse files
[PyLayer] pylayer add api (PaddlePaddle#5148)
* pylayer add api
1 parent 54ac226 commit a6143dd

File tree

1 file changed

+156
-0
lines changed

1 file changed

+156
-0
lines changed

docs/api/paddle/autograd/PyLayerContext_cn.rst

+156
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,159 @@ saved_tensor(self, *tensors)
107107
y, = ctx.saved_tensor()
108108
grad = dy * (1 - paddle.square(y))
109109
return grad
110+
111+
112+
mark_not_inplace(self, *tensors)
113+
'''''''''
114+
115+
标记一些输入是不需要 inplace 的。
116+
如果 ``forward`` 的输入输出是同一个 ``Tensor`` ,并且这个 ``Tensor`` 被标记为 not_inplace 的。Paddle 会替用户创建一个新的 Tensor 作为输出。
117+
这样可以防止输入的 ``Tensor`` 的 auto grad 信息被错误的篡改。
118+
119+
.. note::
120+
这个函数最多只能在 ``forward`` 调用一次,并且所有的参数必须是 ``forward`` 输入的 ``Tensor`` 。
121+
122+
**参数**
123+
124+
- **tensors** (list of Tensor) - 需要标记 not inplace 的 ``Tensor``
125+
126+
**返回**
127+
128+
None
129+
130+
**代码示例**
131+
132+
.. code-block:: python
133+
134+
import paddle
135+
136+
class Exp(paddle.autograd.PyLayer):
137+
@staticmethod
138+
def forward(ctx, x):
139+
ctx.mark_not_inplace(x)
140+
return x
141+
142+
@staticmethod
143+
def backward(ctx, grad_output):
144+
out = grad_output.exp()
145+
return out
146+
147+
x = paddle.randn((1, 1))
148+
x.stop_gradient = False
149+
attn_layers = []
150+
for idx in range(0, 2):
151+
attn_layers.append(Exp())
152+
153+
for step in range(0, 2):
154+
a = x
155+
for j in range(0,2):
156+
a = attn_layers[j].apply(x)
157+
a.backward()
158+
159+
160+
mark_non_differentiable(self, *tensors)
161+
'''''''''
162+
163+
标记一些输出是不需要反向的。
164+
如果 ``forward`` 的输入输出是同一个 ``Tensor`` ,并且这个 ``Tensor`` 被标记为 not_inplace 的。Paddle 会替用户创建一个新的 Tensor 作为输出。
165+
将不需要反向的 ``Tensor`` 标记为 non-differentiable,可以提升反向的性能。但是你在 ``backward`` 函数的输入参数中,仍要为其留有反向梯度的位置。
166+
只是这个反向梯度是 1 个全为 0 的、shape 和 ``forward`` 的输出一样的 ``Tensor`` .
167+
168+
.. note::
169+
这个函数最多只能在 ``forward`` 调用一次,并且所有的参数必须是 ``forward`` 输出的 ``Tensor`` 。
170+
171+
**参数**
172+
173+
- **tensors** (list of Tensor) - 需要标记不需要反向的 ``Tensor``
174+
175+
176+
**返回**
177+
178+
None
179+
180+
**代码示例**
181+
182+
.. code-block:: python
183+
184+
import os
185+
os.environ['FLAGS_enable_eager_mode'] = '1'
186+
import paddle
187+
from paddle.autograd import PyLayer
188+
import numpy as np
189+
190+
class Tanh(PyLayer):
191+
@staticmethod
192+
def forward(ctx, x):
193+
a = x + x
194+
b = x + x + x
195+
ctx.mark_non_differentiable(a)
196+
return a, b
197+
198+
@staticmethod
199+
def backward(ctx, grad_a, grad_b):
200+
assert np.equal(grad_a.numpy(), paddle.zeros([1]).numpy())
201+
assert np.equal(grad_b.numpy(), paddle.ones([1], dtype="float64").numpy())
202+
return grad_b
203+
204+
x = paddle.ones([1], dtype="float64")
205+
x.stop_gradient = False
206+
a, b = Tanh.apply(x)
207+
b.sum().backward()
208+
209+
set_materialize_grads(self, value)
210+
'''''''''
211+
212+
设置是否要框架来初始化未初始化的反向梯度。默认是 True。
213+
如果设置为 True,框架会将未初始化的反向梯度数据初始化为 0,然后再调用 ``backward`` 函数。
214+
如果设置为 False,框架会将未初始化的反向梯度以 None 向 ``backward`` 函数传递。
215+
216+
.. note::
217+
这个函数最多只能在 ``forward`` 中调用。
218+
219+
**参数**
220+
221+
- **value** (bool) - 是否要框架来初始化未初始化的反向梯度
222+
223+
224+
**返回**
225+
226+
None
227+
228+
**代码示例**
229+
230+
.. code-block:: python
231+
232+
import os
233+
os.environ['FLAGS_enable_eager_mode'] = '1'
234+
import paddle
235+
from paddle.autograd import PyLayer
236+
import numpy as np
237+
238+
class Tanh(PyLayer):
239+
@staticmethod
240+
def forward(ctx, x):
241+
return x+x+x, x+x
242+
243+
@staticmethod
244+
def backward(ctx, grad, grad2):
245+
assert np.equal(grad2.numpy(), paddle.zeros([1]).numpy())
246+
return grad
247+
248+
class Tanh2(PyLayer):
249+
@staticmethod
250+
def forward(ctx, x):
251+
ctx.set_materialize_grads(False)
252+
return x+x+x, x+x
253+
254+
@staticmethod
255+
def backward(ctx, grad, grad2):
256+
assert grad2==None
257+
return grad
258+
259+
x = paddle.ones([1], dtype="float64")
260+
x.stop_gradient = False
261+
Tanh.apply(x)[0].backward()
262+
263+
x2 = paddle.ones([1], dtype="float64")
264+
x2.stop_gradient = False
265+
Tanh2.apply(x2)[0].backward()

0 commit comments

Comments
 (0)