Skip to content

Commit 161a0ae

Browse files
committed
Add transposition rules for + and *
1 parent 6427a9e commit 161a0ae

File tree

4 files changed

+42
-62
lines changed

4 files changed

+42
-62
lines changed

examples/ad-tests.dx

Lines changed: 21 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,27 @@
1-
f = lam x. x * x * x
21

32

4-
:p jvp f 1.0 1.0
3+
:p f :: Real --o Real
4+
f = llam x. x
5+
tlinear f 2.0
6+
> 2.0
57

6-
> 3.0
8+
:p f :: Real --o Real
9+
f = llam x. y = x; y
10+
tlinear f 2.0
11+
> 2.0
712

13+
:p f :: Real --o Real
14+
f = llam x. x + x
15+
tlinear f 2.0
16+
> 4.0
817

9-
:p jvp (lam x. jvp f x 1.0) 1.0 1.0
18+
:p f :: Real --o Real
19+
f = llam x. y = 2.0 * x
20+
3.0 * y + x
21+
tlinear f 1.0
22+
> 7.0
1023

11-
> 6.0
12-
13-
14-
:p grad f 1.0
15-
16-
> 3.0
17-
18-
19-
_, Nx = unpack range 3
20-
21-
22-
g x = for i::Nx. 3.0 * x * x
23-
24-
25-
:p jvp g 2.0 1.0
26-
27-
> [12.0, 12.0, 12.0]
28-
29-
30-
g2 (x, y) = x * y
31-
32-
33-
:p grad g2 (1.0, 2.0)
34-
35-
> (2.0, 1.0)
36-
37-
38-
xs = for i::Nx. real iota.i * 1.0
39-
40-
41-
arrFun c = for i::Nx. c
42-
43-
44-
:p let (_, pullback) = vjp arrFun 2.0
45-
in pullback xs
46-
47-
> 3.0
48-
49-
50-
:p (transpose vsum 1.5) :: Nx=>Real
51-
52-
> [1.5, 1.5, 1.5]
53-
54-
55-
:p jvp vsum xs xs
56-
57-
> 3.0
58-
59-
60-
:p transpose (lam x. for i. x.i) xs
61-
62-
> [0.0, 1.0, 2.0]
24+
:p f :: Real --o Real
25+
f = llam x. (2.0 + 3.0) * x
26+
tlinear f 1.0
27+
> 5.0

makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ tests: quine-tests quine-tests-interp repl-test stack-tests
2929

3030
quine-tests: $(quine-test-targets)
3131

32-
quine-tests-interp: runinterp-eval-tests runinterp-interp-tests
32+
quine-tests-interp: runinterp-eval-tests runinterp-ad-tests runinterp-interp-tests
3333

3434
run-%: examples/%.dx
3535
misc/check-quine $^ $(dex) script --lit --allow-errors

prelude.dx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,12 @@ linearize f x0 = %linearize(lam x. f x, x0)
225225
jvp :: (a -> b) -> a -> a --o b
226226
jvp f x = llam t. snd (linearize f x) t
227227

228-
linearTranspose :: (a --o b) --o b --o a
229-
linearTranspose = llam f ct. %linearTranspose(llam t. f t, ct)
228+
tlinear :: (a --o b) --o b --o a
229+
tlinear = llam f ct. %linearTranspose(llam t. f t, ct)
230230

231231
vjp :: (a -> b) -> a -> (b, b --o a)
232232
vjp f x = (y, df) = linearize f x
233-
(y, linearTranspose df)
233+
(y, tlinear df)
234234

235235
grad :: (a -> Real) -> a -> a
236236
grad f x = (_, pullback) = vjp f x

src/lib/Interpreter.hs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,25 @@ transpose :: Val -> Expr -> CotangentVals
198198
transpose ct expr = case expr of
199199
Lit _ -> mempty
200200
Var v _ -> MonMap $ M.singleton v [ct]
201-
PrimOp op ts xs -> error "todo"
201+
PrimOp op ts xs -> transposeOp op ct ts xs
202202
Decl (LetMono p rhs) body
203203
| hasFVs rhs -> cts <> transpose ct' rhs
204204
where (ct', cts) = sepCotangents p $ transpose ct body
205-
App (Lam _ p body) e2 -> transpose ct (Decl (LetMono p e2) body)
205+
App e1 e2
206+
| hasFVs e2 -> cts <> transpose ct' e2
207+
where
208+
(Lam _ p body) = reduce e1
209+
(ct', cts) = sepCotangents p $ transpose ct body
206210
_ -> error $ "Can't transpose in interpreter: " ++ pprint expr
207211

212+
transposeOp :: Builtin -> Val -> [Type] -> [Val] -> CotangentVals
213+
transposeOp op ct ts xs = case (op, ts, xs) of
214+
(FAdd, _, ~[x1, x2]) -> transpose ct x1 <> transpose ct x2
215+
(FMul, _, ~[x1, x2]) | hasFVs x2 -> let ct' = mul ct (reduce x1)
216+
in transpose ct' x2
217+
| otherwise -> let ct' = mul ct (reduce x2)
218+
in transpose ct' x1
219+
208220
hasFVs :: Expr -> Bool
209221
hasFVs expr = not $ null $ envNames $ freeVars expr
210222

@@ -220,6 +232,9 @@ sepCotangents p vs = (recTreeToVal tree, cts)
220232
put s'
221233
return x
222234

235+
mul :: Val -> Val -> Val
236+
mul x y = realBinOp (*) [x, y]
237+
223238
recTreeToVal :: RecTree Val -> Val
224239
recTreeToVal (RecLeaf v) = v
225240
recTreeToVal (RecTree r) = RecCon Cart $ fmap recTreeToVal r

0 commit comments

Comments
 (0)