Skip to content

separate Tensor and PrimitiveRW in n-ary eval #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,16 @@ try x = runEitherT x <&> \case
||| **Note:**
||| * Each call to `eval` will rebuild and execute the graph; multiple calls to `eval` on different
||| tensors, even if they are in the same computation, will be treated entirely independently.
||| To efficiently evaluate multiple tensors at once, use `TensorList.eval`.
||| * `eval` performs logging. You can disable this by adjusting the TensorFlow logging level
||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`.
||| To efficiently evaluate multiple tensors at once, use `All2.eval`.
||| * `eval` performs logging. You can disable this by adjusting the logging level
||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`.
export partial
eval : PrimitiveRW dtype ty => Graph (Tensor shape dtype) -> IO (Literal shape ty)
eval $ MkGraph x =
let (env, MkTensor root) = runState empty x
in try $ execute (MkFn [] root env) >>= read {dtype} []

namespace TensorList
||| A list of `Tensor`s, along with the conversions needed to evaluate them to `Literal`s.
||| The list is parametrized by the shapes and types of the resulting `Literal`s.
public export
data TensorList : List Shape -> List Type -> Type where
Nil : TensorList [] []
(::) : PrimitiveRW dtype ty =>
Tensor shape dtype ->
TensorList shapes tys ->
TensorList (shape :: shapes) (ty :: tys)

namespace All2
||| Evaluate a list of `Tensor`s as a list of `Literal`s. Tensors in the list can have different
||| shapes and element types. For example,
||| ```
Expand All @@ -149,21 +139,31 @@ namespace TensorList
||| * `eval` performs logging. You can disable this by adjusting the TensorFlow logging level
||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`.
export partial
eval : Graph (TensorList shapes tys) -> IO (All2 Literal shapes tys)
eval $ MkGraph xs =
let (env, xs) = runState empty xs
(env, root) = runState env (addNode $ Tuple $ nodes xs)
in try $ execute (MkFn [] root env) >>= readAll xs 0
eval : Graph (All2 Tensor shapes dtypes) ->
All2 PrimitiveRW dtypes tys =>
IO $ All2 Literal shapes tys
eval @{prims} $ MkGraph xs =
let graph = do addNode (Tuple $ nodes !xs)
(env, root) = runState empty graph
in try $ execute (MkFn [] root env) >>= readAll prims 0

where

nodes : TensorList s t -> List Nat
nodes : All2 Tensor ss ds -> List Nat
nodes [] = []
nodes (MkTensor x :: xs) = x :: nodes xs

readAll : HasIO io => TensorList s t -> Nat -> Literal -> io $ All2 Literal s t
readAll : HasIO io => All2 PrimitiveRW ds ts -> Nat -> Literal -> io $ All2 Literal ss ts
readAll [] _ _ = pure []
readAll (MkTensor {dtype} _ :: ts) n lit = [| read {dtype} [n] lit :: readAll ts (S n) lit |]
readAll ((::) {x} _ prims) n lit = [| read {dtype = x} [n] lit :: readAll ts prims (S n) lit |]

partial
foo : IO Bool
foo = do
let x0 : Literal [] Double
y0 := tensor {dtype = F64} x0
[x0'] <- eval {tys = %search} (do pure [!y0])
pure (x0 == x0')

||| A string representation of the graph used to define a `Tensor`, detailing all enqueued XLA
||| operations.
Expand Down
22 changes: 11 additions & 11 deletions test/Unit/TestTensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ limitations under the License.
--}
module Unit.TestTensor

import Unit.TestTensor.Elementwise
import Unit.TestTensor.HigherOrder
import Unit.TestTensor.Sampling
import Unit.TestTensor.Slice
import Unit.TestTensor.Structure

import Data.Nat
import Data.Vect
import System
Expand All @@ -33,6 +27,12 @@ import Utils.Comparison
import Utils.Cases
import Utils.Proof

import Unit.TestTensor.Elementwise
import Unit.TestTensor.HigherOrder
import Unit.TestTensor.Sampling
import Unit.TestTensor.Slice
import Unit.TestTensor.Structure

partial
tensorThenEval : Property
tensorThenEval = property $ do
Expand Down Expand Up @@ -68,18 +68,18 @@ evalTuple = property $ do
y1 = tensor {dtype = S32} x1
y2 = tensor {dtype = U64} x2

let [] = unsafePerformIO $ eval (pure [])
let [] = unsafePerformIO $ eval {tys = []} (pure [])

let [x0'] = unsafePerformIO $ eval (do pure [!y0])
let [x0'] = unsafePerformIO $ eval {tys = [_]} (do pure [!y0])

x0' ==~ x0

let [x0', x1'] = unsafePerformIO $ eval (do pure [!y0, !y1])
let [x0', x1'] = unsafePerformIO $ eval {tys = [_, _]} (do pure [!y0, !y1])

x0' ==~ x0
x1' === x1

let [x0', x1', x2'] = unsafePerformIO $ eval (do pure [!y0, !y1, !y2])
let [x0', x1', x2'] = unsafePerformIO $ eval {tys = [_, _, _]} (do pure [!y0, !y1, !y2])

x0' ==~ x0
x1' === x1
Expand All @@ -95,7 +95,7 @@ evalTupleNonTrivial = property $ do
w <- slice [0.to 2] u
pure [v, w]

[v, w] = unsafePerformIO $ eval xs
[v, w] = unsafePerformIO $ eval {tys = [_, _]} xs

v ==~ Scalar (exp (-2.0) + 3.0)
w ==~ [| exp [1.0, -2.0] |]
Expand Down