@@ -101,6 +101,7 @@ module Language.Egison.Types
101
101
, liftError
102
102
-- * Monads
103
103
, EgisonM (.. )
104
+ , parallelMapM
104
105
, runEgisonM
105
106
, liftEgisonM
106
107
, fromEgisonM
@@ -146,6 +147,7 @@ module Language.Egison.Types
146
147
import Prelude hiding (foldr , mappend , mconcat )
147
148
148
149
import Control.Exception
150
+ import Control.Parallel
149
151
import Data.Typeable
150
152
151
153
import Control.Applicative
@@ -155,7 +157,6 @@ import Control.Monad.Reader (ReaderT)
155
157
import Control.Monad.Writer (WriterT )
156
158
import Control.Monad.Identity
157
159
import Control.Monad.Trans.Maybe
158
- import qualified Control.Monad.Parallel as MP
159
160
160
161
import Data.Monoid (Monoid )
161
162
import qualified Data.HashMap.Lazy as HL
@@ -837,7 +838,7 @@ tTranspose' is t@(Tensor ns xs js) = do
837
838
838
839
tMap :: HasTensor a => (a -> EgisonM a ) -> (Tensor a ) -> EgisonM (Tensor a )
839
840
tMap f (Tensor ns xs js) = do
840
- xs' <- MP. mapM f (V. toList xs) >>= return . V. fromList
841
+ xs' <- parallelMapM f (V. toList xs) >>= return . V. fromList
841
842
t <- toTensor (V. head xs')
842
843
case t of
843
844
(Tensor ns1 _ js1) ->
@@ -847,9 +848,9 @@ tMap f (Scalar x) = f x >>= return . Scalar
847
848
848
849
tMapN :: HasTensor a => ([a ] -> EgisonM a ) -> [Tensor a ] -> EgisonM (Tensor a )
849
850
tMapN f ts@ ((Tensor ns xs js): _) = do
850
- xs' <- MP. mapM (\ is -> mapM (tIntRef is) ts >>= mapM fromTensor >>= f) (enumTensorIndices ns)
851
+ xs' <- parallelMapM (\ is -> mapM (tIntRef is) ts >>= mapM fromTensor >>= f) (enumTensorIndices ns)
851
852
return $ Tensor ns (V. fromList xs') js
852
- tMapN f xs = MP. mapM fromTensor xs >>= f >>= return . Scalar
853
+ tMapN f xs = parallelMapM fromTensor xs >>= f >>= return . Scalar
853
854
854
855
tMap2 :: HasTensor a => (a -> a -> EgisonM a ) -> Tensor a -> Tensor a -> EgisonM (Tensor a )
855
856
tMap2 f t1@ (Tensor ns1 xs1 js1) t2@ (Tensor ns2 xs2 js2) = do
@@ -859,7 +860,7 @@ tMap2 f t1@(Tensor ns1 xs1 js1) t2@(Tensor ns2 xs2 js2) = do
859
860
let cns = take (length cjs) (tSize t1')
860
861
rts1 <- mapM (flip tIntRef t1') (enumTensorIndices cns)
861
862
rts2 <- mapM (flip tIntRef t2') (enumTensorIndices cns)
862
- rts' <- MP. mapM (\ (t1, t2) -> tProduct f t1 t2) (zip rts1 rts2)
863
+ rts' <- parallelMapM (\ (t1, t2) -> tProduct f t1 t2) (zip rts1 rts2)
863
864
let ret = Tensor (cns ++ (tSize (head rts'))) (V. concat (map tToVector rts')) (cjs ++ tIndex (head rts'))
864
865
tTranspose (uniq (tDiagIndex (js1 ++ js2))) ret
865
866
where
@@ -1554,8 +1555,23 @@ liftError = either throwError return
1554
1555
1555
1556
newtype EgisonM a = EgisonM {
1556
1557
unEgisonM :: (ExceptT EgisonError (FreshT IO ) a )
1557
- } deriving (Functor , Applicative , Monad , MonadIO , MonadError EgisonError , MonadFresh , MP.MonadParallel )
1558
- -- } deriving (Functor, Applicative, Monad, MonadIO, MonadError EgisonError, MonadFresh)
1558
+ } deriving (Functor , Applicative , Monad , MonadIO , MonadError EgisonError , MonadFresh )
1559
+
1560
+ parallelMapM :: (a -> EgisonM b ) -> [a ] -> EgisonM [b ]
1561
+ parallelMapM f [] = return []
1562
+ parallelMapM f (x: xs) = do
1563
+ let y = unsafePerformEgison (0 ,1 ) $ f x
1564
+ let ys = unsafePerformEgison (0 ,1 ) $ parallelMapM f xs
1565
+ y `par` (ys `pseq` return (y: ys))
1566
+
1567
+ unsafePerformEgison :: (Int , Int ) -> EgisonM a -> a
1568
+ unsafePerformEgison (x, y) ma =
1569
+ let ((Right ret), _) = unsafePerformIO $ runFreshT (x, y + 1 ) $ runEgisonM ma in
1570
+ ret
1571
+ -- f' :: (Either EgisonError a) -> (Either EgisonError b) -> EgisonM c
1572
+ -- f' (Right x) (Right y) = f x y
1573
+ -- f' (Left e) _ = liftError (Left e)
1574
+ -- f' _ (Left e) = liftError (Left e)
1559
1575
1560
1576
runEgisonM :: EgisonM a -> FreshT IO (Either EgisonError a )
1561
1577
runEgisonM = runExceptT . unEgisonM
@@ -1587,8 +1603,7 @@ modifyCounter m = do
1587
1603
return result
1588
1604
1589
1605
newtype FreshT m a = FreshT { unFreshT :: StateT (Int , Int ) m a }
1590
- deriving (Functor , Applicative , Monad , MonadState (Int , Int ), MonadTrans , MP. MonadParallel )
1591
- -- deriving (Functor, Applicative, Monad, MonadState Int, MonadTrans)
1606
+ deriving (Functor , Applicative , Monad , MonadState (Int , Int ), MonadTrans )
1592
1607
1593
1608
type Fresh = FreshT Identity
1594
1609
@@ -1622,8 +1637,6 @@ instance (MonadFresh m, Monoid e) => MonadFresh (WriterT e m) where
1622
1637
instance MonadIO (FreshT IO ) where
1623
1638
liftIO = lift
1624
1639
1625
- instance (MP. MonadParallel m ) => MP. MonadParallel (StateT s m )
1626
-
1627
1640
runFreshT :: Monad m => (Int , Int ) -> FreshT m a -> m (a , (Int , Int ))
1628
1641
runFreshT seed = flip (runStateT . unFreshT) seed
1629
1642
0 commit comments