Skip to content

Commit 9bc3a11

Browse files
committed
remove LiteralVect
1 parent dd7fca1 commit 9bc3a11

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

src/Tensor.idr

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,26 +120,18 @@ eval $ MkGraph x =
120120
let (env, MkTensor root) = runState empty x
121121
in panicIO $ execute (MkFn [] root env) >>= read {dtype} []
122122

123-
namespace TensorVect
123+
namespace TensorList
124124
public export
125-
data TensorVect : List Shape -> List Type -> Type where
126-
Nil : TensorVect [] []
125+
data TensorList : List (Shape, Type) -> Type where
126+
Nil : TensorList []
127127
(::) : PrimitiveRW dtype ty =>
128128
Tensor shape dtype ->
129-
TensorVect shapes tys ->
130-
TensorVect (shape :: shapes) (ty :: tys)
131-
132-
namespace LiteralVect
133-
public export
134-
data LiteralVect : List Shape -> List Type -> Type where
135-
Nil : LiteralVect [] []
136-
(::) : Literal shape ty ->
137-
LiteralVect shapes tys ->
138-
LiteralVect (shape :: shapes) (ty :: tys)
129+
TensorList sts ->
130+
TensorList ((shape, ty) :: sts)
139131

140132
namespace Tuple
141133
export partial
142-
eval : Graph (TensorVect shapes tys) -> IO $ LiteralVect shapes tys
134+
eval : Graph (TensorList shapes) -> IO $ All (uncurry Literal) shapes
143135
eval $ MkGraph tensors = do
144136
let graph = do ts <- tensors
145137
x <- addNode (Tuple $ nodes ts)
@@ -149,11 +141,11 @@ namespace Tuple
149141

150142
where
151143

152-
nodes : TensorVect s t -> List Nat
144+
nodes : TensorList s -> List Nat
153145
nodes [] = []
154146
nodes (MkTensor x :: xs) = x :: nodes xs
155147

156-
readAll : HasIO io => TensorVect s t -> Nat -> Literal -> io $ LiteralVect s t
148+
readAll : HasIO io => TensorList s -> Nat -> Literal -> io $ All (uncurry Literal) s
157149
readAll [] _ _ = pure []
158150
readAll (MkTensor {dtype} _ :: ts) n lit = [| read {dtype} [n] lit :: readAll ts (S n) lit |]
159151

0 commit comments

Comments
 (0)