diff --git a/what4/src/What4/Expr/App.hs b/what4/src/What4/Expr/App.hs index b0ffbf24..7ec2a132 100644 --- a/what4/src/What4/Expr/App.hs +++ b/what4/src/What4/Expr/App.hs @@ -1722,6 +1722,8 @@ testExprSymFnEq f g = testEquality (symFnId f) (symFnId g) instance IsSymFn (ExprSymFn t) where fnArgTypes = symFnArgTypes fnReturnType = symFnReturnType + fnTestEquality = testExprSymFnEq + fnCompare f g = compareF (symFnId f) (symFnId g) ------------------------------------------------------------------------------- diff --git a/what4/src/What4/Expr/Builder.hs b/what4/src/What4/Expr/Builder.hs index eae5a317..c41ca600 100644 --- a/what4/src/What4/Expr/Builder.hs +++ b/what4/src/What4/Expr/Builder.hs @@ -134,6 +134,8 @@ module What4.Expr.Builder , SymFnInfo(..) , symFnArgTypes , symFnReturnType + , SomeExprSymFn(..) + , ExprSymFnWrapper(..) -- * SymbolVarBimap , SymbolVarBimap @@ -247,19 +249,19 @@ import What4.Utils.StringLiteral toDouble :: Rational -> Double toDouble = fromRational -cachedEval :: (HashableF k, TestEquality k) +cachedEval :: (HashableF k, TestEquality k, MonadIO m) => PH.HashTable RealWorld k a -> k tp - -> IO (a tp) - -> IO (a tp) + -> m (a tp) + -> m (a tp) cachedEval tbl k action = do - mr <- stToIO $ PH.lookup tbl k + mr <- liftIO $ stToIO $ PH.lookup tbl k case mr of Just r -> return r Nothing -> do r <- action seq r $ do - stToIO $ PH.insert tbl k r + liftIO $ stToIO $ PH.insert tbl k r return r ------------------------------------------------------------------------ @@ -319,7 +321,19 @@ instance HashableF (MatlabFnWrapper t) where data ExprSymFnWrapper t c = forall a r . (c ~ (a ::> r)) => ExprSymFnWrapper (ExprSymFn t a r) -data SomeSymFn sym = forall args ret . SomeSymFn (SymFn sym args ret) +data SomeExprSymFn t = forall args ret . SomeExprSymFn (ExprSymFn t args ret) + +instance Eq (SomeExprSymFn t) where + (SomeExprSymFn fn1) == (SomeExprSymFn fn2) = + isJust $ fnTestEquality fn1 fn2 + +instance Ord (SomeExprSymFn t) where + compare (SomeExprSymFn fn1) (SomeExprSymFn fn2) = + toOrdering $ fnCompare fn1 fn2 + +instance Show (SomeExprSymFn t) where + show (SomeExprSymFn f) = show f + ------------------------------------------------------------------------ -- ExprBuilder @@ -751,12 +765,12 @@ betaReduce sym f args = -- -- It is used when an action may modify a value, and we only want to run a -- second action if the value changed. -runIfChanged :: Eq e +runIfChanged :: (Eq e, Monad m) => e - -> (e -> IO e) -- ^ First action to run + -> (e -> m e) -- ^ First action to run -> r -- ^ Result if no change. - -> (e -> IO r) -- ^ Second action to run - -> IO r + -> (e -> m r) -- ^ Second action to run + -> m r runIfChanged x f unChanged onChange = do y <- f x if x == y then @@ -807,11 +821,14 @@ evalSimpleFn :: EvalHashTables t -> ExprBuilder t st fs -> ExprSymFn t idx ret -> IO (Bool,ExprSymFn t idx ret) -evalSimpleFn tbl sym f = +evalSimpleFn tbl sym f = do + let n = symFnId f case symFnInfo f of - UninterpFnInfo{} -> return (False, f) + UninterpFnInfo{} -> do + CachedSymFn changed f' <- cachedEval (fnTable tbl) n $ + return $! CachedSymFn False f + return (changed, f') DefinedFnInfo vars e evalFn -> do - let n = symFnId f let nm = symFnName f CachedSymFn changed f' <- cachedEval (fnTable tbl) n $ do @@ -4028,6 +4045,30 @@ instance IsSymExprBuilder (ExprBuilder t st fs) where evalMatlabSolverFn f sym args _ -> sbNonceExpr sym $! FnApp fn args + substituteBoundVars sym subst e = do + tbls <- stToIO $ do + expr_tbl <- PH.newSized $ PM.size subst + fn_tbl <- PH.new + PM.traverseWithKey_ (PH.insert expr_tbl . BoundVarExpr) subst + return $ EvalHashTables + { exprTable = expr_tbl + , fnTable = fn_tbl + } + evalBoundVars' tbls sym e + + substituteSymFns sym subst e = do + tbls <- stToIO $ do + expr_tbl <- PH.new + fn_tbl <- PH.newSized $ PM.size subst + PM.traverseWithKey_ + (\(SymFnWrapper f) (SymFnWrapper g) -> PH.insert fn_tbl (symFnId f) (CachedSymFn True g)) + subst + return $ EvalHashTables + { exprTable = expr_tbl + , fnTable = fn_tbl + } + evalBoundVars' tbls sym e + instance IsInterpretedFloatExprBuilder (ExprBuilder t st fs) => IsInterpretedFloatSymExprBuilder (ExprBuilder t st fs) diff --git a/what4/src/What4/Interface.hs b/what4/src/What4/Interface.hs index 0b98533f..a99c853d 100644 --- a/what4/src/What4/Interface.hs +++ b/what4/src/What4/Interface.hs @@ -38,6 +38,12 @@ provide several type family definitions and class instances for @sym@: [@instance 'HashableF' ('SymExpr' sym)@] + [@instance 'OrdF' ('BoundVar' sym)@] + + [@instance 'TestEquality' ('BoundVar' sym)@] + + [@instance 'HashableF' ('BoundVar' sym)@] + The canonical implementation of these interface classes is found in "What4.Expr.Builder". -} {-# LANGUAGE CPP #-} @@ -72,6 +78,8 @@ module What4.Interface -- ** Expression recognizers , IsExpr(..) , IsSymFn(..) + , SomeSymFn(..) + , SymFnWrapper(..) , UnfoldPolicy(..) , shouldUnfold @@ -202,6 +210,7 @@ import Data.Parameterized.Classes import qualified Data.Parameterized.Context as Ctx import Data.Parameterized.Ctx import Data.Parameterized.Utils.Endian (Endian(..)) +import Data.Parameterized.Map (MapF) import Data.Parameterized.NatRepr import Data.Parameterized.TraversableFC import qualified Data.Parameterized.Vector as Vector @@ -587,7 +596,7 @@ instance (HashableF (SymExpr sym), TestEquality (SymExpr sym)) => Hashable (SymN -- of an undefined function is _not_ guaranteed to be equivalant to a free -- constant, and no guarantees are made about what properties such values -- will satisfy. -class ( IsExpr (SymExpr sym), HashableF (SymExpr sym) +class ( IsExpr (SymExpr sym), HashableF (SymExpr sym), HashableF (BoundVar sym) , TestEquality (SymAnnotation sym), OrdF (SymAnnotation sym) , HashableF (SymAnnotation sym) ) => IsExprBuilder sym where @@ -2718,14 +2727,42 @@ iteList ite sym ((mp,mx):xs) def = -- 'IsSymExprBuilder'. type family SymFn sym :: Ctx BaseType -> BaseType -> Type +data SomeSymFn sym = forall args ret . SomeSymFn (SymFn sym args ret) + +-- | Wrapper for `SymFn` that concatenates the arguments and the return types. +-- +-- This is useful for implementing `TestEquality` and `OrdF` instances for +-- `SymFn`, and for using `SymFn` as a key or a value in a `MapF`. +data SymFnWrapper sym ctx where + SymFnWrapper :: forall sym args ret . SymFn sym args ret -> SymFnWrapper sym (args ::> ret) + +instance IsSymFn (SymFn sym) => TestEquality (SymFnWrapper sym) where + testEquality (SymFnWrapper fn1) (SymFnWrapper fn2) = fnTestEquality fn1 fn2 + +instance IsSymFn (SymFn sym) => OrdF (SymFnWrapper sym) where + compareF (SymFnWrapper fn1) (SymFnWrapper fn2) = fnCompare fn1 fn2 + -- | A class for extracting type representatives from symbolic functions -class IsSymFn fn where +class IsSymFn (fn :: Ctx BaseType -> BaseType -> Type) where -- | Get the argument types of a function. fnArgTypes :: fn args ret -> Ctx.Assignment BaseTypeRepr args -- | Get the return type of a function. fnReturnType :: fn args ret -> BaseTypeRepr ret + -- | Test whether two functions are equal. + -- + -- The implementation may be incomplete, that is, if it returns `Just` then + -- the functions are equal, while if it returns `Nothing` then the functions + -- may or may not be equal. The result of `freshTotalUninterpFn` or + -- `definedFn` tests equal with itself. + fnTestEquality :: fn args1 ret1 -> fn args2 ret2 -> Maybe ((args1 ::> ret1) :~: (args2 ::> ret2)) + + -- | Compare two functions for ordering. + -- + -- The underlying equality test is provided by `fnTestEquality`. + fnCompare :: fn args1 ret1 -> fn args2 ret2 -> OrderingF (args1 ::> ret1) (args2 ::> ret2) + -- | Describes when we unfold the body of defined functions. data UnfoldPolicy @@ -2769,6 +2806,7 @@ instance Show InvalidRange where class ( IsExprBuilder sym , IsSymFn (SymFn sym) , OrdF (SymExpr sym) + , OrdF (BoundVar sym) ) => IsSymExprBuilder sym where ---------------------------------------------------------------------- @@ -2914,6 +2952,22 @@ class ( IsExprBuilder sym -- ^ Arguments to function -> IO (SymExpr sym ret) + -- | Apply a variable substitution (variable to symbolic expression mapping) + -- to a symbolic expression. + substituteBoundVars :: + sym -> + MapF (BoundVar sym) (SymExpr sym) -> + SymExpr sym tp -> + IO (SymExpr sym tp) + + -- | Apply a function substitution (function to function mapping) to a + -- symbolic expression. + substituteSymFns :: + sym -> + MapF (SymFnWrapper sym) (SymFnWrapper sym) -> + SymExpr sym tp -> + IO (SymExpr sym tp) + -- | This returns true if the value corresponds to a concrete value. baseIsConcrete :: forall e bt . IsExpr e diff --git a/what4/src/What4/Protocol/SMTLib2.hs b/what4/src/What4/Protocol/SMTLib2.hs index 56d9c35f..4b61f531 100644 --- a/what4/src/What4/Protocol/SMTLib2.hs +++ b/what4/src/What4/Protocol/SMTLib2.hs @@ -15,9 +15,12 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternGuards #-} @@ -40,6 +43,7 @@ module What4.Protocol.SMTLib2 , writeGetValue , writeGetAbduct , writeGetAbductNext + , writeCheckSynth , runCheckSat , runGetAbducts , asSMT2Type @@ -51,10 +55,13 @@ module What4.Protocol.SMTLib2 , setProduceModels , smtLibEvalFuns , smtlib2Options + , parseFnModel + , parseFnValues -- * Logic , SMT2.Logic(..) , SMT2.qf_bv , SMT2.allSupported + , SMT2.hornLogic , all_supported , setLogic -- * Type @@ -69,6 +76,7 @@ module What4.Protocol.SMTLib2 , Session(..) , SMTLib2GenericSolver(..) , writeDefaultSMT2 + , defaultFileWriter , startSolver , shutdownSolver , smtAckResult @@ -93,14 +101,21 @@ import Control.Monad.Fail( MonadFail ) import Control.Applicative import Control.Exception -import Control.Monad.State.Strict +import Control.Monad.Except +import Control.Monad.Reader +import qualified Data.Bimap as Bimap import qualified Data.BitVector.Sized as BV import Data.Char (digitToInt, isAscii) +import Data.HashMap.Lazy (HashMap) +import qualified Data.HashMap.Lazy as HashMap import Data.IORef import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Data.Monoid +import Data.Parameterized.Classes import qualified Data.Parameterized.Context as Ctx +import Data.Parameterized.Map (MapF) +import qualified Data.Parameterized.Map as MapF import Data.Parameterized.NatRepr import Data.Parameterized.Pair import Data.Parameterized.Some @@ -124,6 +139,7 @@ import qualified System.IO.Streams as Streams import Data.Versions (Version(..)) import qualified Data.Versions as Versions import qualified Prettyprinter as PP +import Text.Printf (printf) import LibBF( bfToBits ) import Prelude hiding (writeFile) @@ -701,6 +717,11 @@ instance SMTLib2Tweaks a => SMTWriter (Writer a) where let resolveArg (var, Some tp) = (var, asSMT2Type @a tp) in SMT2.defineFun f (resolveArg <$> args) (asSMT2Type @a return_type) e + synthFunCommand _proxy f args ret_tp = + SMT2.synthFun f (map (\(var, Some tp) -> (var, asSMT2Type @a tp)) args) (asSMT2Type @a ret_tp) + declareVarCommand _proxy v tp = SMT2.declareVar v (asSMT2Type @a tp) + constraintCommand _proxy e = SMT2.constraint e + stringTerm str = smtlib2StringTerm @a str stringLength x = smtlib2StringLength @a x stringAppend xs = smtlib2StringAppend @a xs @@ -769,6 +790,10 @@ writeGetAbduct w nm p = addCommandNoAck w $ SMT2.getAbduct nm p writeGetAbductNext :: SMTLib2Tweaks a => WriterConn t (Writer a) -> IO () writeGetAbductNext w = addCommandNoAck w SMT2.getAbductNext +-- | Write check-synth command +writeCheckSynth :: SMTLib2Tweaks a => WriterConn t (Writer a) -> IO () +writeCheckSynth w = addCommandNoAck w SMT2.checkSynth + parseBoolSolverValue :: MonadFail m => SExp -> m Bool parseBoolSolverValue (SAtom "true") = return True parseBoolSolverValue (SAtom "false") = return False @@ -776,6 +801,16 @@ parseBoolSolverValue s = do v <- parseBvSolverValue (knownNat @1) s return (if v == BV.zero knownNat then False else True) +parseIntSolverValue :: MonadFail m => SExp -> m Integer +parseIntSolverValue = \case + SAtom v + | [(i, "")] <- readDec (Text.unpack v) -> + return i + SApp ["-", x] -> + negate <$> parseIntSolverValue x + s -> + fail $ "Could not parse solver value: " ++ show s + parseRealSolverValue :: MonadFail m => SExp -> m Rational parseRealSolverValue (SAtom v) | Just (r,"") <- readDecimal (Text.unpack v) = return r @@ -792,10 +827,11 @@ parseRealSolverValue s = fail $ "Could not parse solver value: " ++ show s -- of the variable. parseBvSolverValue :: MonadFail m => NatRepr w -> SExp -> m (BV.BV w) parseBvSolverValue w s - | Pair w' bv <- parseBVLitHelper s = case w' `compareNat` w of + | Just (Pair w' bv) <- parseBVLitHelper s = case w' `compareNat` w of NatLT zw -> return (BV.zext (addNat w' (addNat zw knownNat)) bv) NatEQ -> return bv NatGT _ -> return (BV.trunc w bv) + | otherwise = fail $ "Could not parse bitvector solver value: " ++ show s natBV :: Natural -- ^ width @@ -806,15 +842,14 @@ natBV wNatural x = case mkNatRepr wNatural of Some w -> Pair w (BV.mkBV w x) -- | Parse an s-expression and return a bitvector and its width -parseBVLitHelper :: SExp -> Pair NatRepr BV.BV +parseBVLitHelper :: SExp -> Maybe (Pair NatRepr BV.BV) parseBVLitHelper (SAtom (Text.unpack -> ('#' : 'b' : n_str))) | [(n, "")] <- readBin n_str = - natBV (fromIntegral (length n_str)) n + Just $ natBV (fromIntegral (length n_str)) n parseBVLitHelper (SAtom (Text.unpack -> ('#' : 'x' : n_str))) | [(n, "")] <- readHex n_str = - natBV (fromIntegral (length n_str * 4)) n + Just $ natBV (fromIntegral (length n_str * 4)) n parseBVLitHelper (SApp ["_", SAtom (Text.unpack -> ('b' : 'v' : n_str)), SAtom (Text.unpack -> w_str)]) - | [(n, "")] <- readDec n_str, [(w, "")] <- readDec w_str = natBV w n --- BGS: Is this correct? -parseBVLitHelper _ = natBV 0 0 + | [(n, "")] <- readDec n_str, [(w, "")] <- readDec w_str = Just $ natBV w n +parseBVLitHelper _ = Nothing parseStringSolverValue :: MonadFail m => SExp -> m Text parseStringSolverValue (SString t) | Just t' <- unescapeText t = return t' @@ -845,10 +880,10 @@ data ParsedFloatResult = forall eb sb . ParsedFloatResult parseFloatLitHelper :: MonadFail m => SExp -> m ParsedFloatResult parseFloatLitHelper (SApp ["fp", sign_s, expt_s, scand_s]) - | Pair sign_w sign <- parseBVLitHelper sign_s + | Just (Pair sign_w sign) <- parseBVLitHelper sign_s , Just Refl <- sign_w `testEquality` (knownNat @1) - , Pair eb expt <- parseBVLitHelper expt_s - , Pair sb scand <- parseBVLitHelper scand_s + , Just (Pair eb expt) <- parseBVLitHelper expt_s + , Just (Pair sb scand) <- parseBVLitHelper scand_s = return $ ParsedFloatResult sign eb expt sb scand parseFloatLitHelper s@(SApp ["_", SAtom (Text.unpack -> nm), SAtom (Text.unpack -> eb_s), SAtom (Text.unpack -> sb_s)]) @@ -884,6 +919,341 @@ parseBvArraySolverValue w v (SApp ["store", arr, idx, val]) = do _ -> return Nothing parseBvArraySolverValue _ _ _ = return Nothing +parseFnModel :: + sym ~ B.ExprBuilder t st fs => + sym -> + WriterConn t h -> + [I.SomeSymFn sym] -> + SExp -> + IO (MapF (I.SymFnWrapper sym) (I.SymFnWrapper sym)) +parseFnModel = parseFns parseDefineFun + +parseFnValues :: + sym ~ B.ExprBuilder t st fs => + sym -> + WriterConn t h -> + [I.SomeSymFn sym] -> + SExp -> + IO (MapF (I.SymFnWrapper sym) (I.SymFnWrapper sym)) +parseFnValues = parseFns parseLambda + +parseFns :: + sym ~ B.ExprBuilder t st fs => + (sym -> SExp -> IO (Text, I.SomeSymFn sym)) -> + sym -> + WriterConn t h -> + [I.SomeSymFn sym] -> + SExp -> + IO (MapF (I.SymFnWrapper sym) (I.SymFnWrapper sym)) +parseFns parse_model_fn sym conn uninterp_fns sexp = do + fn_name_bimap <- cacheLookupFnNameBimap conn $ map (\(I.SomeSymFn fn) -> B.SomeExprSymFn fn) uninterp_fns + defined_fns <- case sexp of + SApp sexps -> Map.fromList <$> mapM (parse_model_fn sym) sexps + _ -> fail $ "Could not parse model response: " ++ show sexp + MapF.fromList <$> mapM + (\(I.SomeSymFn uninterp_fn) -> if + | Just nm <- Bimap.lookup (B.SomeExprSymFn uninterp_fn) fn_name_bimap + , Just (I.SomeSymFn defined_fn) <- Map.lookup nm defined_fns + , Just Refl <- testEquality (I.fnArgTypes uninterp_fn) (I.fnArgTypes defined_fn) + , Just Refl <- testEquality (I.fnReturnType uninterp_fn) (I.fnReturnType defined_fn) -> + return $ MapF.Pair (I.SymFnWrapper uninterp_fn) (I.SymFnWrapper defined_fn) + | otherwise -> fail $ "Could not find model for function: " ++ show uninterp_fn) + uninterp_fns + +parseDefineFun :: I.IsSymExprBuilder sym => sym -> SExp -> IO (Text, I.SomeSymFn sym) +parseDefineFun sym sexp = case sexp of + SApp ["define-fun", SAtom nm, SApp params_sexp, _ret_type_sexp , body_sexp] -> do + fn <- parseFn sym nm params_sexp body_sexp + return (nm, fn) + _ -> fail $ "unexpected sexp, expected define-fun, found " ++ show sexp + +parseLambda :: I.IsSymExprBuilder sym => sym -> SExp -> IO (Text, I.SomeSymFn sym) +parseLambda sym sexp = case sexp of + SApp [SAtom nm, SApp ["lambda", SApp params_sexp, body_sexp]] -> do + fn <- parseFn sym nm params_sexp body_sexp + return (nm, fn) + _ -> fail $ "unexpected sexp, expected lambda, found " ++ show sexp + +parseFn :: I.IsSymExprBuilder sym => sym -> Text -> [SExp] -> SExp -> IO (I.SomeSymFn sym) +parseFn sym nm params_sexp body_sexp = do + (nms, vars) <- unzip <$> mapM (parseVar sym) params_sexp + case Ctx.fromList vars of + Some vars_assign -> do + let let_env = HashMap.fromList $ zip nms $ map (mapSome $ I.varExpr sym) vars + proc_res <- runProcessor (ProcessorEnv { procSym = sym, procLetEnv = let_env }) $ parseExpr sym body_sexp + Some body_expr <- either fail return proc_res + I.SomeSymFn <$> I.definedFn sym (I.safeSymbol $ Text.unpack nm) vars_assign body_expr I.NeverUnfold + +parseVar :: I.IsSymExprBuilder sym => sym -> SExp -> IO (Text, Some (I.BoundVar sym)) +parseVar sym sexp = case sexp of + SApp [SAtom nm, tp_sexp] -> do + Some tp <- parseType tp_sexp + var <- liftIO $ I.freshBoundVar sym (I.safeSymbol $ Text.unpack nm) tp + return (nm, Some var) + _ -> fail $ "unexpected variable " ++ show sexp + +parseType :: SExp -> IO (Some BaseTypeRepr) +parseType sexp = case sexp of + "Bool" -> return $ Some BaseBoolRepr + "Int" -> return $ Some BaseIntegerRepr + "Real" -> return $ Some BaseRealRepr + SApp ["_", "BitVec", SAtom (Text.unpack -> m_str)] + | [(m_n, "")] <- readDec m_str + , Some m <- mkNatRepr m_n + , Just LeqProof <- testLeq (knownNat @1) m -> + return $ Some $ BaseBVRepr m + SApp ["_", "FloatingPoint", SAtom (Text.unpack -> eb_str), SAtom (Text.unpack -> sb_str)] + | [(eb_n, "")] <- readDec eb_str + , Some eb <- mkNatRepr eb_n + , Just LeqProof <- testLeq (knownNat @2) eb + , [(sb_n, "")] <- readDec sb_str + , Some sb <- mkNatRepr sb_n + , Just LeqProof <- testLeq (knownNat @2) sb -> + return $ Some $ BaseFloatRepr $ FloatingPointPrecisionRepr eb sb + SApp ["Array", idx_tp_sexp, val_tp_sexp] -> do + Some idx_tp <- parseType idx_tp_sexp + Some val_tp <- parseType val_tp_sexp + return $ Some $ BaseArrayRepr (Ctx.singleton idx_tp) val_tp + _ -> fail $ "unexpected type " ++ show sexp + + +-- | Stores a NatRepr along with proof that its type parameter is a bitvector of +-- that length. Used for easy pattern matching on the LHS of a binding in a +-- do-expression to extract the proof. +data BVProof tp where + BVProof :: forall n . (1 <= n) => NatRepr n -> BVProof (BaseBVType n) + +-- | Given an expression, monadically either returns proof that it is a +-- bitvector or throws an error. +getBVProof :: (I.IsExpr ex, MonadError String m) => ex tp -> m (BVProof tp) +getBVProof expr = case I.exprType expr of + BaseBVRepr n -> return $ BVProof n + t -> throwError $ "expected BV, found " ++ show t + +-- | Operator type descriptions for parsing s-expression of +-- the form @(operator operands ...)@. +-- +-- Code is copy-pasted and adapted from `What4.Serialize.Parser`, see +-- +data Op sym where + -- | Generic unary operator description. + Op1 :: + Ctx.Assignment BaseTypeRepr (Ctx.EmptyCtx Ctx.::> arg1) -> + (sym -> I.SymExpr sym arg1 -> IO (I.SymExpr sym ret)) -> + Op sym + -- | Generic binary operator description. + Op2 :: + Ctx.Assignment BaseTypeRepr (Ctx.EmptyCtx Ctx.::> arg1 Ctx.::> arg2) -> + Maybe Assoc -> + (sym -> I.SymExpr sym arg1 -> I.SymExpr sym arg2 -> IO (I.SymExpr sym ret)) -> + Op sym + -- | Encapsulating type for a unary operation that takes one bitvector and + -- returns another (in IO). + BVOp1 :: + (forall w . (1 <= w) => sym -> I.SymBV sym w -> IO (I.SymBV sym w)) -> + Op sym + -- | Binop with a bitvector return type, e.g., addition or bitwise operations. + BVOp2 :: + Maybe Assoc -> + (forall w . (1 <= w) => sym -> I.SymBV sym w -> I.SymBV sym w -> IO (I.SymBV sym w)) -> + Op sym + -- | Bitvector binop with a boolean return type, i.e., comparison operators. + BVComp2 :: + (forall w . (1 <= w) => sym -> I.SymBV sym w -> I.SymBV sym w -> IO (I.Pred sym)) -> + Op sym + +data Assoc = RightAssoc | LeftAssoc + +newtype Processor sym a = Processor (ExceptT String (ReaderT (ProcessorEnv sym) IO) a) + deriving (Functor, Applicative, Monad, MonadIO, MonadError String, MonadReader (ProcessorEnv sym)) + +data ProcessorEnv sym = ProcessorEnv + { procSym :: sym + , procLetEnv :: HashMap Text (Some (I.SymExpr sym)) + } + +runProcessor :: ProcessorEnv sym -> Processor sym a -> IO (Either String a) +runProcessor env (Processor action) = runReaderT (runExceptT action) env + +opTable :: I.IsSymExprBuilder sym => HashMap Text (Op sym) +opTable = HashMap.fromList + -- Boolean ops + [ ("not", Op1 knownRepr I.notPred) + , ("=>", Op2 knownRepr (Just RightAssoc) I.impliesPred) + , ("and", Op2 knownRepr (Just LeftAssoc) I.andPred) + , ("or", Op2 knownRepr (Just LeftAssoc) I.orPred) + , ("xor", Op2 knownRepr (Just LeftAssoc) I.xorPred) + -- Integer ops + , ("-", Op2 knownRepr (Just LeftAssoc) I.intSub) + , ("+", Op2 knownRepr (Just LeftAssoc) I.intAdd) + , ("*", Op2 knownRepr (Just LeftAssoc) I.intMul) + , ("div", Op2 knownRepr (Just LeftAssoc) I.intDiv) + , ("mod", Op2 knownRepr Nothing I.intMod) + , ("abs", Op1 knownRepr I.intAbs) + , ("<=", Op2 knownRepr Nothing I.intLe) + , ("<", Op2 knownRepr Nothing I.intLt) + , (">=", Op2 knownRepr Nothing $ \sym arg1 arg2 -> I.intLe sym arg2 arg1) + , (">", Op2 knownRepr Nothing $ \sym arg1 arg2 -> I.intLt sym arg2 arg1) + -- Bitvector ops + , ("bvnot", BVOp1 I.bvNotBits) + , ("bvneg", BVOp1 I.bvNeg) + , ("bvand", BVOp2 (Just LeftAssoc) I.bvAndBits) + , ("bvor", BVOp2 (Just LeftAssoc) I.bvOrBits) + , ("bvxor", BVOp2 (Just LeftAssoc) I.bvXorBits) + , ("bvadd", BVOp2 (Just LeftAssoc) I.bvAdd) + , ("bvsub", BVOp2 (Just LeftAssoc) I.bvSub) + , ("bvmul", BVOp2 (Just LeftAssoc) I.bvMul) + , ("bvudiv", BVOp2 Nothing I.bvUdiv) + , ("bvurem", BVOp2 Nothing I.bvUrem) + , ("bvshl", BVOp2 Nothing I.bvShl) + , ("bvlshr", BVOp2 Nothing I.bvLshr) + , ("bvsdiv", BVOp2 Nothing I.bvSdiv) + , ("bvsrem", BVOp2 Nothing I.bvSrem) + , ("bvashr", BVOp2 Nothing I.bvAshr) + , ("bvult", BVComp2 I.bvUlt) + , ("bvule", BVComp2 I.bvUle) + , ("bvugt", BVComp2 I.bvUgt) + , ("bvuge", BVComp2 I.bvUge) + , ("bvslt", BVComp2 I.bvSlt) + , ("bvsle", BVComp2 I.bvSle) + , ("bvsgt", BVComp2 I.bvSgt) + , ("bvsge", BVComp2 I.bvSge) + ] + +parseExpr :: + forall sym . I.IsSymExprBuilder sym => sym -> SExp -> Processor sym (Some (I.SymExpr sym)) +parseExpr sym sexp = case sexp of + "true" -> return $ Some $ I.truePred sym + "false" -> return $ Some $ I.falsePred sym + _ | Just i <- parseIntSolverValue sexp -> + liftIO $ Some <$> I.intLit sym i + | Just (Pair w bv) <- parseBVLitHelper sexp + , Just LeqProof <- testLeq (knownNat @1) w -> + liftIO $ Some <$> I.bvLit sym w bv + SAtom nm -> do + env <- asks procLetEnv + case HashMap.lookup nm env of + Just expr -> return $ expr + Nothing -> throwError "" + SApp ["let", SApp bindings_sexp, body_sexp] -> do + let_env <- HashMap.fromList <$> mapM + (\case + SApp [SAtom nm, expr_sexp] -> do + Some expr <- parseExpr sym expr_sexp + return (nm, Some expr) + _ -> throwError "") + bindings_sexp + local (\prov_env -> prov_env { procLetEnv = HashMap.union let_env (procLetEnv prov_env) }) $ + parseExpr sym body_sexp + SApp ["=", arg1, arg2] -> do + Some arg1_expr <- parseExpr sym arg1 + Some arg2_expr <- parseExpr sym arg2 + case testEquality (I.exprType arg1_expr) (I.exprType arg2_expr) of + Just Refl -> liftIO (Some <$> I.isEq sym arg1_expr arg2_expr) + Nothing -> throwError "" + SApp ["ite", arg1, arg2, arg3] -> do + Some arg1_expr <- parseExpr sym arg1 + Some arg2_expr <- parseExpr sym arg2 + Some arg3_expr <- parseExpr sym arg3 + case I.exprType arg1_expr of + I.BaseBoolRepr -> case testEquality (I.exprType arg2_expr) (I.exprType arg3_expr) of + Just Refl -> liftIO (Some <$> I.baseTypeIte sym arg1_expr arg2_expr arg3_expr) + Nothing -> throwError "" + _ -> throwError "" + SApp ["concat", arg1, arg2] -> do + Some arg1_expr <- parseExpr sym arg1 + Some arg2_expr <- parseExpr sym arg2 + BVProof{} <- getBVProof arg1_expr + BVProof{} <- getBVProof arg2_expr + liftIO $ Some <$> I.bvConcat sym arg1_expr arg2_expr + SApp ((SAtom operator) : operands) -> case HashMap.lookup operator (opTable @sym) of + Just (Op1 arg_types fn) -> do + args <- mapM (parseExpr sym) operands + exprAssignment arg_types args >>= \case + Ctx.Empty Ctx.:> arg1 -> + liftIO (Some <$> fn sym arg1) + Just (Op2 arg_types _ fn) -> do + args <- mapM (parseExpr sym) operands + exprAssignment arg_types args >>= \case + Ctx.Empty Ctx.:> arg1 Ctx.:> arg2 -> + liftIO (Some <$> fn sym arg1 arg2) + Just (BVOp1 op) -> do + Some arg_expr <- readOneArg sym operands + BVProof{} <- getBVProof arg_expr + liftIO $ Some <$> op sym arg_expr + Just (BVOp2 _ op) -> do + (Some arg1, Some arg2) <- readTwoArgs sym operands + BVProof m <- prefixError "in arg 1: " $ getBVProof arg1 + BVProof n <- prefixError "in arg 2: " $ getBVProof arg2 + case testEquality m n of + Just Refl -> liftIO (Some <$> op sym arg1 arg2) + Nothing -> throwError $ printf "arguments to %s must be the same length, \ + \but arg 1 has length %s \ + \and arg 2 has length %s" + operator + (show m) + (show n) + Just (BVComp2 op) -> do + (Some arg1, Some arg2) <- readTwoArgs sym operands + BVProof m <- prefixError "in arg 1: " $ getBVProof arg1 + BVProof n <- prefixError "in arg 2: " $ getBVProof arg2 + case testEquality m n of + Just Refl -> liftIO (Some <$> op sym arg1 arg2) + Nothing -> throwError $ printf "arguments to %s must be the same length, \ + \but arg 1 has length %s \ + \and arg 2 has length %s" + operator + (show m) + (show n) + _ -> throwError "" + _ -> throwError "" +-- | Verify a list of arguments has a single argument and +-- return it, else raise an error. +readOneArg :: + I.IsSymExprBuilder sym + => sym + -> [SExp] + -> Processor sym (Some (I.SymExpr sym)) +readOneArg sym operands = do + args <- mapM (parseExpr sym) operands + case args of + [arg] -> return arg + _ -> throwError $ printf "expecting 1 argument, got %d" (length args) + +-- | Verify a list of arguments has two arguments and return +-- it, else raise an error. +readTwoArgs :: + I.IsSymExprBuilder sym + => sym + ->[SExp] + -> Processor sym (Some (I.SymExpr sym), Some (I.SymExpr sym)) +readTwoArgs sym operands = do + args <- mapM (parseExpr sym) operands + case args of + [arg1, arg2] -> return (arg1, arg2) + _ -> throwError $ printf "expecting 2 arguments, got %d" (length args) + +exprAssignment :: + forall sym ctx ex . (I.IsSymExprBuilder sym, I.IsExpr ex) + => Ctx.Assignment BaseTypeRepr ctx + -> [Some ex] + -> Processor sym (Ctx.Assignment ex ctx) +exprAssignment tpAssns exs = do + Some exsAsn <- return $ Ctx.fromList exs + exsRepr <- return $ fmapFC I.exprType exsAsn + case testEquality exsRepr tpAssns of + Just Refl -> return exsAsn + Nothing -> throwError $ + "Unexpected expression types for " -- ++ show exsAsn + ++ "\nExpected: " ++ show tpAssns + ++ "\nGot: " ++ show exsRepr + +-- | Utility function for contextualizing errors. Prepends the given prefix +-- whenever an error is thrown. +prefixError :: (Monoid e, MonadError e m) => e -> m a -> m a +prefixError prefix act = catchError act (throwError . mappend prefix) + + ------------------------------------------------------------------------ -- Session @@ -1145,16 +1515,28 @@ writeDefaultSMT2 :: SMTLib2Tweaks a -> [B.BoolExpr t] -> IO () writeDefaultSMT2 a nm feat strictOpt sym h ps = do + c <- defaultFileWriter a nm feat strictOpt sym h + setProduceModels c True + forM_ ps (SMTWriter.assume c) + writeCheckSat c + writeExit c + +defaultFileWriter :: + SMTLib2Tweaks a => + a -> + String -> + ProblemFeatures -> + Maybe (CFG.ConfigOption I.BaseBoolType) -> + B.ExprBuilder t st fs -> + IO.Handle -> + IO (WriterConn t (Writer a)) +defaultFileWriter a nm feat strictOpt sym h = do bindings <- B.getSymbolVarBimap sym str <- Streams.encodeUtf8 =<< Streams.handleToOutputStream h null_in <- Streams.nullInput let cfg = I.getConfiguration sym strictness <- parserStrictness strictOpt strictSMTParsing cfg - c <- newWriter a str null_in nullAcknowledgementAction strictness nm True feat True bindings - setProduceModels c True - forM_ ps (SMTWriter.assume c) - writeCheckSat c - writeExit c + newWriter a str null_in nullAcknowledgementAction strictness nm True feat True bindings -- n.b. commonly used for the startSolverProcess method of the -- OnlineSolver class, so it's helpful for the type suffixes to align diff --git a/what4/src/What4/Protocol/SMTLib2/Response.hs b/what4/src/What4/Protocol/SMTLib2/Response.hs index b09bdc1a..c52921b1 100644 --- a/what4/src/What4/Protocol/SMTLib2/Response.hs +++ b/what4/src/What4/Protocol/SMTLib2/Response.hs @@ -71,6 +71,8 @@ data SMTResponse = AckSuccess | AckSat | AckUnsat | AckUnknown + | AckInfeasible -- SyGuS response + | AckFail -- SyGuS response | RspName Text | RspVersion Text | RspErrBehavior Text @@ -144,10 +146,12 @@ rspParser strictness = parens p = AT.char '(' *> p <* AT.char ')' errParser = parens $ lexeme (AT.string "error") *> (AckError <$> lexeme parseSMTLib2String) - specific_success_response = check_sat_response <|> get_info_response + specific_success_response = check_sat_response <|> check_synth_response <|> get_info_response check_sat_response = (AckSat <$ AT.string "sat") <|> (AckUnsat <$ AT.string "unsat") <|> (AckUnknown <$ AT.string "unknown") + check_synth_response = (AckInfeasible <$ AT.string "infeasible") + <|> (AckFail <$ AT.string "fail") get_info_response = parens info_response info_response = errBhvParser <|> nameParser diff --git a/what4/src/What4/Protocol/SMTLib2/Syntax.hs b/what4/src/What4/Protocol/SMTLib2/Syntax.hs index c86c8ec9..64615905 100644 --- a/what4/src/What4/Protocol/SMTLib2/Syntax.hs +++ b/what4/src/What4/Protocol/SMTLib2/Syntax.hs @@ -51,11 +51,17 @@ module What4.Protocol.SMTLib2.Syntax , getUnsatCore , getAbduct , getAbductNext + -- * SyGuS + , synthFun + , declareVar + , constraint + , checkSynth -- * Logic , Logic(..) , qf_bv , allSupported , allLogic + , hornLogic -- * Sort , Sort(..) , boolSort @@ -192,7 +198,11 @@ allSupported = Logic "ALL_SUPPORTED" -- | Set the logic to all supported logics. allLogic :: Logic allLogic = Logic "ALL" - + +-- | Use the Horn logic +hornLogic :: Logic +hornLogic = Logic "HORN" + ------------------------------------------------------------------------ -- Symbol @@ -836,6 +846,29 @@ getAbduct nm p = Cmd $ "(get-abduct " <> Builder.fromText nm <> " " <> renderTer getAbductNext :: Command getAbductNext = Cmd "(get-abduct-next)" +-- | Declare a SyGuS function to synthesize with the given name, arguments, and +-- return type. +synthFun :: Text -> [(Text, Sort)] -> Sort -> Command +synthFun f args ret_tp = Cmd $ app "synth-fun" + [ Builder.fromText f + , builder_list $ map (\(var, tp) -> app (Builder.fromText var) [unSort tp]) args + , unSort ret_tp + ] + +-- | Declare a SyGuS variable with the given name and type. +declareVar :: Text -> Sort -> Command +declareVar v tp = Cmd $ app "declare-var" [Builder.fromText v, unSort tp] + +-- | Add the SyGuS constraint to the current synthesis problem. +constraint :: Term -> Command +constraint p = Cmd $ app "constraint" [renderTerm p] + +-- | Ask the SyGuS solver to find a solution for the synthesis problem +-- corresponding to the current functions-to-synthesize, variables and +-- constraints. +checkSynth :: Command +checkSynth = Cmd "(check-synth)\n" + -- | Get the values associated with the terms from the last call to @check-sat@. getValue :: [Term] -> Command getValue values = Cmd $ app "get-value" [builder_list (renderTerm <$> values)] diff --git a/what4/src/What4/Protocol/SMTWriter.hs b/what4/src/What4/Protocol/SMTWriter.hs index 77b9cd46..bb41de6f 100644 --- a/what4/src/What4/Protocol/SMTWriter.hs +++ b/what4/src/What4/Protocol/SMTWriter.hs @@ -69,6 +69,7 @@ module What4.Protocol.SMTWriter , entryStackHeight , pushEntryStack , popEntryStack + , cacheLookupFnNameBimap , Command , addCommand , addCommandNoAck @@ -86,6 +87,10 @@ module What4.Protocol.SMTWriter , ResponseStrictness(..) , parserStrictness , nullAcknowledgementAction + -- * SyGuS + , addSynthFun + , addDeclareVar + , addConstraint -- * SMTWriter operations , assume , mkSMTTerm @@ -111,6 +116,8 @@ import Control.Monad.Reader import Control.Monad.ST import Control.Monad.State.Strict import Control.Monad.Trans.Maybe +import Data.Bimap (Bimap) +import qualified Data.Bimap as Bimap import qualified Data.BitVector.Sized as BV import qualified Data.Bits as Bits import Data.IORef @@ -820,6 +827,15 @@ cacheValueFn cacheValueFn conn n lifetime value = cacheValue conn lifetime $ \entry -> stToIO $ PH.insert (symFnCache entry) n value +cacheLookupFnNameBimap :: WriterConn t h -> [SomeExprSymFn t] -> IO (Bimap (SomeExprSymFn t) Text) +cacheLookupFnNameBimap conn fns = Bimap.fromList <$> mapM + (\some_fn@(SomeExprSymFn fn) -> do + maybe_smt_sym_fn <- cacheLookupFn conn $ symFnId fn + case maybe_smt_sym_fn of + Just (SMTSymFn nm _ _) -> return (some_fn, nm) + Nothing -> fail $ "Could not find function in cache: " ++ show fn) + fns + -- | Run state with handle. withWriterState :: WriterConn t h -> State WriterState a -> IO a withWriterState c m = do @@ -931,6 +947,23 @@ class (SupportTermOps (Term h)) => SMTWriter h where -> Term h -> Command h + -- | Declare a new SyGuS function to synthesize with the given name, + -- arguments, and result type. + synthFunCommand :: f h + -> Text + -> [(Text, Some TypeMap)] + -> TypeMap tp + -> Command h + + -- | Declare a new SyGuS universal variables with the given name and type. + declareVarCommand :: f h + -> Text + -> TypeMap tp + -> Command h + + -- | Add a SyGuS formula to the set of synthesis constraints. + constraintCommand :: f h -> Term h -> Command h + -- | Declare a struct datatype if is has not been already given the number of -- arguments in the struct. declareStructDatatype :: WriterConn t h -> Ctx.Assignment TypeMap args -> IO () @@ -1054,6 +1087,63 @@ assumeFormulaWithFreshName conn p = assumeFormulaWithName conn p var return var +addSynthFun :: + SMTWriter h => + WriterConn t h -> + ExprSymFn t args ret -> + IO () +addSynthFun conn fn = + cacheLookupFn conn (symFnId fn) >>= \case + Just{} -> + fail $ "Internal error in SMTLIB exporter: function already declared." + ++ show (symFnId fn) ++ " declared at " + ++ show (plSourceLoc (symFnLoc fn)) ++ "." + Nothing -> case symFnInfo fn of + UninterpFnInfo arg_types ret_type -> do + nm <- getSymbolName conn (FnSymbolBinding fn) + let fn_source = fnSource (symFnName fn) (symFnLoc fn) + smt_arg_types <- traverseFC (evalFirstClassTypeRepr conn fn_source) arg_types + checkArgumentTypes conn smt_arg_types + smt_ret_type <- evalFirstClassTypeRepr conn fn_source ret_type + traverseFC_ (declareTypes conn) smt_arg_types + declareTypes conn smt_ret_type + smt_args <- mapM + (\(Some tp) -> do + var <- withWriterState conn $ freshVarName + return (var, Some tp)) + (toListFC Some smt_arg_types) + addCommand conn $ synthFunCommand conn nm smt_args smt_ret_type + cacheValueFn conn (symFnId fn) DeleteNever $! SMTSymFn nm smt_arg_types smt_ret_type + DefinedFnInfo{} -> + fail $ "Internal error in SMTLIB exporter: defined functions cannot be synthesized." + MatlabSolverFnInfo{} -> + fail $ "Internal error in SMTLIB exporter: MatlabSolver functions cannot be synthesized." + +addDeclareVar :: + SMTWriter h => + WriterConn t h -> + ExprBoundVar t tp -> + IO () +addDeclareVar conn var = + cacheLookupExpr conn (bvarId var) >>= \case + Just{} -> + fail $ "Internal error in SMTLIB exporter: variable already declared." + ++ show (bvarId var) ++ " declared at " + ++ show (plSourceLoc (bvarLoc var)) ++ "." + Nothing -> do + nm <- getSymbolName conn (VarSymbolBinding var) + let fn_source = fnSource (bvarName var) (bvarLoc var) + smt_type <- evalFirstClassTypeRepr conn fn_source $ bvarType var + declareTypes conn smt_type + addCommand conn $ declareVarCommand conn nm smt_type + cacheValueExpr conn (bvarId var) DeleteNever $! SMTName smt_type nm + +addConstraint :: SMTWriter h => WriterConn t h -> BoolExpr t -> IO () +addConstraint conn p = do + f <- mkFormula conn p + updateProgramLoc conn (exprLoc p) + addCommand conn $ constraintCommand conn f + -- | Perform any necessary declarations to ensure that the mentioned type map -- sorts exist in the solver environment. declareTypes :: diff --git a/what4/src/What4/Solver/CVC5.hs b/what4/src/What4/Solver/CVC5.hs index 0e55c33a..1c0d0ea8 100644 --- a/what4/src/What4/Solver/CVC5.hs +++ b/what4/src/What4/Solver/CVC5.hs @@ -11,6 +11,7 @@ ------------------------------------------------------------------------ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -26,6 +27,9 @@ module What4.Solver.CVC5 , withCVC5 , writeCVC5SMT2File , writeMultiAsmpCVC5SMT2File + , runCVC5SyGuS + , withCVC5_SyGuS + , writeCVC5SyFile ) where import Control.Monad (forM_, when) @@ -34,6 +38,8 @@ import Data.String import System.IO import qualified System.IO.Streams as Streams +import Data.Parameterized.Map (MapF) +import Data.Parameterized.Some import What4.BaseTypes import What4.Concrete import What4.Config @@ -244,3 +250,130 @@ instance OnlineSolver (SMT2.Writer CVC5) where timeout feat (Just cvc5StrictParsing) mbIOh sym shutdownSolverProcess = SMT2.shutdownSolver CVC5 + + +-- | `CVC5_SyGuS` implements a `SMT2.SMTLib2GenericSolver` instance that is +-- different from `CVC5` in that it provides SyGuS specific implementations for +-- `defaultSolverArgs` and `setDefaultLogicAndOptions`. +data CVC5_SyGuS = CVC5_SyGuS deriving Show + +instance SMT2.SMTLib2Tweaks CVC5_SyGuS where + smtlib2tweaks = CVC5_SyGuS + + smtlib2arrayType = SMT2.smtlib2arrayType @CVC5 + + smtlib2arrayConstant = SMT2.smtlib2arrayConstant @CVC5 + smtlib2arraySelect = SMT2.smtlib2arraySelect @CVC5 + smtlib2arrayUpdate = SMT2.smtlib2arrayUpdate @CVC5 + + smtlib2declareStructCmd = SMT2.smtlib2declareStructCmd @CVC5 + smtlib2StructSort = SMT2.smtlib2StructSort @CVC5 + smtlib2StructCtor = SMT2.smtlib2StructCtor @CVC5 + smtlib2StructProj = SMT2.smtlib2StructProj @CVC5 + +instance SMT2.SMTLib2GenericSolver CVC5_SyGuS where + defaultSolverPath _ = SMT2.defaultSolverPath CVC5 + + defaultSolverArgs _ sym = do + let cfg = getConfiguration sym + timeout <- getOption =<< getOptionSetting cvc5Timeout cfg + let extraOpts = case timeout of + Just (ConcreteInteger n) | n > 0 -> ["--tlimit-per=" ++ show n] + _ -> [] + return $ ["--sygus", "--lang", "sygus2", "--strings-exp", "--fp-exp"] ++ extraOpts + + getErrorBehavior _ = SMT2.queryErrorBehavior + + defaultFeatures _ = SMT2.defaultFeatures CVC5 + + supportsResetAssertions _ = SMT2.supportsResetAssertions CVC5 + + setDefaultLogicAndOptions writer = do + -- Tell cvc5 to use all supported logics. + SMT2.setLogic writer Syntax.allLogic + +-- | Find a solution to a Syntax-Guided Synthesis (SyGuS) problem. +-- +-- For more information, see the [SyGuS standard](https://sygus.org/). +runCVC5SyGuS :: + sym ~ ExprBuilder t st fs => + sym -> + LogData -> + [SomeSymFn sym] -> + [BoolExpr t] -> + IO (SatResult (MapF (SymFnWrapper sym) (SymFnWrapper sym)) ()) +runCVC5SyGuS sym log_data synth_fns constraints = do + logSolverEvent sym + (SolverStartSATQuery $ SolverStartSATQueryRec + { satQuerySolverName = show CVC5_SyGuS + , satQueryReason = logReason log_data + }) + + path <- SMT2.defaultSolverPath CVC5_SyGuS sym + withCVC5_SyGuS sym path (log_data { logVerbosity = 2 }) $ \session -> do + writeSyGuSProblem sym (SMT2.sessionWriter session) synth_fns constraints + result <- RSP.getLimitedSolverResponse "check-synth" + (\case + RSP.AckSuccessSExp sexp -> Just $ Sat sexp + RSP.AckInfeasible -> Just $ Unsat () + RSP.AckFail -> Just Unknown + _ -> Nothing) + (SMT2.sessionWriter session) + Syntax.checkSynth + + logSolverEvent sym + (SolverEndSATQuery $ SolverEndSATQueryRec + { satQueryResult = forgetModelAndCore result + , satQueryError = Nothing + }) + + traverseSatResult + (\sexp -> SMT2.parseFnModel sym (SMT2.sessionWriter session) synth_fns sexp) + return + result + +-- | Run CVC5 SyGuS in a session, with the default configuration. +withCVC5_SyGuS :: + ExprBuilder t st fs -> + FilePath -> + LogData -> + (SMT2.Session t CVC5_SyGuS -> IO a) -> + IO a +withCVC5_SyGuS = + SMT2.withSolver + CVC5_SyGuS + nullAcknowledgementAction + (SMT2.defaultFeatures CVC5_SyGuS) + (Just cvc5StrictParsing) + +writeCVC5SyFile :: + sym ~ ExprBuilder t st fs => + sym -> + Handle -> + [SomeSymFn sym] -> + [BoolExpr t] -> + IO () +writeCVC5SyFile sym h synth_fns constraints = do + writer <- SMT2.defaultFileWriter + CVC5_SyGuS + (show CVC5_SyGuS) + (SMT2.defaultFeatures CVC5_SyGuS) + (Just cvc5StrictParsing) + sym + h + SMT2.setDefaultLogicAndOptions writer + writeSyGuSProblem sym writer synth_fns constraints + SMT2.writeExit writer + +writeSyGuSProblem :: + sym ~ ExprBuilder t st fs => + sym -> + WriterConn t (SMT2.Writer CVC5_SyGuS) -> + [SomeSymFn sym] -> + [BoolExpr t] -> + IO () +writeSyGuSProblem sym writer synth_fns constraints = do + mapM_ (\(SomeSymFn fn) -> addSynthFun writer fn) synth_fns + mapM_ (viewSome $ addDeclareVar writer) $ foldMap (exprUninterpConstants sym) constraints + mapM_ (addConstraint writer) constraints + SMT2.writeCheckSynth writer diff --git a/what4/src/What4/Solver/Yices.hs b/what4/src/What4/Solver/Yices.hs index c19f7546..ea15a38e 100644 --- a/what4/src/What4/Solver/Yices.hs +++ b/what4/src/What4/Solver/Yices.hs @@ -528,6 +528,10 @@ instance SMTWriter Connection where , renderTerm (yicesLambda args t) ] + synthFunCommand _ _ _ _ = unsupportedFeature "SyGuS" + declareVarCommand _ _ _ = unsupportedFeature "SyGuS" + constraintCommand _ _ = unsupportedFeature "SyGuS" + resetDeclaredStructs conn = resetUnitType conn structProj _n i s = term_app "select" [s, fromIntegral (Ctx.indexVal i + 1)] diff --git a/what4/src/What4/Solver/Z3.hs b/what4/src/What4/Solver/Z3.hs index bbb34611..84db2037 100644 --- a/what4/src/What4/Solver/Z3.hs +++ b/what4/src/What4/Solver/Z3.hs @@ -10,6 +10,7 @@ -- Z3-specific tweaks to the basic SMTLib2 solver interface. ------------------------------------------------------------------------ {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -27,15 +28,21 @@ module What4.Solver.Z3 , runZ3InOverride , withZ3 , writeZ3SMT2File + , runZ3Horn + , writeZ3HornSMT2File ) where import Control.Monad ( when ) +import qualified Data.Bimap as Bimap import Data.Bits +import Data.Foldable import Data.String import Data.Text (Text) import qualified Data.Text as T import System.IO +import Data.Parameterized.Map (MapF) +import Data.Parameterized.Some import What4.BaseTypes import What4.Concrete import What4.Config @@ -46,7 +53,8 @@ import What4.ProblemFeatures import What4.Protocol.Online import qualified What4.Protocol.SMTLib2 as SMT2 import What4.Protocol.SMTLib2.Response ( strictSMTParseOpt ) -import qualified What4.Protocol.SMTLib2.Syntax as SMT2Syntax +import qualified What4.Protocol.SMTLib2.Response as RSP +import qualified What4.Protocol.SMTLib2.Syntax as Syntax import What4.Protocol.SMTWriter import What4.SatResult import What4.Solver.Adapter @@ -139,7 +147,7 @@ instance SMT2.SMTLib2Tweaks Z3 where fields = field_def <$> [1..n] decl = app tp [app ctor fields] decls = "(" <> decl <> ")" - in SMT2Syntax.Cmd $ app "declare-datatypes" [ params, decls ] + in Syntax.Cmd $ app "declare-datatypes" [ params, decls ] z3Features :: ProblemFeatures z3Features = useNonlinearArithmetic @@ -235,3 +243,87 @@ instance OnlineSolver (SMT2.Writer Z3) where timeout feat (Just z3StrictParsing) mbIOh sym shutdownSolverProcess = SMT2.shutdownSolver Z3 + +-- | Check the satisfiability of a set of constrained Horn clauses (CHCs). +-- +-- CHCs are represented as pure SMT-LIB2 implications. For more information, see +-- the [Z3 guide](https://microsoft.github.io/z3guide/docs/fixedpoints/intro/). +runZ3Horn :: + sym ~ ExprBuilder t st fs => + sym -> + LogData -> + [SomeSymFn sym] -> + [BoolExpr t] -> + IO (SatResult (MapF (SymFnWrapper sym) (SymFnWrapper sym)) ()) +runZ3Horn sym log_data inv_fns horn_clauses = do + logSolverEvent sym + (SolverStartSATQuery $ SolverStartSATQueryRec + { satQuerySolverName = show Z3 + , satQueryReason = logReason log_data + }) + + path <- SMT2.defaultSolverPath Z3 sym + withZ3 sym path (log_data { logVerbosity = 2 }) $ \session -> do + writeHornProblem sym (SMT2.sessionWriter session) inv_fns horn_clauses + result <- RSP.getLimitedSolverResponse "check-sat" + (\case + RSP.AckSat -> Just $ Sat () + RSP.AckUnsat -> Just $ Unsat () + RSP.AckUnknown -> Just Unknown + _ -> Nothing) + (SMT2.sessionWriter session) + Syntax.checkSat + + logSolverEvent sym + (SolverEndSATQuery $ SolverEndSATQueryRec + { satQueryResult = result + , satQueryError = Nothing + }) + + traverseSatResult + (\() -> do + sexp <- RSP.getLimitedSolverResponse "get-value" + (\case + RSP.AckSuccessSExp sexp -> Just sexp + _ -> Nothing) + (SMT2.sessionWriter session) + (Syntax.getValue []) + SMT2.parseFnValues sym (SMT2.sessionWriter session) inv_fns sexp) + return + result + +writeZ3HornSMT2File :: + sym ~ ExprBuilder t st fs => + sym -> + Handle -> + [SomeSymFn sym] -> + [BoolExpr t] -> + IO () +writeZ3HornSMT2File sym h inv_fns horn_clauses = do + writer <- SMT2.defaultFileWriter + Z3 + (show Z3) + (SMT2.defaultFeatures Z3) + (Just z3StrictParsing) + sym + h + SMT2.setDefaultLogicAndOptions writer + writeHornProblem sym writer inv_fns horn_clauses + SMT2.writeExit writer + +writeHornProblem :: + sym ~ ExprBuilder t st fs => + sym -> + WriterConn t (SMT2.Writer Z3) -> + [SomeSymFn sym] -> + [BoolExpr t] -> + IO () +writeHornProblem sym writer inv_fns horn_clauses = do + SMT2.setLogic writer Syntax.hornLogic + implications <- mapM + (\clause -> foldrM (viewSome $ forallPred sym) clause $ exprUninterpConstants sym clause) + horn_clauses + mapM_ (SMT2.assume writer) implications + SMT2.writeCheckSat writer + fn_name_bimap <- cacheLookupFnNameBimap writer $ map (\(SomeSymFn fn) -> SomeExprSymFn fn) inv_fns + SMT2.writeGetValue writer $ map fromText $ Bimap.elems fn_name_bimap diff --git a/what4/test/InvariantSynthesis.hs b/what4/test/InvariantSynthesis.hs new file mode 100644 index 00000000..32b995fd --- /dev/null +++ b/what4/test/InvariantSynthesis.hs @@ -0,0 +1,152 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeApplications #-} + +import ProbeSolvers +import Test.Tasty +import Test.Tasty.ExpectedFailure +import Test.Tasty.HUnit + +import Data.Maybe +import System.Environment + +import qualified Data.BitVector.Sized as BV +import Data.Parameterized.Context +import Data.Parameterized.Map (MapF) +import Data.Parameterized.Nonce + +import What4.Config +import What4.Expr +import What4.Interface +import What4.SatResult +import What4.Solver.Adapter +import qualified What4.Solver.CVC5 as CVC5 +import qualified What4.Solver.Z3 as Z3 + +type SimpleExprBuilder t fs = ExprBuilder t EmptyExprBuilderState fs + +logData :: LogData +logData = defaultLogData { logCallbackVerbose = (\_ -> putStrLn) } + +withSym :: FloatModeRepr fm -> (forall t . SimpleExprBuilder t (Flags fm) -> IO a) -> IO a +withSym float_mode action = withIONonceGenerator $ \gen -> do + sym <- newExprBuilder float_mode EmptyExprBuilderState gen + extendConfig CVC5.cvc5Options (getConfiguration sym) + extendConfig Z3.z3Options (getConfiguration sym) + action sym + +intProblem :: IsSymExprBuilder sym => sym -> IO ([SomeSymFn sym], [Pred sym], Pred sym) +intProblem sym = do + inv <- freshTotalUninterpFn sym (safeSymbol "inv") knownRepr knownRepr + i <- freshConstant sym (safeSymbol "i") knownRepr + n <- freshConstant sym (safeSymbol "n") knownRepr + zero <- intLit sym 0 + one <- intLit sym 1 + lt_1_n <- intLt sym one n + inv_0_n <- applySymFn sym inv $ Empty :> zero :> n + -- 1 < n ==> inv(0, n) + impl0 <- impliesPred sym lt_1_n inv_0_n + inv_i_n <- applySymFn sym inv $ Empty :> i :> n + add_i_1 <- intAdd sym i one + lt_add_i_1_n <- intLt sym add_i_1 n + conj0 <- andPred sym inv_i_n lt_add_i_1_n + inv_add_i_1_n <- applySymFn sym inv $ Empty :> add_i_1 :> n + -- inv(i, n) /\ i+1 < n ==> inv(i+1, n) + impl1 <- impliesPred sym conj0 inv_add_i_1_n + le_0_i <- intLe sym zero i + lt_i_n <- intLt sym i n + conj1 <- andPred sym le_0_i lt_i_n + -- inv(i, n) ==> 0 <= i /\ i < n + impl2 <- impliesPred sym inv_i_n conj1 + + -- inv(i, n) /\ not (i + 1 < n) ==> i + 1 == n + not_lt_add_i_1_n <- notPred sym lt_add_i_1_n + conj2 <- andPred sym inv_i_n not_lt_add_i_1_n + eq_add_i_1_n <- intEq sym add_i_1 n + impl3 <- notPred sym =<< impliesPred sym conj2 eq_add_i_1_n + + return ([SomeSymFn inv], [impl0, impl1, impl2], impl3) + +bvProblem :: IsSymExprBuilder sym => sym -> IO ([SomeSymFn sym], [Pred sym], Pred sym) +bvProblem sym = do + inv <- freshTotalUninterpFn sym (safeSymbol "inv") knownRepr knownRepr + i <- freshConstant sym (safeSymbol "i") $ BaseBVRepr $ knownNat @64 + n <- freshConstant sym (safeSymbol "n") knownRepr + zero <- bvLit sym knownNat $ BV.zero knownNat + one <- bvLit sym knownNat $ BV.one knownNat + ult_1_n <- bvUlt sym one n + inv_0_n <- applySymFn sym inv $ Empty :> zero :> n + -- 1 < n ==> inv(0, n) + impl0 <- impliesPred sym ult_1_n inv_0_n + inv_i_n <- applySymFn sym inv $ Empty :> i :> n + add_i_1 <- bvAdd sym i one + ult_add_i_1_n <- bvUlt sym add_i_1 n + conj0 <- andPred sym inv_i_n ult_add_i_1_n + inv_add_i_1_n <- applySymFn sym inv $ Empty :> add_i_1 :> n + -- inv(i, n) /\ i+1 < n ==> inv(i+1, n) + impl1 <- impliesPred sym conj0 inv_add_i_1_n + ule_0_i <- bvUle sym zero i -- trivially true, here for similarity with int test + ult_i_n <- bvUlt sym i n + conj1 <- andPred sym ule_0_i ult_i_n + -- inv(i, n) ==> 0 <= i /\ i < n + impl2 <- impliesPred sym inv_i_n conj1 + + -- inv(i, n) /\ not (i + 1 < n) ==> i + 1 == n + not_ult_add_i_1_n <- notPred sym ult_add_i_1_n + conj2 <- andPred sym inv_i_n not_ult_add_i_1_n + eq_add_i_1_n <- bvEq sym add_i_1 n + impl3 <- notPred sym =<< impliesPred sym conj2 eq_add_i_1_n + + return ([SomeSymFn inv], [impl0, impl1, impl2], impl3) + +synthesis_test :: + String -> + (forall sym . IsSymExprBuilder sym => sym -> IO ([SomeSymFn sym], [Pred sym], Pred sym)) -> + String -> + (forall sym t fs . + sym ~ SimpleExprBuilder t fs => + sym -> + LogData -> + [SomeSymFn sym] -> + [BoolExpr t] -> + IO (SatResult (MapF (SymFnWrapper sym) (SymFnWrapper sym)) ())) -> + (forall t fs a . + SimpleExprBuilder t fs -> + LogData -> + [BoolExpr t] -> + (SatResult (GroundEvalFn t, Maybe (ExprRangeBindings t)) () -> IO a) -> + IO a) -> + TestTree +synthesis_test test_name synthesis_problem solver_name run_solver_synthesis run_solver_in_override = + testCase (test_name ++ " " ++ solver_name ++ " test") $ withSym FloatIEEERepr $ \sym -> do + (synth_fns, constraints, goal) <- synthesis_problem sym + + run_solver_in_override sym logData [goal] $ \res -> isSat res @? "sat" + + subst <- run_solver_synthesis sym logData synth_fns constraints >>= \case + Sat res -> return res + Unsat{} -> fail "Infeasible" + Unknown -> fail "Fail" + + goal' <- substituteSymFns sym subst goal + run_solver_in_override sym logData [goal'] $ \res -> isUnsat res @? "unsat" + +main :: IO () +main = do + testLevel <- TestLevel . fromMaybe "0" <$> lookupEnv "CI_TEST_LEVEL" + let solverNames = map SolverName [ "cvc5", "z3" ] + solvers <- reportSolverVersions testLevel id + =<< (zip solverNames <$> mapM getSolverVersion solverNames) + let skipPre4_8_9 why = + let shouldSkip = case lookup (SolverName "z3") solvers of + Just (SolverVersion v) -> any (`elem` [ "4.8.8" ]) $ words v + Nothing -> True + in if shouldSkip then expectFailBecause why else id + failureZ3 = "failure with older Z3 versions; upgrade to at least 4.8.9" + defaultMain $ testGroup "Tests" $ + [ synthesis_test "int" intProblem "cvc5" CVC5.runCVC5SyGuS CVC5.runCVC5InOverride + , skipPre4_8_9 failureZ3 $ synthesis_test "int" intProblem "z3" Z3.runZ3Horn Z3.runZ3InOverride + , synthesis_test "bv" bvProblem "cvc5" CVC5.runCVC5SyGuS CVC5.runCVC5InOverride + ] diff --git a/what4/what4.cabal b/what4/what4.cabal index ca17fb48..95d25952 100644 --- a/what4/what4.cabal +++ b/what4/what4.cabal @@ -447,3 +447,17 @@ test-suite what4-serialize-tests , async , directory , ordered-containers + +test-suite invariant-synthesis + import: bldflags, testdefs-hunit + type: exitcode-stdio-1.0 + + main-is: InvariantSynthesis.hs + + other-modules: ProbeSolvers + + build-depends: + bv-sized, + process, + tasty-expected-failure >= 0.12 && < 0.13 +