@@ -120,26 +120,18 @@ eval $ MkGraph x =
120
120
let (env, MkTensor root) = runState empty x
121
121
in panicIO $ execute (MkFn [] root env) >>= read {dtype} []
122
122
123
- namespace TensorVect
123
+ namespace TensorList
124
124
public export
125
- data TensorVect : List Shape -> List Type -> Type where
126
- Nil : TensorVect [] []
125
+ data TensorList : List ( Shape, Type ) -> Type where
126
+ Nil : TensorList []
127
127
(:: ) : PrimitiveRW dtype ty =>
128
128
Tensor shape dtype ->
129
- TensorVect shapes tys ->
130
- TensorVect (shape :: shapes) (ty :: tys)
131
-
132
- namespace LiteralVect
133
- public export
134
- data LiteralVect : List Shape -> List Type -> Type where
135
- Nil : LiteralVect [] []
136
- (:: ) : Literal shape ty ->
137
- LiteralVect shapes tys ->
138
- LiteralVect (shape :: shapes) (ty :: tys)
129
+ TensorList sts ->
130
+ TensorList ((shape, ty) :: sts)
139
131
140
132
namespace Tuple
141
133
export partial
142
- eval : Graph (TensorVect shapes tys ) -> IO $ LiteralVect shapes tys
134
+ eval : Graph (TensorList shapes) -> IO $ All (uncurry Literal) shapes
143
135
eval $ MkGraph tensors = do
144
136
let graph = do ts <- tensors
145
137
x <- addNode (Tuple $ nodes ts)
@@ -149,11 +141,11 @@ namespace Tuple
149
141
150
142
where
151
143
152
- nodes : TensorVect s t -> List Nat
144
+ nodes : TensorList s -> List Nat
153
145
nodes [] = []
154
146
nodes (MkTensor x :: xs) = x :: nodes xs
155
147
156
- readAll : HasIO io => TensorVect s t -> Nat -> Literal -> io $ LiteralVect s t
148
+ readAll : HasIO io => TensorList s -> Nat -> Literal -> io $ All (uncurry Literal) s
157
149
readAll [] _ _ = pure []
158
150
readAll (MkTensor {dtype} _ :: ts) n lit = [| read {dtype} [n] lit :: readAll ts (S n) lit | ]
159
151
0 commit comments