diff --git a/lib/core/cardano-wallet-core.cabal b/lib/core/cardano-wallet-core.cabal index eda8ba9888f..af506933dbf 100644 --- a/lib/core/cardano-wallet-core.cabal +++ b/lib/core/cardano-wallet-core.cabal @@ -355,7 +355,7 @@ test-suite unit , network , network-uri , nothunks - , persistent + , persistent >=2.13 && <2.14 , persistent-sqlite >=2.13 && <2.14 , plutus-ledger-api , pretty-simple diff --git a/lib/core/src/Cardano/DB/Sqlite.hs b/lib/core/src/Cardano/DB/Sqlite.hs index 273c445230c..7b3cc8255cd 100644 --- a/lib/core/src/Cardano/DB/Sqlite.hs +++ b/lib/core/src/Cardano/DB/Sqlite.hs @@ -42,7 +42,7 @@ module Cardano.DB.Sqlite -- * Manual Migration , ManualMigration (..) , MigrationError (..) - , DBField(..) + , DBField (..) , tableName , fieldName , fieldType diff --git a/lib/core/src/Cardano/Wallet/DB/Sqlite/Migration.hs b/lib/core/src/Cardano/Wallet/DB/Sqlite/Migration.hs index 2862ad574fb..d537e2fd7c4 100644 --- a/lib/core/src/Cardano/Wallet/DB/Sqlite/Migration.hs +++ b/lib/core/src/Cardano/Wallet/DB/Sqlite/Migration.hs @@ -1,4 +1,8 @@ +{-# LANGUAGE BlockArguments #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeApplications #-} @@ -13,6 +17,9 @@ module Cardano.Wallet.DB.Sqlite.Migration ( DefaultFieldValues (..) , migrateManually + , SchemaVersion (..) + , currentSchemaVersion + , InvalidDatabaseSchemaVersion (..) ) where @@ -52,6 +59,10 @@ import Database.Persist.Class ( toPersistValue ) import Database.Persist.Types ( PersistValue (..), fromPersistValueText ) +import Numeric.Natural + ( Natural ) +import UnliftIO.Exception + ( Exception, throwIO, throwString ) import qualified Cardano.Wallet.Primitive.AddressDerivation as W import qualified Cardano.Wallet.Primitive.AddressDiscovery.Sequential as Seq @@ -83,6 +94,23 @@ data SqlColumnStatus | ColumnPresent deriving Eq +data TableCreationResult + = TableCreated + | TableExisted + +newtype SchemaVersion = SchemaVersion Natural + deriving newtype (Eq, Ord, Read, Show ) + +data InvalidDatabaseSchemaVersion + = InvalidDatabaseSchemaVersion + { expectedVersion :: SchemaVersion + , actualVersion :: SchemaVersion + } + deriving (Show, Eq, Exception) + +currentSchemaVersion :: SchemaVersion +currentSchemaVersion = SchemaVersion 1 + -- | Executes any manual database migration steps that may be required on -- startup. migrateManually @@ -93,7 +121,8 @@ migrateManually -> [ManualMigration] migrateManually tr proxy defaultFieldValues = ManualMigration <$> - [ cleanupCheckpointTable + [ initializeSchemaVersionTable + , cleanupCheckpointTable , assignDefaultPassphraseScheme , addDesiredPoolNumberIfMissing , addMinimumUTxOValueIfMissing @@ -116,6 +145,51 @@ migrateManually tr proxy defaultFieldValues = , cleanupSeqStateTable ] where + initializeSchemaVersionTable :: Sqlite.Connection -> IO () + initializeSchemaVersionTable conn = + createSchemaVersionTableIfMissing conn >>= \case + TableCreated -> putSchemaVersion conn currentSchemaVersion + TableExisted -> do + schemaVersion <- getSchemaVersion conn + case compare schemaVersion currentSchemaVersion of + GT -> throwIO InvalidDatabaseSchemaVersion + { expectedVersion = currentSchemaVersion + , actualVersion = schemaVersion + } + LT -> putSchemaVersion conn currentSchemaVersion + EQ -> pure () + + createSchemaVersionTableIfMissing :: + Sqlite.Connection -> IO TableCreationResult + createSchemaVersionTableIfMissing conn = do + res <- runSql conn + "SELECT name FROM sqlite_master \ + \WHERE type='table' AND name='database_schema_version'" + case res of + [] -> TableCreated <$ runSql conn + "CREATE TABLE database_schema_version\ + \( name TEXT PRIMARY KEY \ + \, version INTEGER NOT NULL \ + \)" + _ -> pure TableExisted + + putSchemaVersion :: Sqlite.Connection -> SchemaVersion -> IO () + putSchemaVersion conn schemaVersion = void $ runSql conn $ T.unwords + [ "INSERT INTO database_schema_version (name, version)" + , "VALUES ('schema'," + , version + , ") ON CONFLICT (name) DO UPDATE SET version =" + , version + ] + where + version = T.pack $ show schemaVersion + + getSchemaVersion :: Sqlite.Connection -> IO SchemaVersion + getSchemaVersion conn = + runSql conn "SELECT version FROM database_schema_version" >>= \case + [[PersistInt64 i]] | i >= 0 -> pure $ SchemaVersion $ fromIntegral i + _ -> throwString "Database metadata table is corrupt" + -- NOTE -- We originally stored script pool gap inside sequential state in the 'SeqState' table, -- represented by 'seqStateScriptGap' field. We introduce separate shared wallet state @@ -144,14 +218,10 @@ migrateManually tr proxy defaultFieldValues = return () dropTable :: Text -> Text - dropTable table = mconcat - [ "DROP TABLE IF EXISTS " <> table <> ";" - ] + dropTable table = "DROP TABLE IF EXISTS " <> table <> ";" getTableInfo :: Text -> Text - getTableInfo table = mconcat - [ "PRAGMA table_info(", table, ");" - ] + getTableInfo table = "PRAGMA table_info(" <> table <> ");" filterColumn :: [Text] -> [PersistValue] -> Maybe [PersistValue] filterColumn excluding = \case diff --git a/lib/core/src/Cardano/Wallet/DB/Sqlite/Types.hs b/lib/core/src/Cardano/Wallet/DB/Sqlite/Types.hs index b1b5206042a..8eb7e750d4f 100644 --- a/lib/core/src/Cardano/Wallet/DB/Sqlite/Types.hs +++ b/lib/core/src/Cardano/Wallet/DB/Sqlite/Types.hs @@ -332,7 +332,7 @@ hashOfNoParent :: Hash "BlockHeader" hashOfNoParent = Hash . BS.pack $ replicate 32 0 fromMaybeHash :: Maybe (Hash "BlockHeader") -> BlockId -fromMaybeHash = BlockId . fromMaybe hashOfNoParent +fromMaybeHash = BlockId . fromMaybe hashOfNoParent toMaybeHash :: BlockId -> Maybe (Hash "BlockHeader") toMaybeHash (BlockId h) = if h == hashOfNoParent then Nothing else Just h @@ -850,7 +850,6 @@ instance PersistField POSIXTime where instance PersistFieldSql POSIXTime where sqlType _ = sqlType (Proxy @Text) - -- | Newtype to get a MonadFail instance for @Either Text@. -- -- We need it to use @parseTimeM@. diff --git a/lib/core/test/unit/Cardano/Wallet/DB/SqliteSpec.hs b/lib/core/test/unit/Cardano/Wallet/DB/SqliteSpec.hs index 1dead9dbea2..dc4c768055c 100644 --- a/lib/core/test/unit/Cardano/Wallet/DB/SqliteSpec.hs +++ b/lib/core/test/unit/Cardano/Wallet/DB/SqliteSpec.hs @@ -66,6 +66,11 @@ import Cardano.Wallet.DB.Sqlite , withDBLayer , withDBLayerInMemory ) +import Cardano.Wallet.DB.Sqlite.Migration + ( InvalidDatabaseSchemaVersion (..) + , SchemaVersion (..) + , currentSchemaVersion + ) import Cardano.Wallet.DB.StateMachine ( TestConstraints, prop_parallel, prop_sequential, validateGenerators ) import Cardano.Wallet.DummyTarget.Primitive.Types @@ -272,6 +277,8 @@ import qualified Data.List as L import qualified Data.Set as Set import qualified Data.Text as T import qualified Data.Text.Encoding as T +import qualified Database.Persist.Sql as Sql +import qualified Database.Persist.Sqlite as Sqlite import qualified UnliftIO.STM as STM spec :: Spec @@ -1021,6 +1028,12 @@ manualMigrationsSpec = describe "Manual migrations" $ do ) ] + it "'migrate' db to create metadata table when it doesn't exist" + testCreateMetadataTable + + it "'migrate' db never modifies database with newer version" + testNewerDatabaseIsNeverModified + testMigrationTxMetaFee :: forall k s. ( s ~ SeqState 'Mainnet k @@ -1244,6 +1257,40 @@ testMigrationPassphraseScheme = do Right walOldScheme = fromText "4a6279cd71d5993a288b2c5879daa7c42cebb73d" Right walNoPassphrase = fromText "ba74a7d2c1157ea7f32a93f255dac30e9ebca62b" +testCreateMetadataTable :: + forall s k. (k ~ ShelleyKey, s ~ SeqState 'Mainnet k) => IO () +testCreateMetadataTable = withSystemTempFile "db.sql" $ \path _ -> do + let noop _ = pure () + tr = nullTracer + withDBLayer @s @k tr defaultFieldValues path dummyTimeInterpreter noop + actualVersion <- Sqlite.runSqlite (T.pack path) $ do + [Sqlite.Single (version :: Int)] <- Sqlite.rawSql + "SELECT version FROM database_schema_version \ + \WHERE name = 'schema'" [] + pure $ SchemaVersion $ fromIntegral version + actualVersion `shouldBe` currentSchemaVersion + +testNewerDatabaseIsNeverModified :: + forall s k. (k ~ ShelleyKey, s ~ SeqState 'Mainnet k) => IO () +testNewerDatabaseIsNeverModified = withSystemTempFile "db.sql" $ \path _ -> do + let newerVersion = SchemaVersion 100 + currentVersion = SchemaVersion 1 + _ <- Sqlite.runSqlite (T.pack path) $ do + Sqlite.rawExecute + "CREATE TABLE database_schema_version (name, version)" [] + Sqlite.rawExecute ( + "INSERT INTO database_schema_version \ + \VALUES ('schema', " <> T.pack (show newerVersion) <> ")" + ) [] + let noop _ = pure () + tr = nullTracer + withDBLayer @s @k tr defaultFieldValues path dummyTimeInterpreter noop + `shouldThrow` \case + InvalidDatabaseSchemaVersion {..} + | expectedVersion == currentVersion + && actualVersion == newerVersion -> True + _ -> False + {------------------------------------------------------------------------------- Test data -------------------------------------------------------------------------------}