diff --git a/src/Tensor.idr b/src/Tensor.idr index 8cf212c06..8c1f32fc4 100644 --- a/src/Tensor.idr +++ b/src/Tensor.idr @@ -110,26 +110,16 @@ try x = runEitherT x <&> \case ||| **Note:** ||| * Each call to `eval` will rebuild and execute the graph; multiple calls to `eval` on different ||| tensors, even if they are in the same computation, will be treated entirely independently. -||| To efficiently evaluate multiple tensors at once, use `TensorList.eval`. -||| * `eval` performs logging. You can disable this by adjusting the TensorFlow logging level -||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`. +||| To efficiently evaluate multiple tensors at once, use `All2.eval`. +||| * `eval` performs logging. You can disable this by adjusting the logging level +||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`. export partial eval : PrimitiveRW dtype ty => Graph (Tensor shape dtype) -> IO (Literal shape ty) eval $ MkGraph x = let (env, MkTensor root) = runState empty x in try $ execute (MkFn [] root env) >>= read {dtype} [] -namespace TensorList - ||| A list of `Tensor`s, along with the conversions needed to evaluate them to `Literal`s. - ||| The list is parametrized by the shapes and types of the resulting `Literal`s. - public export - data TensorList : List Shape -> List Type -> Type where - Nil : TensorList [] [] - (::) : PrimitiveRW dtype ty => - Tensor shape dtype -> - TensorList shapes tys -> - TensorList (shape :: shapes) (ty :: tys) - +namespace All2 ||| Evaluate a list of `Tensor`s as a list of `Literal`s. Tensors in the list can have different ||| shapes and element types. For example, ||| ``` @@ -149,21 +139,31 @@ namespace TensorList ||| * `eval` performs logging. You can disable this by adjusting the TensorFlow logging level ||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`. export partial - eval : Graph (TensorList shapes tys) -> IO (All2 Literal shapes tys) - eval $ MkGraph xs = - let (env, xs) = runState empty xs - (env, root) = runState env (addNode $ Tuple $ nodes xs) - in try $ execute (MkFn [] root env) >>= readAll xs 0 + eval : Graph (All2 Tensor shapes dtypes) -> + All2 PrimitiveRW dtypes tys => + IO $ All2 Literal shapes tys + eval @{prims} $ MkGraph xs = + let graph = do addNode (Tuple $ nodes !xs) + (env, root) = runState empty graph + in try $ execute (MkFn [] root env) >>= readAll prims 0 where - nodes : TensorList s t -> List Nat + nodes : All2 Tensor ss ds -> List Nat nodes [] = [] nodes (MkTensor x :: xs) = x :: nodes xs - readAll : HasIO io => TensorList s t -> Nat -> Literal -> io $ All2 Literal s t + readAll : HasIO io => All2 PrimitiveRW ds ts -> Nat -> Literal -> io $ All2 Literal ss ts readAll [] _ _ = pure [] - readAll (MkTensor {dtype} _ :: ts) n lit = [| read {dtype} [n] lit :: readAll ts (S n) lit |] + readAll ((::) {x} _ prims) n lit = [| read {dtype = x} [n] lit :: readAll ts prims (S n) lit |] + +partial +foo : IO Bool +foo = do + let x0 : Literal [] Double + y0 := tensor {dtype = F64} x0 + [x0'] <- eval {tys = %search} (do pure [!y0]) + pure (x0 == x0') ||| A string representation of the graph used to define a `Tensor`, detailing all enqueued XLA ||| operations. diff --git a/test/Unit/TestTensor.idr b/test/Unit/TestTensor.idr index 8d81c8ed4..19f45562d 100644 --- a/test/Unit/TestTensor.idr +++ b/test/Unit/TestTensor.idr @@ -15,12 +15,6 @@ limitations under the License. --} module Unit.TestTensor -import Unit.TestTensor.Elementwise -import Unit.TestTensor.HigherOrder -import Unit.TestTensor.Sampling -import Unit.TestTensor.Slice -import Unit.TestTensor.Structure - import Data.Nat import Data.Vect import System @@ -33,6 +27,12 @@ import Utils.Comparison import Utils.Cases import Utils.Proof +import Unit.TestTensor.Elementwise +import Unit.TestTensor.HigherOrder +import Unit.TestTensor.Sampling +import Unit.TestTensor.Slice +import Unit.TestTensor.Structure + partial tensorThenEval : Property tensorThenEval = property $ do @@ -68,18 +68,18 @@ evalTuple = property $ do y1 = tensor {dtype = S32} x1 y2 = tensor {dtype = U64} x2 - let [] = unsafePerformIO $ eval (pure []) + let [] = unsafePerformIO $ eval {tys = []} (pure []) - let [x0'] = unsafePerformIO $ eval (do pure [!y0]) + let [x0'] = unsafePerformIO $ eval {tys = [_]} (do pure [!y0]) x0' ==~ x0 - let [x0', x1'] = unsafePerformIO $ eval (do pure [!y0, !y1]) + let [x0', x1'] = unsafePerformIO $ eval {tys = [_, _]} (do pure [!y0, !y1]) x0' ==~ x0 x1' === x1 - let [x0', x1', x2'] = unsafePerformIO $ eval (do pure [!y0, !y1, !y2]) + let [x0', x1', x2'] = unsafePerformIO $ eval {tys = [_, _, _]} (do pure [!y0, !y1, !y2]) x0' ==~ x0 x1' === x1 @@ -95,7 +95,7 @@ evalTupleNonTrivial = property $ do w <- slice [0.to 2] u pure [v, w] - [v, w] = unsafePerformIO $ eval xs + [v, w] = unsafePerformIO $ eval {tys = [_, _]} xs v ==~ Scalar (exp (-2.0) + 3.0) w ==~ [| exp [1.0, -2.0] |]