Skip to content

Commit 897b05e

Browse files
committed
all variant
1 parent 9bc3a11 commit 897b05e

File tree

10 files changed

+58
-1398
lines changed

10 files changed

+58
-1398
lines changed

src/Compiler/Expr.idr

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ export
4242
empty : Env
4343
empty = MkEnv 0 []
4444

45+
export
46+
addNode1 : Expr -> Env -> (Nat, Env)
47+
addNode1 x (MkEnv n xs) = (n, MkEnv (S n) ((n, x) :: xs))
48+
4549
export
4650
addNode : Expr -> State Env Nat
4751
addNode expr = do

src/Tensor.idr

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -110,44 +110,66 @@ panicIO x = runEitherT x <&> \case
110110
||| **Note:**
111111
||| * Each call to `eval` will rebuild and execute the graph. Similarly, multiple calls to
112112
||| `eval` on different `Tensor`s in a computation will be treated entirely independently.
113-
||| `eval` does not store intermediate values. This is a known limitation, and may change in
114-
||| the future.
115-
||| * `eval` performs logging. You can disable this by adjusting the TensorFlow logging level
116-
||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`.
113+
||| `eval` does not store intermediate values. If you want to evaluate multiple tensors, use
114+
||| `Tuple.eval`.
115+
||| * `eval` performs logging. You can disable this by adjusting the logging level
116+
||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`.
117117
export partial
118118
eval : PrimitiveRW dtype ty => Graph (Tensor shape dtype) -> IO (Literal shape ty)
119119
eval $ MkGraph x =
120120
let (env, MkTensor root) = runState empty x
121121
in panicIO $ execute (MkFn [] root env) >>= read {dtype} []
122122

123-
namespace TensorList
124-
public export
125-
data TensorList : List (Shape, Type) -> Type where
126-
Nil : TensorList []
127-
(::) : PrimitiveRW dtype ty =>
128-
Tensor shape dtype ->
129-
TensorList sts ->
130-
TensorList ((shape, ty) :: sts)
131-
132123
namespace Tuple
133124
export partial
134-
eval : Graph (TensorList shapes) -> IO $ All (uncurry Literal) shapes
135-
eval $ MkGraph tensors = do
125+
eval : Graph (All2 Tensor shapes dtypes) ->
126+
All2 PrimitiveRW dtypes tys =>
127+
IO $ All2 Literal shapes tys
128+
eval @{prims} $ MkGraph tensors = do
136129
let graph = do ts <- tensors
137130
x <- addNode (Tuple $ nodes ts)
138131
pure (x, ts)
139132
(env, root, tensors) = runState empty graph
140-
panicIO $ execute (MkFn [] root env) >>= readAll tensors 0
133+
panicIO $ execute (MkFn [] root env) >>= readAll tensors prims 0
141134

142135
where
143136

144-
nodes : TensorList s -> List Nat
137+
nodes : All2 Tensor ss ds -> List Nat
145138
nodes [] = []
146139
nodes (MkTensor x :: xs) = x :: nodes xs
147140

148-
readAll : HasIO io => TensorList s -> Nat -> Literal -> io $ All (uncurry Literal) s
149-
readAll [] _ _ = pure []
150-
readAll (MkTensor {dtype} _ :: ts) n lit = [| read {dtype} [n] lit :: readAll ts (S n) lit |]
141+
readAll : HasIO io =>
142+
All2 Tensor ss ds ->
143+
All2 PrimitiveRW ds ts ->
144+
Nat ->
145+
Literal ->
146+
io $ All2 Literal ss ts
147+
readAll [] [] _ _ = pure []
148+
readAll (MkTensor {dtype} _ :: ts) (prim :: prims) n lit =
149+
[| read {dtype} [n] lit :: readAll ts prims (S n) lit |]
150+
151+
partial
152+
foo : IO Bool
153+
foo = do
154+
let x0 : Literal [] Double
155+
y0 := tensor {dtype = F64} x0
156+
[x0'] <- eval {tys = %search} (do pure [!y0])
157+
pure (x0 == x0')
158+
159+
namespace Example
160+
interface RW (a : Type) (b : Type) | a where
161+
162+
RW Nat Int32 where
163+
RW Bool Double where
164+
165+
eval : All2 Literal shape as -> All2 RW as bs => All2 Literal shape bs
166+
167+
eq : Bool
168+
eq = let xs : Literal [2] Int32
169+
xs' : Literal [2] Nat
170+
171+
[xs''] := eval [xs']
172+
in xs == xs''
151173
152174
||| A string representation of the graph used to define a `Tensor`, detailing all enqueued XLA
153175
||| operations.

src/Util.idr

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ namespace List
126126
impl _ [] = []
127127
impl i (x :: xs) = if elem i idxs then impl (S i) xs else x :: impl (S i) xs
128128

129+
namespace All2
130+
public export
131+
data All2 : (0 p : a -> b -> Type) -> List a -> List b -> Type where
132+
Nil : All2 p [] []
133+
(::) : forall xs, ys . p x y -> All2 p xs ys -> All2 p (x :: xs) (y :: ys)
134+
129135
||| A `Sorted f xs` proves that for all consecutive elements `x` and `y` in `xs`, `f x y` exists.
130136
||| For example, a `Sorted LT xs` proves that all `Nat`s in `xs` appear in increasing numerical
131137
||| order.

test.ipkg

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@ main = Main
1313
modules =
1414
Unit.Model.TestKernel,
1515

16-
Unit.TestTensor.Elementwise,
17-
Unit.TestTensor.HigherOrder,
18-
Unit.TestTensor.Sampling,
19-
Unit.TestTensor.Slice,
20-
Unit.TestTensor.Structure,
21-
2216
Unit.TestDistribution,
2317
Unit.TestLiteral,
2418
Unit.TestTensor,

test/Unit/TestTensor.idr

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@ limitations under the License.
1515
--}
1616
module Unit.TestTensor
1717

18-
import Unit.TestTensor.Elementwise
19-
import Unit.TestTensor.HigherOrder
20-
import Unit.TestTensor.Sampling
21-
import Unit.TestTensor.Slice
22-
import Unit.TestTensor.Structure
23-
2418
import Data.Nat
2519
import Data.Vect
2620
import System
@@ -68,9 +62,9 @@ evalTuple = property $ do
6862
y1 = tensor {dtype = S32} x1
6963
y2 = tensor {dtype = U64} x2
7064

71-
let [] = unsafePerformIO $ eval (pure [])
65+
-- let [] = unsafePerformIO $ eval {tys = []} (pure [])
7266

73-
let [x0'] = unsafePerformIO $ eval (do pure [!y0])
67+
let [x0'] = unsafePerformIO $ eval {tys = [_]} (do pure [!y0])
7468

7569
x0' ==~ x0
7670

@@ -88,14 +82,15 @@ evalTuple = property $ do
8882
partial
8983
evalTupleNonTrivial : Property
9084
evalTupleNonTrivial = property $ do
91-
let xs = do y0 <- tensor [1.0, -2.0, 0.4]
85+
let xs : Graph $ All2 Tensor [[], [2]] _ =
86+
do y0 <- tensor [1.0, -2.0, 0.4]
9287
y1 <- tensor 3.0
9388
u <- exp y0
9489
v <- slice [at 1] u + pure y1
9590
w <- slice [0.to 2] u
9691
pure [v, w]
9792

98-
[v, w] = unsafePerformIO $ eval xs
93+
[v, w] = unsafePerformIO $ eval {shapes = [[], [2]]} {tys = [_, _]} xs
9994

10095
v ==~ Scalar (exp (-2.0) + 3.0)
10196
w ==~ [| exp [1.0, -2.0] |]
@@ -469,10 +464,4 @@ group = MkGroup "Tensor" $ [
469464
, (#"(|\) and (/|) result and inverse"#, triangularSolveResultAndInverse)
470465
, (#"(|\) and (/|) ignore opposite elements"#, triangularSolveIgnoresOppositeElems)
471466
, ("trace", trace)
472-
] ++ concat (the (List _) [
473-
Unit.TestTensor.Elementwise.all
474-
, Unit.TestTensor.HigherOrder.all
475-
, Unit.TestTensor.Sampling.all
476-
, Unit.TestTensor.Slice.all
477-
, Unit.TestTensor.Structure.all
478-
])
467+
]

0 commit comments

Comments
 (0)