Skip to content

Commit c596ccf

Browse files
committed
enable to use function symbol in tensor
1 parent 9cad798 commit c596ccf

File tree

6 files changed

+60
-21
lines changed

6 files changed

+60
-21
lines changed

hs-src/Language/Egison/Core.hs

+20-13
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,21 @@ evalExpr env (VectorExpr exprs) = do
223223
whnfs <- mapM (evalExpr env) exprs
224224
case whnfs of
225225
((Intermediate (ITensor (Tensor _ _ _))):_) -> do
226-
ret <- mapM toTensor whnfs >>= tConcat' >>= fromTensor
226+
ret <- mapM toTensor (map f $ zip whnfs [1..(length exprs + 1)]) >>= tConcat' >>= fromTensor
227227
return ret
228-
_ -> do
229-
fromTensor (Tensor [fromIntegral (length whnfs)] (V.fromList whnfs) [])
228+
_ -> fromTensor (Tensor [fromIntegral $ length whnfs] (V.fromList whnfs) [])
229+
where
230+
f ((Intermediate (ITensor (Tensor ns xs indices))), i) =
231+
Intermediate $ ITensor $ Tensor ns (V.fromList $ map g $ zip (V.toList xs) $ map (\ms -> map toEgison $ (toInteger i):ms) $ enumTensorIndices ns) indices
232+
f (x, _) = x
233+
g (Value (ScalarData (Div (Plus [Term 1 [(FunctionData fn argnames args js, 1)]]) p)), ms) =
234+
let Env _ maybe_vwi = env in
235+
let fn' = case maybe_vwi of
236+
Nothing -> fn
237+
Just (VarWithIndices nameString indexList) ->
238+
Just $ symbolScalarData "" $ show $ VarWithIndices nameString $ changeIndexList indexList ms in
239+
Value $ ScalarData $ Div (Plus [Term 1 [(FunctionData fn' argnames args js, 1)]]) p
240+
g (x, _) = x
230241

231242
evalExpr env (TensorExpr nsExpr xsExpr supExpr subExpr) = do
232243
nsWhnf <- evalExpr env nsExpr
@@ -690,11 +701,6 @@ evalExpr env (GenerateTensorExpr fnExpr sizeExpr) = do
690701
applyFunc env fn (Value (makeTuple ms)))
691702
(map (\ms -> map toEgison ms) (enumTensorIndices ns))
692703
fromTensor (Tensor ns (V.fromList xs) [])
693-
where
694-
changeIndexList :: [Index String] -> [EgisonValue] -> [Index String]
695-
changeIndexList idxlist ms = map (\(i, m) -> case i of
696-
Superscript s -> Superscript (s ++ m)
697-
Subscript s -> Subscript (s ++ m)) $ zip idxlist (map show ms)
698704

699705
evalExpr env (TensorContractExpr fnExpr tExpr) = do
700706
fn <- evalExpr env fnExpr
@@ -1007,6 +1013,7 @@ recursiveBind env bindings = do
10071013
let (names, exprs) = unzip bindings
10081014
refs <- replicateM (length bindings) $ newObjectRef nullEnv UndefinedExpr
10091015
let env' = extendEnv env $ makeBindings names refs
1016+
let Env frame _ = env'
10101017
zipWithM_ (\ref (name,expr) -> do
10111018
case expr of
10121019
MemoizedLambdaExpr names body -> do
@@ -1021,14 +1028,14 @@ recursiveBind env bindings = do
10211028
case whnf of
10221029
(Value (CFunc _ env arg body)) -> liftIO . writeIORef ref . WHNF $ (Value (CFunc (Just name) env arg body))
10231030
FunctionExpr args -> do
1024-
let Env frame _ = env'
10251031
liftIO . writeIORef ref . Thunk $ evalExpr (Env frame (Just $ varToVarWithIndices name)) $ FunctionExpr args
1026-
GenerateTensorExpr _ _ -> do
1027-
let Env frame _ = env'
1028-
liftIO . writeIORef ref . Thunk $ evalExpr (Env frame (Just $ varToVarWithIndices name)) $ expr
1029-
_ -> liftIO . writeIORef ref . Thunk $ evalExpr env' expr)
1032+
_ | isVarWithIndices name -> liftIO . writeIORef ref . Thunk $ evalExpr (Env frame (Just $ varToVarWithIndices name)) expr
1033+
| otherwise -> liftIO . writeIORef ref . Thunk $ evalExpr env' expr)
10301034
refs bindings
10311035
return env'
1036+
where
1037+
isVarWithIndices :: Var -> Bool
1038+
isVarWithIndices (Var _ xs) = not $ null xs
10321039

10331040
recursiveRebind :: Env -> (Var, EgisonExpr) -> EgisonM Env
10341041
recursiveRebind env (name, expr) = do

hs-src/Language/Egison/ParserNonS.hs

+8-4
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ term' :: Parser EgisonExpr
230230
term' = matchExpr
231231
<|> matchAllExpr
232232
<|> matchLambdaExpr
233+
<|> matchAllLambdaExpr
233234
<|> matcherExpr
234235
<|> matcherDFSExpr
235236
<|> functionWithArgExpr
@@ -318,19 +319,22 @@ quoteSymbolExpr :: Parser EgisonExpr
318319
quoteSymbolExpr = char '`' >> QuoteSymbolExpr <$> expr
319320

320321
matchAllExpr :: Parser EgisonExpr
321-
matchAllExpr = keywordMatchAll >> MatchAllExpr <$> expr <* (inSpaces $ string "as") <*> expr <*> matchClauses
322+
matchAllExpr = keywordMatchAll >> MatchAllExpr <$> expr <* keywordAs <*> expr <*> matchClauses
322323

323324
matchExpr :: Parser EgisonExpr
324-
matchExpr = keywordMatch >> MatchExpr <$> expr <* (inSpaces $ string "as") <*> expr <*> matchClauses
325+
matchExpr = keywordMatch >> MatchExpr <$> expr <* keywordAs <*> expr <*> matchClauses
325326

326327
matchLambdaExpr :: Parser EgisonExpr
327-
matchLambdaExpr = keywordMatchLambda >> MatchLambdaExpr <$ (inSpaces $ string "as") <*> expr <*> matchClauses
328+
matchLambdaExpr = keywordMatchLambda >> MatchLambdaExpr <$ keywordAs <*> expr <*> matchClauses
329+
330+
matchAllLambdaExpr :: Parser EgisonExpr
331+
matchAllLambdaExpr = keywordMatchAllLambda >> MatchAllLambdaExpr <$ keywordAs <*> expr <*> matchClauses
328332

329333
matchClauses :: Parser [MatchClause]
330334
matchClauses = many1 matchClause
331335

332336
matchClause :: Parser MatchClause
333-
matchClause = inSpaces (string "|") >> (,) <$> pattern <* (reservedOp "->") <*> expr
337+
matchClause = try $ inSpaces (string "|") >> (,) <$> pattern <* (reservedOp "->") <*> expr
334338

335339
matcherExpr :: Parser EgisonExpr
336340
matcherExpr = keywordMatcher >> MatcherExpr <$> ppMatchClauses

hs-src/Language/Egison/Types.hs

+6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ module Language.Egison.Types
4343
, tIndex
4444
, tref
4545
, enumTensorIndices
46+
, changeIndexList
4647
, tTranspose
4748
, tTranspose'
4849
, tFlipIndices
@@ -879,6 +880,11 @@ enumTensorIndices :: [Integer] -> [[Integer]]
879880
enumTensorIndices [] = [[]]
880881
enumTensorIndices (n:ns) = concatMap (\i -> (map (\is -> i:is) (enumTensorIndices ns))) [1..n]
881882

883+
changeIndexList :: [Index String] -> [EgisonValue] -> [Index String]
884+
changeIndexList idxlist ms = map (\(i, m) -> case i of
885+
Superscript s -> Superscript (s ++ m)
886+
Subscript s -> Subscript (s ++ m)) $ zip idxlist (map show ms)
887+
882888
transIndex :: [Index EgisonValue] -> [Index EgisonValue] -> [Integer] -> EgisonM [Integer]
883889
transIndex [] [] is = return is
884890
transIndex (j1:js1) js2 is = do

nons-test/test/lib/math/tensor.egi

+5
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,8 @@ assertEqual("generate_tensor by using function expr",
6565
[3, 3]) in
6666
show(withSymbols {i, j} d/d(g_i_j, x)),
6767
"[| [| g_1_1|x 0 0 |] [| 0 g_2_2|x 0 |] [| 0 0 g_3_3|x |] |]")
68+
69+
assertEqual("define tensor having value of function expr",
70+
letrec g__ = [| [| function(x, y, z), 0, 0 |], [| 0, function(x, y, z), 0 |], [| 0, 0, function(x, y, z) |] |] in
71+
show(withSymbols {i, j} d/d(g_i_j, x)),
72+
"[| [| g_1_1|x 0 0 |] [| 0 g_2_2|x 0 |] [| 0 0 g_3_3|x |] |]")

nons-test/test/syntax.egi

+16-4
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,31 @@ assertEqual("match-all",
104104
| $x<:>$xs -> [x, xs],
105105
[[1, [2, 3]]])
106106

107+
assertEqual("match-all-multi",
108+
matchAll [1, 2, 3] as multiset(integer)
109+
| $x <:> (x + 1) <:> _ -> [x, x + 1]
110+
| $x <:> (x + 2) <:> _ -> [x, x + 2],
111+
[[1, 2], [2, 3], [1, 3]])
112+
107113
assertEqual("match-lambda",
108-
letrec count = $l ->
109-
match l as list(something)
114+
letrec count =
115+
matchLambda as list(something)
110116
| <nil> -> 0
111117
| _<:>$xs -> count(xs)+1 in
112118
count([1, 2, 3]),
113119
3)
114120

115121
assertEqual("match-all-lambda",
116-
($l -> matchAll l as list(something)
117-
| _ <++> $x <:> _ -> x)([1, 2, 3]),
122+
(matchAllLambda as list(something)
123+
| _ <++> $x <:> _ -> x)([1, 2, 3]),
118124
[1, 2, 3])
119125

126+
assertEqual("match-all-lambda-multi",
127+
(matchAllLambda as multiset(something)
128+
| $x <:> (x + 1) <:> _ -> [x, x + 1]
129+
| $x <:> (x + 2) <:> _ -> [x, x + 2])([1, 2, 3]),
130+
[[1, 2], [2, 3], [1, 3]])
131+
120132
assertEqual("pattern variable",
121133
match 1 as something
122134
| $x -> x,

test/lib/math/tensor.egi

+5
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,8 @@
7878
{3 3})]}
7979
(show (with-symbols {i j} (d/d g_i_j x))))
8080
"[| [| g_1_1|x 0 0 |] [| 0 g_2_2|x 0 |] [| 0 0 g_3_3|x |] |]")
81+
82+
(assert-equal "define tensor having value of function expr"
83+
(letrec {[$g__ [| [| (function [x y z]) 0 0 |] [| 0 (function [x y z]) 0 |] [| 0 0 (function [x y z]) |] |]]}
84+
(show (with-symbols {i j} (d/d g_i_j x))))
85+
"[| [| g_1_1|x 0 0 |] [| 0 g_2_2|x 0 |] [| 0 0 g_3_3|x |] |]")

0 commit comments

Comments
 (0)