Skip to content

Commit 3ed541b

Browse files
jbertholdgithub-actions
and
github-actions
authored
Only send unevaluated functions to LLVM (#4018)
This change steps away from maximising the terms we send to the LLVM backend for evaluation, and only sends terms with an unevaluated function call at the top. It addresses a particular constellation that has been observed in downstream semantics/proofs: * a configuration containing large data structures in cells that remain unchanged * a long-running sequence of rewrites that keep inserting concrete function applications into cells required for the next rewrite step, * and therefore require frequent configuration simplifications Each time the configuration is simplified, the large unchanged data structure is sent to the LLVM backend, which incurs considerable overhead. On the flip side, we will send each function that needs to be evaluated individually. However, rewrites or equations will usually not create more than a handful of new unevaluated function calls at one time, and we will still send _nested_ unevaluated expressions together. Also includes a logging call to measure the duration of LLVM interactions (including argument/result conversion), and necessary code restructuring. --------- Co-authored-by: github-actions <[email protected]>
1 parent def3190 commit 3ed541b

File tree

5 files changed

+72
-31
lines changed

5 files changed

+72
-31
lines changed

booster/library/Booster/LLVM.hs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,36 @@ module Booster.LLVM (
66
) where
77

88
import Control.Monad.IO.Class (MonadIO (..))
9+
import Data.Aeson
910
import Data.Binary.Get
1011
import Data.ByteString (fromStrict)
12+
import Data.ByteString.Char8 qualified as BS
1113
import Data.Map qualified as Map
1214
import Data.Set qualified as Set
15+
import Data.Text qualified as T
1316

1417
import Booster.Definition.Base
1518
import Booster.LLVM.Internal qualified as Internal
19+
import Booster.Log
1620
import Booster.Pattern.Base
1721
import Booster.Pattern.Binary
1822
import Booster.Pattern.Util
19-
import Data.ByteString.Char8 qualified as BS
23+
import Booster.Util (secWithUnit, timed)
2024

21-
simplifyBool :: MonadIO io => Internal.API -> Term -> io (Either Internal.LlvmError Bool)
22-
simplifyBool api trm = liftIO $ Internal.runLLVM api $ do
25+
simplifyBool :: LoggerMIO io => Internal.API -> Term -> io (Either Internal.LlvmError Bool)
26+
simplifyBool api trm = ioWithTiming $ Internal.runLLVM api $ do
2327
kore <- Internal.ask
2428
trmPtr <- Internal.marshallTerm trm
2529
liftIO $ kore.simplifyBool trmPtr
2630

2731
simplifyTerm ::
28-
MonadIO io => Internal.API -> KoreDefinition -> Term -> Sort -> io (Either Internal.LlvmError Term)
29-
simplifyTerm api def trm sort = liftIO $ Internal.runLLVM api $ do
32+
LoggerMIO io =>
33+
Internal.API ->
34+
KoreDefinition ->
35+
Term ->
36+
Sort ->
37+
io (Either Internal.LlvmError Term)
38+
simplifyTerm api def trm sort = ioWithTiming $ Internal.runLLVM api $ do
3039
kore <- Internal.ask
3140
trmPtr <- Internal.marshallTerm trm
3241
sortPtr <- Internal.marshallSort sort
@@ -56,3 +65,11 @@ simplifyTerm api def trm sort = liftIO $ Internal.runLLVM api $ do
5665
sortName (SortApp name _) = name
5766
sortName (SortVar name) = name
5867
subsorts = maybe Set.empty snd $ Map.lookup (sortName sort) def.sorts
68+
69+
ioWithTiming :: LoggerMIO io => IO a -> io a
70+
ioWithTiming action = do
71+
(result, time) <- liftIO $ timed action
72+
withContext CtxTiming . logMessage $
73+
WithJsonMessage (object ["time" .= time]) $
74+
"Performed LLVM call in " <> T.pack (secWithUnit time)
75+
pure result

booster/library/Booster/Pattern/ApplyEquations.hs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,9 @@ llvmSimplify term = do
395395
where
396396
evalLlvm definition api cb t@(Term attributes _)
397397
| attributes.isEvaluated = pure t
398-
| isConcrete t && attributes.canBeEvaluated = withContext CtxLlvm . withTermContext t $ do
398+
| isConcrete t
399+
, attributes.canBeEvaluated
400+
, isFunctionApp t = withContext CtxLlvm . withTermContext t $ do
399401
LLVM.simplifyTerm api definition t (sortOfTerm t)
400402
>>= \case
401403
Left (LlvmError e) -> do
@@ -413,6 +415,10 @@ llvmSimplify term = do
413415
| otherwise =
414416
cb t
415417

418+
isFunctionApp :: Term -> Bool
419+
isFunctionApp (SymbolApplication sym _ _) = isFunctionSymbol sym
420+
isFunctionApp _ = False
421+
416422
----------------------------------------
417423
-- Interface functions
418424

@@ -1065,11 +1071,11 @@ simplifyConstraint' :: LoggerMIO io => Bool -> Term -> EquationT io Term
10651071
-- evaluateTerm.
10661072
simplifyConstraint' recurseIntoEvalBool = \case
10671073
t@(Term TermAttributes{canBeEvaluated} _)
1068-
| isConcrete t && canBeEvaluated -> withTermContext t $ do
1074+
| isConcrete t && canBeEvaluated -> do
10691075
mbApi <- (.llvmApi) <$> getConfig
10701076
case mbApi of
10711077
Just api ->
1072-
withContext CtxLlvm $
1078+
withContext CtxLlvm . withTermContext t $
10731079
LLVM.simplifyBool api t >>= \case
10741080
Left (LlvmError e) -> do
10751081
withContext CtxAbort $
@@ -1086,11 +1092,10 @@ simplifyConstraint' recurseIntoEvalBool = \case
10861092
pure result
10871093
Nothing -> if recurseIntoEvalBool then evalBool t else pure t
10881094
| otherwise ->
1089-
withTermContext t $
1090-
if recurseIntoEvalBool then evalBool t else pure t
1095+
if recurseIntoEvalBool then evalBool t else pure t
10911096
where
10921097
evalBool :: LoggerMIO io => Term -> EquationT io Term
1093-
evalBool t = do
1098+
evalBool t = withTermContext t $ do
10941099
prior <- getState -- save prior state so we can revert
10951100
eqState $ put prior{termStack = mempty, changed = False}
10961101
result <- iterateEquations BottomUp PreferFunctions t

booster/library/Booster/Util.hs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@ module Booster.Util (
1414
newTimeCache,
1515
pattern PrettyTimestamps,
1616
pattern NoPrettyTimestamps,
17+
timed,
18+
secWithUnit,
1719
) where
1820

1921
import Control.AutoUpdate (defaultUpdateSettings, mkAutoUpdate, updateAction, updateFreq)
2022
import Control.DeepSeq (NFData (..))
2123
import Control.Exception (bracket, catch, throwIO)
24+
import Control.Monad.IO.Class (MonadIO (liftIO))
2225
import Data.ByteString (ByteString)
2326
import Data.ByteString.Char8 qualified as BS
2427
import Data.Coerce (coerce)
@@ -31,6 +34,7 @@ import Data.Time.Clock.System (SystemTime (..), getSystemTime, systemToUTCTime)
3134
import Data.Time.Format
3235
import GHC.Generics (Generic)
3336
import Language.Haskell.TH.Syntax (Lift)
37+
import System.Clock
3438
import System.Directory (removeFile)
3539
import System.IO.Error (isDoesNotExistError)
3640
import System.Log.FastLogger (
@@ -42,6 +46,7 @@ import System.Log.FastLogger (
4246
newTimedFastLogger,
4347
)
4448
import System.Log.FastLogger.Types (FormattedTime)
49+
import Text.Printf
4550

4651
newtype Flag (name :: k) = Flag Bool
4752
deriving stock (Eq, Ord, Show, Generic, Data, Lift)
@@ -185,14 +190,32 @@ pattern NoPrettyTimestamps = Flag False
185190
-- | Format time either as a human-readable date and time or as nanoseconds
186191
formatSystemTime :: Flag "PrettyTimestamp" -> SystemTime -> ByteString
187192
formatSystemTime prettyTimestamp =
188-
let formatString = "%Y-%m-%dT%H:%M:%S%6Q"
193+
let formatStr = "%Y-%m-%dT%H:%M:%S%6Q"
189194
formatter =
190195
if coerce prettyTimestamp
191-
then formatTime defaultTimeLocale formatString . systemToUTCTime
196+
then formatTime defaultTimeLocale formatStr . systemToUTCTime
192197
else show . toNanoSeconds
193198
in BS.pack . formatter
194199
where
195200
toNanoSeconds :: SystemTime -> Integer
196201
toNanoSeconds MkSystemTime{systemSeconds, systemNanoseconds} =
197202
fromIntegral @_ @Integer systemSeconds * (10 :: Integer) ^ (9 :: Integer)
198203
+ fromIntegral @_ @Integer systemNanoseconds
204+
205+
------------------------------------------------------------
206+
-- helper for measuring durations
207+
208+
-- returns time taken by the given action (in seconds)
209+
timed :: MonadIO m => m a -> m (a, Double)
210+
timed action = do
211+
start <- liftIO $ getTime Monotonic
212+
result <- action
213+
stop <- liftIO $ getTime Monotonic
214+
let time = fromIntegral (toNanoSecs (diffTimeSpec stop start)) / 10 ** 9
215+
pure (result, time)
216+
217+
secWithUnit :: (Floating a, Ord a, PrintfArg a) => a -> String
218+
secWithUnit x
219+
| x > 0.1 = printf "%.2fs" x
220+
| x > 0.0001 = printf "%.3fms" $ x * 10 ** 3
221+
| otherwise = printf "%.1fμs" $ x * 10 ** 6

booster/test/llvm-integration/LLVM.hs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
{-# LANGUAGE PatternSynonyms #-}
22

3+
{-# OPTIONS -fno-warn-orphans #-}
4+
35
{- |
46
Copyright : (c) Runtime Verification, 2023
57
License : BSD-3-Clause
@@ -19,6 +21,7 @@ import Data.List (foldl1', isInfixOf, nub)
1921
import Data.Map (Map)
2022
import Data.Map qualified as Map
2123
import Data.Maybe (fromMaybe)
24+
import Data.Proxy
2225
import Data.Set (Set)
2326
import Data.Set qualified as Set
2427
import Data.Text (Text)
@@ -40,7 +43,9 @@ import Booster.Definition.Attributes.Base
4043
import Booster.Definition.Base
4144
import Booster.LLVM qualified as LLVM
4245
import Booster.LLVM.Internal qualified as Internal
46+
import Booster.Log
4347
import Booster.Pattern.Base
48+
import Booster.Pattern.Pretty
4449
import Booster.SMT.Base (SExpr (..), SMTId (..))
4550
import Booster.Syntax.Json.Externalise (externaliseTerm)
4651
import Booster.Syntax.Json.Internalise (pattern AllowAlias, pattern IgnoreSubsorts)
@@ -109,6 +114,11 @@ llvmSpec =
109114
--------------------------------------------------
110115
-- individual hedgehog property tests and helpers
111116

117+
instance LoggerMIO (PropertyT IO) where
118+
getLogger = pure $ Logger $ \_ -> pure ()
119+
getPrettyModifiers = pure $ ModifiersRep @'[] Proxy
120+
withLogger _ = id
121+
112122
boolsRemainProp
113123
, compareNumbersProp
114124
, simplifyComparisonProp ::

booster/tools/booster/Stats.hs

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ module Stats (
22
newStats,
33
addStats,
44
finaliseStats,
5-
timed,
6-
secWithUnit,
75
RequestStats (..),
86
StatsVar,
97
MethodTiming (..),
8+
-- re-export
9+
timed,
10+
secWithUnit,
1011
) where
1112

1213
import Control.Concurrent.MVar (MVar, modifyMVar_, newMVar, readMVar)
@@ -18,11 +19,11 @@ import Data.Text (pack)
1819
import Deriving.Aeson
1920
import GHC.Generics ()
2021
import Prettyprinter
21-
import System.Clock
2222
import Text.Printf
2323

2424
import Booster.Log
2525
import Booster.Prettyprinter
26+
import Booster.Util (secWithUnit, timed)
2627
import Kore.JsonRpc.Types (APIMethod)
2728

2829
-- | Statistics for duration measurement time series (in seconds)
@@ -62,12 +63,6 @@ instance (Floating a, PrintfArg a, Ord a) => Pretty (RequestStats a) where
6263
where
6364
withUnit = pretty . secWithUnit
6465

65-
secWithUnit :: (Floating a, Ord a, PrintfArg a) => a -> String
66-
secWithUnit x
67-
| x > 0.1 = printf "%.2fs" x
68-
| x > 0.0001 = printf "%.3fms" $ x * 10 ** 3
69-
| otherwise = printf "%.1fμs" $ x * 10 ** 6
70-
7166
-- internal helper type
7267
-- all values are in seconds
7368
data Stats' = Stats'
@@ -138,15 +133,6 @@ addStats statVar MethodTiming{method, time, koreTime} =
138133
newStats :: MonadIO m => m (MVar (Map APIMethod Stats'))
139134
newStats = liftIO $ newMVar Map.empty
140135

141-
-- returns time taken by the given action (in seconds)
142-
timed :: MonadIO m => m a -> m (a, Double)
143-
timed action = do
144-
start <- liftIO $ getTime Monotonic
145-
result <- action
146-
stop <- liftIO $ getTime Monotonic
147-
let time = fromIntegral (toNanoSecs (diffTimeSpec stop start)) / 10 ** 9
148-
pure (result, time)
149-
150136
newtype FinalStats = FinalStats (Map APIMethod (RequestStats Double))
151137
deriving stock (Eq, Show)
152138
deriving newtype (FromJSON, ToJSON)

0 commit comments

Comments
 (0)