diff --git a/config/config.go b/config/config.go index fa12d4ac..38563439 100644 --- a/config/config.go +++ b/config/config.go @@ -11,9 +11,9 @@ import ( "github.com/ethereum/go-ethereum/crypto" ) -// // ******* Parsed config ******* // +// type Config struct { DeveloperMode bool @@ -36,7 +36,7 @@ type Config struct { RPCServer RPCServerConfig PartnerPlugin PartnerPluginConfig Tracing TracingConfig - DB DBConfig + DB SQLiteDBConfig Matrix MatrixConfig } @@ -60,15 +60,15 @@ type MatrixConfig struct { Store string } -// +type SQLiteDBConfig struct { + Common UnparsedSQLiteDBConfig + Scheduler UnparsedSQLiteDBConfig + ChequeHandler UnparsedSQLiteDBConfig +} + // ******* Common ******* // - -type DBConfig struct { - DBPath string `mapstructure:"path"` - DBName string `mapstructure:"name"` - MigrationsPath string `mapstructure:"migrations_path"` -} +// type RPCServerConfig struct { Enabled bool `mapstructure:"enabled"` @@ -78,9 +78,9 @@ type RPCServerConfig struct { ServerKeyFile string `mapstructure:"key_file"` } -// // ******* Unparsed config ******* // +// type UnparsedConfig struct { DeveloperMode bool `mapstructure:"developer_mode"` @@ -104,8 +104,8 @@ type UnparsedConfig struct { Tracing UnparsedTracingConfig `mapstructure:"tracing"` Matrix UnparsedMatrixConfig `mapstructure:"matrix"` - RPCServer RPCServerConfig `mapstructure:"rpc_server"` - DB DBConfig `mapstructure:"db"` + RPCServer RPCServerConfig `mapstructure:"rpc_server"` + DB UnparsedSQLiteDBConfig `mapstructure:"db"` } type UnparsedTracingConfig struct { @@ -128,9 +128,14 @@ type UnparsedMatrixConfig struct { Store string `mapstructure:"store"` } +type UnparsedSQLiteDBConfig struct { + DBPath string `mapstructure:"path"` + MigrationsPath string `mapstructure:"migrations_path"` +} + func (cfg *Config) unparse() *UnparsedConfig { return &UnparsedConfig{ - DB: cfg.DB, + DB: cfg.DB.Common, RPCServer: cfg.RPCServer, Tracing: UnparsedTracingConfig{ Enabled: cfg.Tracing.Enabled, diff --git a/config/config_reader.go b/config/config_reader.go index 20bd1dee..9ebc93e9 100644 --- a/config/config_reader.go +++ b/config/config_reader.go @@ -21,7 +21,7 @@ const envPrefix = "CMB" var ( _ Reader = (*reader)(nil) - errInvalidConfig = errors.New("invalid config") + errInvalidRawConfig = errors.New("invalid raw config") ) type Reader interface { @@ -82,7 +82,7 @@ func (cr *reader) ReadConfig() (*Config, error) { parsedCfg, err := cr.parseConfig(cfg) if err != nil { - return nil, fmt.Errorf("%w: %w", errInvalidConfig, err) + return nil, fmt.Errorf("%w: %w", errInvalidRawConfig, err) } return parsedCfg, nil @@ -120,7 +120,17 @@ func (cr *reader) parseConfig(cfg *UnparsedConfig) (*Config, error) { } return &Config{ - DB: cfg.DB, + DB: SQLiteDBConfig{ + Common: cfg.DB, + Scheduler: UnparsedSQLiteDBConfig{ + DBPath: cfg.DB.DBPath + "/scheduler", + MigrationsPath: cfg.DB.MigrationsPath + "/scheduler", + }, + ChequeHandler: UnparsedSQLiteDBConfig{ + DBPath: cfg.DB.DBPath + "/cheque_handler", + MigrationsPath: cfg.DB.MigrationsPath + "/cheque_handler", + }, + }, RPCServer: cfg.RPCServer, Tracing: TracingConfig{ Enabled: cfg.Tracing.Enabled, diff --git a/config/config_reader_test.go b/config/config_reader_test.go index bfeecdd2..d29b6025 100644 --- a/config/config_reader_test.go +++ b/config/config_reader_test.go @@ -35,7 +35,7 @@ func TestReadConfig(t *testing.T) { cr.viper.Set(flagKeyConfig, nonExistingConfigPath) }, flags: Flags(), - expectedErr: errInvalidConfig, // empty bot key + expectedErr: errInvalidRawConfig, // empty bot key }, "from file": { prepare: func(_ *testing.T, cr *reader) { diff --git a/config/flags.go b/config/flags.go index d0056445..5c79cbc8 100644 --- a/config/flags.go +++ b/config/flags.go @@ -29,8 +29,7 @@ func Flags() *pflag.FlagSet { flags.Int64("response_timeout", 3000, "The messenger timeout (in milliseconds).") // DB config flags - flags.String("db.name", "camino_messenger_bot", "Database name.") - flags.String("db.path", "cmb.db", "Path to database.") + flags.String("db.path", "cmb-db", "Path to database dir.") flags.String("db.migrations_path", "file://./migrations", "Path to migration scripts.") // Tracing config flags diff --git a/config/test_config.yaml b/config/test_config.yaml index 7430a124..20ed8948 100644 --- a/config/test_config.yaml +++ b/config/test_config.yaml @@ -6,8 +6,7 @@ cheque_expiration_time: 18144000 cm_account_address: 0xe55E387F5474a012D1b048155E25ea78C7DBfBBC db: migrations_path: file://./migrations - name: camino_messenger_bot - path: supplier-bot.db + path: supplier-bot-db developer_mode: true matrix: host: messenger.chain4travel.com diff --git a/examples/config/camino-messenger-bot-distributor-camino.yaml b/examples/config/camino-messenger-bot-distributor-camino.yaml index 49dc1ba0..72b8e0fd 100644 --- a/examples/config/camino-messenger-bot-distributor-camino.yaml +++ b/examples/config/camino-messenger-bot-distributor-camino.yaml @@ -51,11 +51,8 @@ db: # Path to migrations dir with sql up/down scripts. Schema is mandatory. migrations_path: file://./migrations - # Database name. - name: camino_messenger_bot - - # Path to database file. - path: distributor-bot.db + # Path to database dir. + path: distributor-bot-db diff --git a/examples/config/camino-messenger-bot-distributor-columbus.yaml b/examples/config/camino-messenger-bot-distributor-columbus.yaml index 6818c812..5332a7ae 100644 --- a/examples/config/camino-messenger-bot-distributor-columbus.yaml +++ b/examples/config/camino-messenger-bot-distributor-columbus.yaml @@ -51,11 +51,8 @@ db: # Path to migrations dir with sql up/down scripts. Schema is mandatory. migrations_path: file://./migrations - # Database name. - name: camino_messenger_bot - - # Path to database file. - path: distributor-bot.db + # Path to database dir. + path: distributor-bot-db diff --git a/examples/config/camino-messenger-bot-supplier-camino.yaml b/examples/config/camino-messenger-bot-supplier-camino.yaml index e19bc63d..ddea1dc8 100644 --- a/examples/config/camino-messenger-bot-supplier-camino.yaml +++ b/examples/config/camino-messenger-bot-supplier-camino.yaml @@ -51,11 +51,8 @@ db: # Path to migrations dir with sql up/down scripts. Schema is mandatory. migrations_path: file://./migrations - # Database name. - name: camino_messenger_bot - - # Path to database file. - path: supplier-bot.db + # Path to database dir. + path: supplier-bot-db diff --git a/examples/config/camino-messenger-bot-supplier-columbus.yaml b/examples/config/camino-messenger-bot-supplier-columbus.yaml index 8b29eb90..3d63a7db 100644 --- a/examples/config/camino-messenger-bot-supplier-columbus.yaml +++ b/examples/config/camino-messenger-bot-supplier-columbus.yaml @@ -51,13 +51,8 @@ db: # Path to migrations dir with sql up/down scripts. Schema is mandatory. migrations_path: file://./migrations - # Database name. - name: camino_messenger_bot - - # Path to database file. - path: supplier-bot.db - - + # Path to database dir. + path: supplier-bot-db ### Matrix matrix: diff --git a/go.mod b/go.mod index c659d420..6d2b5221 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jmoiron/sqlx v1.4.0 + github.com/jonboulle/clockwork v0.4.0 github.com/klauspost/compress v1.17.10 github.com/mattn/go-sqlite3 v1.14.23 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index 830c68a4..f4e11c93 100644 --- a/go.sum +++ b/go.sum @@ -160,6 +160,8 @@ github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7Bd github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= +github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= +github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/klauspost/compress v1.17.10 h1:oXAz+Vh0PMUvJczoi+flxpnBEPxoER1IaAnU/NMPtT0= github.com/klauspost/compress v1.17.10/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/internal/app/app.go b/internal/app/app.go index 5dd61ab6..df0f7566 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -3,10 +3,12 @@ package app import ( "context" "fmt" + "time" "github.com/chain4travel/camino-messenger-bot/config" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethclient" + "github.com/jonboulle/clockwork" "maunium.net/go/mautrix/id" "github.com/chain4travel/camino-messenger-bot/internal/compression" @@ -14,16 +16,22 @@ import ( "github.com/chain4travel/camino-messenger-bot/internal/messaging" "github.com/chain4travel/camino-messenger-bot/internal/rpc/client" "github.com/chain4travel/camino-messenger-bot/internal/rpc/server" - "github.com/chain4travel/camino-messenger-bot/internal/scheduler" - "github.com/chain4travel/camino-messenger-bot/internal/storage" "github.com/chain4travel/camino-messenger-bot/internal/tracing" + "github.com/chain4travel/camino-messenger-bot/pkg/chequehandler" + chequeHandlerStorage "github.com/chain4travel/camino-messenger-bot/pkg/chequehandler/storage/sqlite" + cmaccounts "github.com/chain4travel/camino-messenger-bot/pkg/cm_accounts" + "github.com/chain4travel/camino-messenger-bot/pkg/database/sqlite" + "github.com/chain4travel/camino-messenger-bot/pkg/scheduler" + scheduler_storage "github.com/chain4travel/camino-messenger-bot/pkg/scheduler/storage/sqlite" "go.uber.org/zap" "golang.org/x/sync/errgroup" ) const ( - cashInJobName = "cash_in" - appName = "camino-messenger-bot" + cashInJobName = "cash_in" + appName = "camino-messenger-bot" + cmAccountsCacheSize = 100 + cashInTxIssueTimeout = 10 * time.Second ) func NewApp(ctx context.Context, cfg *config.Config, logger *zap.SugaredLogger) (*App, error) { @@ -55,13 +63,6 @@ func NewApp(ctx context.Context, cfg *config.Config, logger *zap.SugaredLogger) return nil, err } - // database - storage, err := storage.New(ctx, logger, cfg.DB) - if err != nil { - logger.Errorf("Failed to create storage: %v", err) - return nil, err - } - // partner-plugin rpc client rpcClient, err := client.NewClient(cfg.PartnerPlugin, logger) if err != nil { @@ -95,27 +96,37 @@ func NewApp(ctx context.Context, cfg *config.Config, logger *zap.SugaredLogger) return nil, err } - identificationHandler, err := messaging.NewIdentificationHandler( + cmAccounts, err := cmaccounts.NewService( + logger, + cmAccountsCacheSize, evmClient, + ) + if err != nil { + logger.Errorf("Failed to create cm accounts service: %v", err) + return nil, err + } + + chequeHandlerStorage, err := chequeHandlerStorage.New( + ctx, logger, - cfg.CMAccountAddress, - cfg.Matrix.HostURL, + sqlite.DBConfig(cfg.DB.ChequeHandler), ) if err != nil { - logger.Errorf("Failed to create identification handler: %v", err) + logger.Errorf("Failed to create cheque handler storage: %v", err) return nil, err } - chequeHandler, err := messaging.NewChequeHandler( + chequeHandler, err := chequehandler.NewChequeHandler( logger, evmClient, cfg.BotKey, cfg.CMAccountAddress, chainID, - storage, - serviceRegistry, + chequeHandlerStorage, + cmAccounts, cfg.MinChequeDurationUntilExpiration, cfg.ChequeExpirationTime, + cashInTxIssueTimeout, ) if err != nil { logger.Errorf("Failed to create cheque handler: %v", err) @@ -137,13 +148,12 @@ func NewApp(ctx context.Context, cfg *config.Config, logger *zap.SugaredLogger) cfg.NetworkFeeRecipientCMAccountAddress, serviceRegistry, responseHandler, - identificationHandler, chequeHandler, messaging.NewCompressor(compression.MaxChunkSize), + cmAccounts, ) // rpc server for incoming requests - // TODO@ disable if we don't have port provided, e.g. its supplier bot? rpcServer, err := server.NewServer( cfg.RPCServer, logger, @@ -157,7 +167,14 @@ func NewApp(ctx context.Context, cfg *config.Config, logger *zap.SugaredLogger) } // scheduler for periodic tasks (e.g. cheques cash-in) - scheduler := scheduler.New(ctx, logger, storage) + + storage, err := scheduler_storage.New(ctx, logger, sqlite.DBConfig(cfg.DB.Scheduler)) + if err != nil { + logger.Errorf("Failed to create storage: %v", err) + return nil, err + } + + scheduler := scheduler.New(logger, storage, clockwork.NewRealClock()) scheduler.RegisterJobHandler(cashInJobName, func() { _ = chequeHandler.CashIn(context.Background()) }) @@ -181,7 +198,7 @@ type App struct { logger *zap.SugaredLogger tracer tracing.Tracer scheduler scheduler.Scheduler - chequeHandler messaging.ChequeHandler + chequeHandler chequehandler.ChequeHandler rpcClient *client.RPCClient rpcServer server.Server messageProcessor messaging.Processor diff --git a/internal/messaging/identification.go b/internal/messaging/identification.go deleted file mode 100644 index 22993da2..00000000 --- a/internal/messaging/identification.go +++ /dev/null @@ -1,112 +0,0 @@ -package messaging - -import ( - "context" - "math/big" - "net/url" - "strings" - - "github.com/chain4travel/camino-messenger-contracts/go/contracts/cmaccount" - "github.com/ethereum/go-ethereum/accounts/abi/bind" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethclient" - lru "github.com/hashicorp/golang-lru/v2" - "go.uber.org/zap" - "maunium.net/go/mautrix/id" -) - -const cmAccountsCacheSize = 100 - -var _ IdentificationHandler = (*evmIdentificationHandler)(nil) - -var roleHash = crypto.Keccak256Hash([]byte("CHEQUE_OPERATOR_ROLE")) - -type evmIdentificationHandler struct { - ethClient *ethclient.Client - matrixHost string - myCMAccountAddress common.Address - cmAccounts *lru.Cache[common.Address, *cmaccount.Cmaccount] - logger *zap.SugaredLogger -} - -type IdentificationHandler interface { - getFirstBotUserIDFromCMAccountAddress(common.Address) (id.UserID, error) -} - -func NewIdentificationHandler( - ethClient *ethclient.Client, - logger *zap.SugaredLogger, - cmAccountAddress common.Address, - matrixHost url.URL, -) (IdentificationHandler, error) { - cmAccountsCache, err := lru.New[common.Address, *cmaccount.Cmaccount](cmAccountsCacheSize) - if err != nil { - return nil, err - } - - return &evmIdentificationHandler{ - ethClient: ethClient, - matrixHost: matrixHost.String(), - myCMAccountAddress: cmAccountAddress, - cmAccounts: cmAccountsCache, - logger: logger, - }, nil -} - -func (ih *evmIdentificationHandler) getFirstBotUserIDFromCMAccountAddress(cmAccountAddress common.Address) (id.UserID, error) { - botAddress, err := ih.getFirstBotFromCMAccountAddress(cmAccountAddress) - if err != nil { - return "", err - } - - return UserIDFromAddress(botAddress, ih.matrixHost), nil -} - -func (ih *evmIdentificationHandler) getFirstBotFromCMAccountAddress(cmAccountAddress common.Address) (common.Address, error) { - bots, err := ih.getAllBotAddressesFromCMAccountAddress(cmAccountAddress) - if err != nil { - return common.Address{}, err - } - return bots[0], nil -} - -func (ih *evmIdentificationHandler) getAllBotAddressesFromCMAccountAddress(cmAccountAddress common.Address) ([]common.Address, error) { - cmAccount, ok := ih.cmAccounts.Get(cmAccountAddress) - if !ok { - var err error - cmAccount, err = cmaccount.NewCmaccount(cmAccountAddress, ih.ethClient) - if err != nil { - ih.logger.Errorf("Failed to get cm Account: %v", err) - return nil, err - } - ih.cmAccounts.Add(cmAccountAddress, cmAccount) - } - - countBig, err := cmAccount.GetRoleMemberCount(&bind.CallOpts{Context: context.TODO()}, roleHash) - if err != nil { - ih.logger.Errorf("Failed to call contract function: %v", err) - return nil, err - } - - count := countBig.Int64() - botsAddresses := make([]common.Address, 0, count) - for i := int64(0); i < count; i++ { - address, err := cmAccount.GetRoleMember(&bind.CallOpts{Context: context.TODO()}, roleHash, big.NewInt(i)) - if err != nil { - ih.logger.Errorf("Failed to call contract function: %v", err) - continue - } - botsAddresses = append(botsAddresses, address) - } - - return botsAddresses, nil -} - -func UserIDFromAddress(address common.Address, host string) id.UserID { - return id.NewUserID(strings.ToLower(address.Hex()), host) -} - -func addressFromUserID(userID id.UserID) common.Address { - return common.HexToAddress(userID.Localpart()) -} diff --git a/internal/messaging/noop_cheque_handler.go b/internal/messaging/noop_cheque_handler.go deleted file mode 100644 index ac1c3d5b..00000000 --- a/internal/messaging/noop_cheque_handler.go +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (C) 2024, Chain4Travel AG. All rights reserved. - * See the file LICENSE for licensing terms. - */ - -package messaging - -import ( - "context" - "math/big" - - "github.com/chain4travel/camino-messenger-bot/pkg/cheques" - "github.com/ethereum/go-ethereum/common" -) - -var _ ChequeHandler = (*NoopChequeHandler)(nil) - -type NoopChequeHandler struct{} - -func (NoopChequeHandler) IssueCheque(_ context.Context, _ common.Address, _ common.Address, _ common.Address, _ *big.Int) (*cheques.SignedCheque, error) { - return &cheques.SignedCheque{}, nil -} - -func (NoopChequeHandler) GetServiceFee(_ context.Context, _ common.Address, _ string) (*big.Int, error) { - return nil, nil -} - -func (NoopChequeHandler) IsBotAllowed(_ context.Context, _ common.Address) (bool, error) { - return true, nil -} - -func (NoopChequeHandler) IsEmptyCheque(_ *cheques.SignedCheque) bool { - return false -} - -func (NoopChequeHandler) CashIn(_ context.Context) error { - return nil -} - -func (NoopChequeHandler) CheckCashInStatus(_ context.Context) error { - return nil -} - -func (NoopChequeHandler) VerifyCheque(_ context.Context, _ *cheques.SignedCheque, _ common.Address, _ *big.Int) error { - return nil -} diff --git a/internal/messaging/noop_indentification.go b/internal/messaging/noop_indentification.go deleted file mode 100644 index e45d3783..00000000 --- a/internal/messaging/noop_indentification.go +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (C) 2024, Chain4Travel AG. All rights reserved. - * See the file LICENSE for licensing terms. - */ - -package messaging - -import ( - "github.com/ethereum/go-ethereum/common" - "maunium.net/go/mautrix/id" -) - -var _ IdentificationHandler = (*NoopIdentification)(nil) - -type NoopIdentification struct{} - -func (NoopIdentification) getFirstBotUserIDFromCMAccountAddress(_ common.Address) (id.UserID, error) { - return "", nil -} diff --git a/internal/messaging/processor.go b/internal/messaging/processor.go index f0ed00d8..59603258 100644 --- a/internal/messaging/processor.go +++ b/internal/messaging/processor.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/big" + "strings" "sync" "time" @@ -12,7 +13,9 @@ import ( "github.com/chain4travel/camino-messenger-bot/internal/messaging/types" "github.com/chain4travel/camino-messenger-bot/internal/metadata" "github.com/chain4travel/camino-messenger-bot/internal/rpc" + "github.com/chain4travel/camino-messenger-bot/pkg/chequehandler" "github.com/chain4travel/camino-messenger-bot/pkg/cheques" + cmaccounts "github.com/chain4travel/camino-messenger-bot/pkg/cm_accounts" "github.com/ethereum/go-ethereum/common" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -23,8 +26,6 @@ import ( "maunium.net/go/mautrix/id" ) -const cashInTxIssueTimeout = 10 * time.Second - var ( _ Processor = (*processor)(nil) @@ -65,9 +66,9 @@ func NewProcessor( networkFeeRecipientCMAccountAddress common.Address, registry ServiceRegistry, responseHandler ResponseHandler, - identificationHandler IdentificationHandler, - chequeHandler ChequeHandler, + chequeHandler chequehandler.ChequeHandler, compressor compression.Compressor[*types.Message, [][]byte], + cmAccounts cmaccounts.Service, ) Processor { return &processor{ messenger: messenger, @@ -77,9 +78,10 @@ func NewProcessor( responseChannels: make(map[string]chan *types.Message), serviceRegistry: registry, responseHandler: responseHandler, - identificationHandler: identificationHandler, chequeHandler: chequeHandler, compressor: compressor, + cmAccounts: cmAccounts, + matrixHost: botUserID.Homeserver(), myBotAddress: addressFromUserID(botUserID), botUserID: botUserID, cmAccountAddress: cmAccountAddress, @@ -93,19 +95,20 @@ type processor struct { logger *zap.SugaredLogger tracer trace.Tracer responseTimeout time.Duration // timeout after which a request is considered failed + matrixHost string botUserID id.UserID myBotAddress common.Address cmAccountAddress common.Address networkFeeRecipientBotAddress common.Address networkFeeRecipientCMAccountAddress common.Address - mu sync.Mutex - responseChannels map[string]chan *types.Message - serviceRegistry ServiceRegistry - responseHandler ResponseHandler - identificationHandler IdentificationHandler - chequeHandler ChequeHandler - compressor compression.Compressor[*types.Message, [][]byte] + mu sync.Mutex + responseChannels map[string]chan *types.Message + serviceRegistry ServiceRegistry + responseHandler ResponseHandler + chequeHandler chequehandler.ChequeHandler + compressor compression.Compressor[*types.Message, [][]byte] + cmAccounts cmaccounts.Service } func (*processor) Checkpoint() string { @@ -177,14 +180,14 @@ func (p *processor) Request(ctx context.Context, msg *types.Message) (*types.Mes p.logger.Infof("Distributor: received a request to propagate to CMAccount %s", msg.Metadata.Recipient) // lookup for CM Account -> bot recipientCMAccAddr := common.HexToAddress(msg.Metadata.Recipient) - recipientBotUserID, err := p.identificationHandler.getFirstBotUserIDFromCMAccountAddress(recipientCMAccAddr) + recipientBotAddr, err := p.getFirstBotFromCMAccount(ctx, recipientCMAccAddr) if err != nil { return nil, err } msg.Metadata.Cheques = []cheques.SignedCheque{} - isBotAllowed, err := p.chequeHandler.IsBotAllowed(ctx, p.myBotAddress) + isBotAllowed, err := p.cmAccounts.IsBotAllowed(ctx, p.cmAccountAddress, p.myBotAddress) if err != nil { return nil, err } @@ -192,7 +195,7 @@ func (p *processor) Request(ctx context.Context, msg *types.Message) (*types.Mes return nil, ErrBotMissingChequeOperatorRole } - serviceFee, err := p.chequeHandler.GetServiceFee(ctx, recipientCMAccAddr, msg.Type.ToServiceName()) + serviceFee, err := p.cmAccounts.GetServiceFee(ctx, recipientCMAccAddr, msg.Type.ToServiceName()) if err != nil { // TODO @evlekht explicitly say if service is not supported and its not just some network error return nil, err @@ -229,7 +232,7 @@ func (p *processor) Request(ctx context.Context, msg *types.Message) (*types.Mes ctx, p.cmAccountAddress, recipientCMAccAddr, - addressFromUserID(recipientBotUserID), + recipientBotAddr, serviceFee, ) if err != nil { @@ -243,11 +246,17 @@ func (p *processor) Request(ctx context.Context, msg *types.Message) (*types.Mes ctx, span := p.tracer.Start(ctx, "processor.Request", trace.WithAttributes(attribute.String("type", string(msg.Type)))) defer span.End() - p.logger.Infof("Distributor: Bot %s is contacting bot %s of the CMaccount %s", msg.Sender, recipientBotUserID, msg.Metadata.Recipient) + p.logger.Infof("Distributor: Bot %s is contacting bot %s of the CMaccount %s", msg.Sender, recipientBotAddr, msg.Metadata.Recipient) - if err := p.messenger.SendAsync(ctx, *msg, compressedContent, recipientBotUserID); err != nil { + if err := p.messenger.SendAsync( + ctx, + *msg, + compressedContent, + UserIDFromAddress(recipientBotAddr, p.matrixHost), + ); err != nil { return nil, err } + ctx, responseSpan := p.tracer.Start(ctx, "processor.AwaitResponse", trace.WithSpanKind(trace.SpanKindConsumer), trace.WithAttributes(attribute.String("type", string(msg.Type)))) defer responseSpan.End() for { @@ -283,7 +292,7 @@ func (p *processor) Respond(msg *types.Message) error { return ErrMissingCheques } - serviceFee, err := p.chequeHandler.GetServiceFee(ctx, common.HexToAddress(msg.Metadata.Recipient), service.Name()) + serviceFee, err := p.cmAccounts.GetServiceFee(ctx, common.HexToAddress(msg.Metadata.Recipient), service.Name()) if err != nil { return err } @@ -378,3 +387,20 @@ func (p *processor) compress(ctx context.Context, msg *types.Message) (context.C } return ctx, compressedContent, nil } + +func (p *processor) getFirstBotFromCMAccount(ctx context.Context, cmAccountAddress common.Address) (common.Address, error) { + bots, err := p.cmAccounts.GetChequeOperators(ctx, cmAccountAddress) + if err != nil { + p.logger.Errorf("failed to get bots from CMAccount: %v", err) + return common.Address{}, err + } + return bots[0], nil +} + +func UserIDFromAddress(address common.Address, host string) id.UserID { + return id.NewUserID(strings.ToLower(address.Hex()), host) +} + +func addressFromUserID(userID id.UserID) common.Address { + return common.HexToAddress(userID.Localpart()) +} diff --git a/internal/messaging/processor_test.go b/internal/messaging/processor_test.go index 4f6657b1..2d0a2dd7 100644 --- a/internal/messaging/processor_test.go +++ b/internal/messaging/processor_test.go @@ -23,7 +23,9 @@ import ( "github.com/chain4travel/camino-messenger-bot/internal/metadata" "github.com/chain4travel/camino-messenger-bot/internal/rpc" "github.com/chain4travel/camino-messenger-bot/internal/rpc/generated" + "github.com/chain4travel/camino-messenger-bot/pkg/chequehandler" "github.com/chain4travel/camino-messenger-bot/pkg/cheques" + cmaccounts "github.com/chain4travel/camino-messenger-bot/pkg/cm_accounts" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/require" @@ -62,18 +64,19 @@ func TestProcessInbound(t *testing.T) { } mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() mockServiceRegistry := NewMockServiceRegistry(mockCtrl) mockService := rpc.NewMockService(mockCtrl) mockMessenger := NewMockMessenger(mockCtrl) + mockCMAccounts := cmaccounts.NewMockService(mockCtrl) + mockChequeHandler := chequehandler.NewMockChequeHandler(mockCtrl) type fields struct { - messenger Messenger - serviceRegistry ServiceRegistry - responseHandler ResponseHandler - identificationHandler IdentificationHandler - chequeHandler ChequeHandler - compressor compression.Compressor[*types.Message, [][]byte] + messenger Messenger + serviceRegistry ServiceRegistry + responseHandler ResponseHandler + chequeHandler chequehandler.ChequeHandler + compressor compression.Compressor[*types.Message, [][]byte] + cmAccounts cmaccounts.Service } type args struct { msg *types.Message @@ -119,18 +122,20 @@ func TestProcessInbound(t *testing.T) { }, "err: process request message failed": { fields: fields{ - serviceRegistry: mockServiceRegistry, - responseHandler: NoopResponseHandler{}, - identificationHandler: NoopIdentification{}, - chequeHandler: NoopChequeHandler{}, - messenger: mockMessenger, - compressor: &noopCompressor{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + chequeHandler: mockChequeHandler, + messenger: mockMessenger, + compressor: &noopCompressor{}, + cmAccounts: mockCMAccounts, }, prepare: func(*processor) { mockService.EXPECT().Name().Return("dummy") - mockService.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil, generated.PingServiceV1Response, nil) - mockServiceRegistry.EXPECT().GetService(gomock.Any()).Times(1).Return(mockService, true) - mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(errSomeError) + mockService.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, generated.PingServiceV1Response, nil) + mockServiceRegistry.EXPECT().GetService(gomock.Any()).Return(mockService, true) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(errSomeError) + mockChequeHandler.EXPECT().VerifyCheque(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockCMAccounts.EXPECT().GetServiceFee(gomock.Any(), gomock.Any(), gomock.Any()).Return(big.NewInt(1), nil) }, args: args{ msg: &types.Message{ @@ -145,18 +150,20 @@ func TestProcessInbound(t *testing.T) { }, "success: process request message": { fields: fields{ - serviceRegistry: mockServiceRegistry, - responseHandler: NoopResponseHandler{}, - identificationHandler: NoopIdentification{}, - chequeHandler: NoopChequeHandler{}, - messenger: mockMessenger, - compressor: &noopCompressor{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + chequeHandler: mockChequeHandler, + messenger: mockMessenger, + compressor: &noopCompressor{}, + cmAccounts: mockCMAccounts, }, prepare: func(*processor) { mockService.EXPECT().Name().Return("dummy") - mockService.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil, generated.PingServiceV1Response, nil) - mockServiceRegistry.EXPECT().GetService(gomock.Any()).Times(1).Return(mockService, true) - mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil) + mockService.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, generated.PingServiceV1Response, nil) + mockServiceRegistry.EXPECT().GetService(gomock.Any()).Return(mockService, true) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockChequeHandler.EXPECT().VerifyCheque(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockCMAccounts.EXPECT().GetServiceFee(gomock.Any(), gomock.Any(), gomock.Any()).Return(big.NewInt(1), nil) }, args: args{ msg: &types.Message{ @@ -170,11 +177,12 @@ func TestProcessInbound(t *testing.T) { }, "success: process response message": { fields: fields{ - serviceRegistry: mockServiceRegistry, - responseHandler: NoopResponseHandler{}, - identificationHandler: NoopIdentification{}, - messenger: mockMessenger, - compressor: &noopCompressor{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + chequeHandler: mockChequeHandler, + messenger: mockMessenger, + compressor: &noopCompressor{}, + cmAccounts: mockCMAccounts, }, prepare: func(p *processor) { p.responseChannels[requestID] = make(chan *types.Message, 1) @@ -200,9 +208,9 @@ func TestProcessInbound(t *testing.T) { common.Address{}, tt.fields.serviceRegistry, tt.fields.responseHandler, - tt.fields.identificationHandler, tt.fields.chequeHandler, tt.fields.compressor, + tt.fields.cmAccounts, ) if tt.prepare != nil { tt.prepare(p.(*processor)) @@ -224,18 +232,19 @@ func TestProcessOutbound(t *testing.T) { } mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() mockServiceRegistry := NewMockServiceRegistry(mockCtrl) mockMessenger := NewMockMessenger(mockCtrl) + mockCMAccounts := cmaccounts.NewMockService(mockCtrl) + mockChequeHandler := chequehandler.NewMockChequeHandler(mockCtrl) type fields struct { - responseTimeout time.Duration - messenger Messenger - serviceRegistry ServiceRegistry - responseHandler ResponseHandler - identificationHandler IdentificationHandler - chequeHandler ChequeHandler - compressor compression.Compressor[*types.Message, [][]byte] + responseTimeout time.Duration + messenger Messenger + serviceRegistry ServiceRegistry + responseHandler ResponseHandler + chequeHandler chequehandler.ChequeHandler + compressor compression.Compressor[*types.Message, [][]byte] + cmAccounts cmaccounts.Service } type args struct { msg *types.Message @@ -250,12 +259,12 @@ func TestProcessOutbound(t *testing.T) { }{ "err: non-request outbound message": { fields: fields{ - serviceRegistry: mockServiceRegistry, - responseHandler: NoopResponseHandler{}, - identificationHandler: NoopIdentification{}, - chequeHandler: NoopChequeHandler{}, - messenger: mockMessenger, - compressor: &noopCompressor{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + chequeHandler: mockChequeHandler, + messenger: mockMessenger, + compressor: &noopCompressor{}, + cmAccounts: mockCMAccounts, }, args: args{ msg: &types.Message{Type: generated.PingServiceV1Response}, @@ -264,12 +273,12 @@ func TestProcessOutbound(t *testing.T) { }, "err: missing recipient": { fields: fields{ - serviceRegistry: mockServiceRegistry, - responseHandler: NoopResponseHandler{}, - identificationHandler: NoopIdentification{}, - chequeHandler: NoopChequeHandler{}, - messenger: mockMessenger, - compressor: &noopCompressor{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + chequeHandler: mockChequeHandler, + messenger: mockMessenger, + compressor: &noopCompressor{}, + cmAccounts: mockCMAccounts, }, args: args{ msg: &types.Message{Type: generated.PingServiceV1Request}, @@ -278,13 +287,13 @@ func TestProcessOutbound(t *testing.T) { }, "err: awaiting-response-timeout exceeded": { fields: fields{ - responseTimeout: 10 * time.Millisecond, // 10ms - serviceRegistry: mockServiceRegistry, - responseHandler: NoopResponseHandler{}, - identificationHandler: NoopIdentification{}, - chequeHandler: NoopChequeHandler{}, - messenger: mockMessenger, - compressor: &noopCompressor{}, + responseTimeout: 10 * time.Millisecond, // 10ms + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + chequeHandler: mockChequeHandler, + messenger: mockMessenger, + compressor: &noopCompressor{}, + cmAccounts: mockCMAccounts, }, args: args{ msg: &types.Message{ @@ -293,19 +302,23 @@ func TestProcessOutbound(t *testing.T) { }, }, prepare: func() { - mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil) + mockCMAccounts.EXPECT().GetChequeOperators(gomock.Any(), gomock.Any()).Return([]common.Address{{}}, nil) + mockCMAccounts.EXPECT().GetServiceFee(gomock.Any(), gomock.Any(), gomock.Any()).Return(big.NewInt(1), nil) + mockCMAccounts.EXPECT().IsBotAllowed(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + mockChequeHandler.EXPECT().IssueCheque(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(&cheques.SignedCheque{}, nil) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, err: ErrExceededResponseTimeout, }, "err: while sending request": { fields: fields{ - responseTimeout: 100 * time.Millisecond, // 100ms - serviceRegistry: mockServiceRegistry, - responseHandler: NoopResponseHandler{}, - identificationHandler: NoopIdentification{}, - chequeHandler: NoopChequeHandler{}, - messenger: mockMessenger, - compressor: &noopCompressor{}, + responseTimeout: 100 * time.Millisecond, // 100ms + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + chequeHandler: mockChequeHandler, + messenger: mockMessenger, + compressor: &noopCompressor{}, + cmAccounts: mockCMAccounts, }, args: args{ msg: &types.Message{ @@ -314,19 +327,26 @@ func TestProcessOutbound(t *testing.T) { }, }, prepare: func() { - mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(errSomeError) + mockCMAccounts.EXPECT().GetChequeOperators(gomock.Any(), gomock.Any()). + Return([]common.Address{{}}, nil) + mockCMAccounts.EXPECT().GetServiceFee(gomock.Any(), gomock.Any(), gomock.Any()). + Return(big.NewInt(1), nil) + mockCMAccounts.EXPECT().IsBotAllowed(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + mockChequeHandler.EXPECT().IssueCheque(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(&cheques.SignedCheque{}, nil) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(errSomeError) }, err: errSomeError, }, "success: response before timeout": { fields: fields{ - responseTimeout: 500 * time.Millisecond, // long enough timeout for response to be received - serviceRegistry: mockServiceRegistry, - responseHandler: NoopResponseHandler{}, - identificationHandler: NoopIdentification{}, - chequeHandler: NoopChequeHandler{}, - messenger: mockMessenger, - compressor: &noopCompressor{}, + responseTimeout: 500 * time.Millisecond, // long enough timeout for response to be received + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + chequeHandler: mockChequeHandler, + messenger: mockMessenger, + compressor: &noopCompressor{}, + cmAccounts: mockCMAccounts, }, args: args{ msg: &types.Message{ @@ -335,7 +355,11 @@ func TestProcessOutbound(t *testing.T) { }, }, prepare: func() { - mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil) + mockCMAccounts.EXPECT().GetChequeOperators(gomock.Any(), gomock.Any()).Return([]common.Address{{}}, nil) + mockCMAccounts.EXPECT().GetServiceFee(gomock.Any(), gomock.Any(), gomock.Any()).Return(big.NewInt(1), nil) + mockCMAccounts.EXPECT().IsBotAllowed(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockChequeHandler.EXPECT().IssueCheque(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(&cheques.SignedCheque{}, nil) }, writeResponseToChannel: func(p *processor) { done := func() bool { @@ -370,9 +394,9 @@ func TestProcessOutbound(t *testing.T) { common.Address{}, tt.fields.serviceRegistry, tt.fields.responseHandler, - tt.fields.identificationHandler, tt.fields.chequeHandler, tt.fields.compressor, + tt.fields.cmAccounts, ) if tt.prepare != nil { tt.prepare() @@ -402,12 +426,19 @@ func (d dummyService) Name() string { func TestStart(t *testing.T) { mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - mockServiceRegistry := NewMockServiceRegistry(mockCtrl) + mockMessenger := NewMockMessenger(mockCtrl) - mockServiceRegistry.EXPECT().GetService(gomock.Any()).AnyTimes().Return(dummyService{}, true) mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(nil) + mockServiceRegistry := NewMockServiceRegistry(mockCtrl) + mockServiceRegistry.EXPECT().GetService(gomock.Any()).AnyTimes().Return(dummyService{}, true) + + mockCMAccounts := cmaccounts.NewMockService(mockCtrl) + mockCMAccounts.EXPECT().GetServiceFee(gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(big.NewInt(1), nil) + + mockChequeHandler := chequehandler.NewMockChequeHandler(mockCtrl) + mockChequeHandler.EXPECT().VerifyCheque(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(nil) + t.Run("start processor and accept messages", func(*testing.T) { ch := make(chan types.Message, 5) // incoming messages @@ -443,9 +474,9 @@ func TestStart(t *testing.T) { common.Address{}, mockServiceRegistry, NoopResponseHandler{}, - NoopIdentification{}, - NoopChequeHandler{}, + mockChequeHandler, &noopCompressor{}, + mockCMAccounts, ) go p.Start(ctx) diff --git a/internal/scheduler/timer.go b/internal/scheduler/timer.go deleted file mode 100644 index 002fe99d..00000000 --- a/internal/scheduler/timer.go +++ /dev/null @@ -1,71 +0,0 @@ -package scheduler - -import ( - "time" -) - -func newTimer() *timer { - t := time.NewTimer(time.Second) - t.Stop() - return &timer{ - Timer: t, - stopCh: make(chan struct{}, 1), - } -} - -type timer struct { - *time.Timer - stopCh chan struct{} -} - -// StartOnce starts the timer once and starts goroutine with [f] call when the timer expires. -// Returns a channel that is closed upon completion. -// -// Should not be called on already running timer. -func (t *timer) StartOnce(d time.Duration, f func()) chan struct{} { - t.Reset(d) - stopSignalCh := make(chan struct{}) - go func() { - defer close(stopSignalCh) - for { - select { - case <-t.stopCh: - return - case <-t.C: - go f() - t.Stop() - } - } - }() - return stopSignalCh -} - -// Start starts the timer and starts goroutine with [f] call when the timer ticks. -// Returns a channel that is closed after timer is stopped. -// -// Should not be called on already running timer. -func (t *timer) Start(d time.Duration, f func()) chan struct{} { - t.Reset(d) - stopSignalCh := make(chan struct{}) - go func() { - defer close(stopSignalCh) - for { - select { - case <-t.stopCh: - return - case <-t.C: - go f() - t.Reset(d) - } - } - }() - return stopSignalCh -} - -// Stop stops the timer. -// -// Stop should not be called on already stopped timer. -func (t *timer) Stop() { - t.stopCh <- struct{}{} - t.Timer.Stop() -} diff --git a/internal/storage/storage.go b/internal/storage/storage.go deleted file mode 100644 index 8c9d2b14..00000000 --- a/internal/storage/storage.go +++ /dev/null @@ -1,179 +0,0 @@ -package storage - -import ( - "context" - "database/sql" - "errors" - - "github.com/chain4travel/camino-messenger-bot/config" - "github.com/golang-migrate/migrate/v4" - "github.com/golang-migrate/migrate/v4/database/sqlite3" - _ "github.com/golang-migrate/migrate/v4/source/file" // required by migrate - "github.com/jmoiron/sqlx" - _ "github.com/mattn/go-sqlite3" // sql driver, required - "go.uber.org/zap" -) - -var ( - _ Storage = (*storage)(nil) - _ Session = (*session)(nil) - - ErrNotFound = errors.New("not found") - ErrAlreadyCommitted = errors.New("already committed") -) - -type Storage interface { - NewSession(ctx context.Context) (Session, error) - Close(ctx context.Context) error -} - -type Session interface { - Commit() error - Abort() - - JobsStorage - ChequeRecordsStorage - IssuedChequeRecordsStorage -} - -func New(ctx context.Context, logger *zap.SugaredLogger, cfg config.DBConfig) (Storage, error) { - db, err := sqlx.Open("sqlite3", cfg.DBPath) - if err != nil { - logger.Error(err) - return nil, err - } - - s := &storage{ - logger: logger, - db: db, - } - - if err := s.migrate(cfg.DBName, cfg.MigrationsPath); err != nil { - return nil, err - } - - if err := s.prepare(ctx); err != nil { - return nil, err - } - - return s, nil -} - -type storage struct { - logger *zap.SugaredLogger - db *sqlx.DB - - issuedChequeRecordsStatements - chequeRecordsStatements - jobsStatements -} - -func (s *storage) migrate(dbName, migrationsPath string) error { - s.logger.Infof("Performing db migrations...") - - driver, err := sqlite3.WithInstance(s.db.DB, &sqlite3.Config{}) - if err != nil { - s.logger.Error(err) - return err - } - - migration, err := migrate.NewWithDatabaseInstance(migrationsPath, dbName, driver) - if err != nil { - s.logger.Error(err) - return err - } - - version, dirty, err := migration.Version() - if err != nil && !errors.Is(err, migrate.ErrNilVersion) { - s.logger.Error(err) - return err - } - if dirty { - return errors.New("database in dirty state after previous migration, requires manual fixing") - } - s.logger.Infof("Migration version: %d", version) - - err = migration.Up() - switch { - case errors.Is(err, migrate.ErrNoChange): - s.logger.Infof("No migrations needed") - case err != nil: - s.logger.Error(err) - return err - default: - newVersion, dirty, err := migration.Version() - if err != nil && !errors.Is(err, migrate.ErrNilVersion) { - s.logger.Error(err) - return err - } - if dirty { - return errors.New("database in dirty state after previous migration, requires manual fixing") - } - s.logger.Infof("New migration version: %d", newVersion) - } - - s.logger.Infof("Finished preforming db migrations") - return nil -} - -func (s *storage) prepare(ctx context.Context) error { - return errors.Join( - s.prepareIssuedChequeRecordsStmts(ctx), - s.prepareChequeRecordsStmts(ctx), - s.prepareJobsStmts(ctx), - ) -} - -func (s *storage) Close(_ context.Context) error { - if err := s.db.Close(); err != nil { - s.logger.Error(err) - return err - } - return nil -} - -func (s *storage) NewSession(ctx context.Context) (Session, error) { - tx, err := s.db.BeginTxx(ctx, &sql.TxOptions{ - Isolation: sql.LevelSerializable, - }) - if err != nil { - s.logger.Error(err) - return nil, err - } - return &session{tx: tx, logger: s.logger, storage: s}, nil -} - -type session struct { - storage *storage - logger *zap.SugaredLogger - tx *sqlx.Tx - committed bool -} - -func (s *session) Commit() error { - if s.committed { - return ErrAlreadyCommitted - } - if err := s.tx.Commit(); err != nil { - s.logger.Error(err) - return upgradeError(err) - } - s.committed = true - return nil -} - -func (s *session) Abort() { - if s.committed { - return - } - if err := s.tx.Rollback(); err != nil { - s.logger.Error(err) - } -} - -func upgradeError(err error) error { - if errors.Is(err, sql.ErrNoRows) { - return ErrNotFound - } - return err -} diff --git a/migrations/1_initial.down.sql b/migrations/cheque_handler/1_initial.down.sql similarity index 75% rename from migrations/1_initial.down.sql rename to migrations/cheque_handler/1_initial.down.sql index 6f79597c..9ee652e5 100644 --- a/migrations/1_initial.down.sql +++ b/migrations/cheque_handler/1_initial.down.sql @@ -1,3 +1,2 @@ DROP TABLE IF EXISTS cheque_records; DROP TABLE IF EXISTS issued_cheque_records; -DROP TABLE IF EXISTS jobs; \ No newline at end of file diff --git a/migrations/1_initial.up.sql b/migrations/cheque_handler/1_initial.up.sql similarity index 82% rename from migrations/1_initial.up.sql rename to migrations/cheque_handler/1_initial.up.sql index a40b1f0e..f5275618 100644 --- a/migrations/1_initial.up.sql +++ b/migrations/cheque_handler/1_initial.up.sql @@ -17,10 +17,3 @@ CREATE TABLE issued_cheque_records ( counter VARBINARY(16) NOT NULL, amount VARBINARY(16) NOT NULL ); - -CREATE TABLE jobs ( - name VARCHAR(100) NOT NULL PRIMARY KEY, - execute_at BIGINT NOT NULL, - period BIGINT NOT NULL -); - diff --git a/migrations/scheduler/1_initial.down.sql b/migrations/scheduler/1_initial.down.sql new file mode 100644 index 00000000..ac5b0966 --- /dev/null +++ b/migrations/scheduler/1_initial.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS jobs; diff --git a/migrations/scheduler/1_initial.up.sql b/migrations/scheduler/1_initial.up.sql new file mode 100644 index 00000000..eac67b04 --- /dev/null +++ b/migrations/scheduler/1_initial.up.sql @@ -0,0 +1,5 @@ +CREATE TABLE jobs ( + name VARCHAR(100) NOT NULL PRIMARY KEY, + execute_at BIGINT NOT NULL, + period BIGINT NOT NULL +); diff --git a/internal/messaging/cheque_handler.go b/pkg/chequehandler/cheque_handler.go similarity index 54% rename from internal/messaging/cheque_handler.go rename to pkg/chequehandler/cheque_handler.go index 332c7074..eaf0a875 100644 --- a/internal/messaging/cheque_handler.go +++ b/pkg/chequehandler/cheque_handler.go @@ -1,4 +1,4 @@ -package messaging +package chequehandler import ( "context" @@ -9,24 +9,55 @@ import ( "sync" "time" - "github.com/chain4travel/camino-messenger-bot/internal/models" - "github.com/chain4travel/camino-messenger-bot/internal/storage" "github.com/chain4travel/camino-messenger-bot/pkg/cheques" - "github.com/chain4travel/camino-messenger-contracts/go/contracts/cmaccount" - "github.com/ethereum/go-ethereum/accounts/abi/bind" + cmaccounts "github.com/chain4travel/camino-messenger-bot/pkg/cm_accounts" "github.com/ethereum/go-ethereum/common" - ethTypes "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethclient" - lru "github.com/hashicorp/golang-lru/v2" "go.uber.org/zap" ) var ( _ ChequeHandler = (*evmChequeHandler)(nil) bigOne = big.NewInt(1) + + ErrNotFound = errors.New("not found") ) +type Storage interface { + SessionHandler + ChequeRecordsStorage + IssuedChequeRecordsStorage +} + +type ChequeRecordsStorage interface { + GetNotCashedChequeRecords(ctx context.Context, session Session) ([]*ChequeRecord, error) + GetChequeRecordsWithPendingTxs(ctx context.Context, session Session) ([]*ChequeRecord, error) + GetChequeRecord(ctx context.Context, session Session, chequeRecordID common.Hash) (*ChequeRecord, error) + GetChequeRecordByTxID(ctx context.Context, session Session, txID common.Hash) (*ChequeRecord, error) + UpsertChequeRecord(ctx context.Context, session Session, chequeRecord *ChequeRecord) error +} + +type IssuedChequeRecordsStorage interface { + GetIssuedChequeRecord(ctx context.Context, session Session, chequeRecordID common.Hash) (*IssuedChequeRecord, error) + UpsertIssuedChequeRecord(ctx context.Context, session Session, chequeRecord *IssuedChequeRecord) error +} + +type SessionHandler interface { + NewSession(ctx context.Context) (Session, error) + Commit(session Session) error + Abort(session Session) +} + +type Session interface { + Commit() error + Abort() error +} + +type TxReceiptGetter interface { + TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) +} + type ChequeHandler interface { IssueCheque( ctx context.Context, @@ -36,14 +67,6 @@ type ChequeHandler interface { amount *big.Int, ) (*cheques.SignedCheque, error) - GetServiceFee( - ctx context.Context, - toCmAccountAddress common.Address, - serviceFullName string, - ) (*big.Int, error) - - IsBotAllowed(ctx context.Context, fromBot common.Address) (bool, error) - CashIn(ctx context.Context) error CheckCashInStatus(ctx context.Context) error @@ -58,44 +81,34 @@ type ChequeHandler interface { func NewChequeHandler( logger *zap.SugaredLogger, - ethClient *ethclient.Client, + ethClient TxReceiptGetter, botKey *ecdsa.PrivateKey, cmAccountAddress common.Address, chainID *big.Int, - storage storage.Storage, - serviceRegistry ServiceRegistry, + storage Storage, + cmAccounts cmaccounts.Service, minChequeDurationUntilExpiration *big.Int, chequeExpirationTime *big.Int, + cashInTxIssueTimeout time.Duration, ) (ChequeHandler, error) { - cmAccountInstance, err := cmaccount.NewCmaccount(cmAccountAddress, ethClient) - if err != nil { - return nil, fmt.Errorf("failed to instantiate contract binding: %w", err) - } - signer, err := cheques.NewSigner(botKey, chainID) if err != nil { return nil, fmt.Errorf("failed to create signer: %w", err) } - cmAccountsCache, err := lru.New[common.Address, *cmaccount.Cmaccount](cmAccountsCacheSize) - if err != nil { - return nil, err - } - return &evmChequeHandler{ - ethClient: ethClient, + txReceiptGetter: ethClient, cmAccountAddress: cmAccountAddress, - cmAccountInstance: cmAccountInstance, chainID: chainID, botKey: botKey, botAddress: crypto.PubkeyToAddress(botKey.PublicKey), logger: logger, storage: storage, signer: signer, - serviceRegistry: serviceRegistry, - cmAccounts: cmAccountsCache, + cmAccounts: cmAccounts, minChequeDurationUntilExpiration: minChequeDurationUntilExpiration, chequeExpirationTime: chequeExpirationTime, + cashInTxIssueTimeout: cashInTxIssueTimeout, }, nil } @@ -103,17 +116,16 @@ type evmChequeHandler struct { logger *zap.SugaredLogger chainID *big.Int - ethClient *ethclient.Client + txReceiptGetter TxReceiptGetter cmAccountAddress common.Address - cmAccountInstance *cmaccount.Cmaccount botKey *ecdsa.PrivateKey botAddress common.Address signer cheques.Signer - serviceRegistry ServiceRegistry - storage storage.Storage - cmAccounts *lru.Cache[common.Address, *cmaccount.Cmaccount] + storage Storage + cmAccounts cmaccounts.Service minChequeDurationUntilExpiration *big.Int chequeExpirationTime *big.Int + cashInTxIssueTimeout time.Duration } func (ch *evmChequeHandler) IssueCheque( @@ -128,7 +140,7 @@ func (ch *evmChequeHandler) IssueCheque( return nil, fmt.Errorf("failed to create session: %w", err) } - defer session.Abort() + defer ch.storage.Abort(session) now := big.NewInt(time.Now().Unix()) newCheque := &cheques.Cheque{ @@ -141,10 +153,10 @@ func (ch *evmChequeHandler) IssueCheque( ExpiresAt: big.NewInt(0).Add(now, ch.chequeExpirationTime), } - chequeRecordID := models.ChequeRecordID(newCheque) + chequeRecordID := ChequeRecordID(newCheque) - previousChequeModel, err := session.GetIssuedChequeRecord(ctx, chequeRecordID) - if err != nil && !errors.Is(err, storage.ErrNotFound) { + previousChequeModel, err := ch.storage.GetIssuedChequeRecord(ctx, session, chequeRecordID) + if err != nil && !errors.Is(err, ErrNotFound) { ch.logger.Errorf("failed to get previous cheque: %v", err) return nil, fmt.Errorf("failed to get previous cheque: %w", err) } @@ -160,12 +172,11 @@ func (ch *evmChequeHandler) IssueCheque( return nil, fmt.Errorf("failed to sign cheque: %w", err) } - isChequeValid, err := verifyChequeWithContract(ctx, ch.cmAccountInstance, signedCheque) - if err != nil { + if isChequeValid, err := ch.cmAccounts.VerifyCheque(ctx, signedCheque); err != nil { ch.logger.Errorf("failed to verify cheque with smart contract: %v", err) return nil, fmt.Errorf("failed to verify cheque with smart contract: %w", err) } else if !isChequeValid { - lastCounter, lastAmount, err := ch.getLastCashIn(ctx, toBot) + lastCounter, lastAmount, err := ch.cmAccounts.GetLastCashIn(ctx, ch.cmAccountAddress, ch.botAddress, toBot) if err != nil { ch.logger.Errorf("failed to get last cash in: %v", err) return nil, fmt.Errorf("failed to get last cash in: %w", err) @@ -179,8 +190,7 @@ func (ch *evmChequeHandler) IssueCheque( return nil, fmt.Errorf("failed to sign cheque: %w", err) } - isChequeValid, err := verifyChequeWithContract(ctx, ch.cmAccountInstance, signedCheque) - if err != nil { + if isChequeValid, err := ch.cmAccounts.VerifyCheque(ctx, signedCheque); err != nil { ch.logger.Errorf("failed to verify cheque with smart contract after getting last cash-in: %v", err) return nil, fmt.Errorf("failed to verify cheque with smart contract: %w", err) } else if !isChequeValid { @@ -189,12 +199,12 @@ func (ch *evmChequeHandler) IssueCheque( } } - if err := session.UpsertIssuedChequeRecord(ctx, models.IssuedChequeRecordCheque(chequeRecordID, signedCheque)); err != nil { + if err := ch.storage.UpsertIssuedChequeRecord(ctx, session, IssuedChequeRecordCheque(chequeRecordID, signedCheque)); err != nil { ch.logger.Error(err) return nil, fmt.Errorf("failed to upsert issued cheque record: %w", err) } - if err := session.Commit(); err != nil { + if err := ch.storage.Commit(session); err != nil { ch.logger.Error(err) return nil, fmt.Errorf("failed to commit session: %w", err) } @@ -202,35 +212,6 @@ func (ch *evmChequeHandler) IssueCheque( return signedCheque, nil } -func (ch *evmChequeHandler) GetServiceFee( - ctx context.Context, - toCmAccountAddress common.Address, - servicefullName string, -) (*big.Int, error) { - supplierCmAccount, err := ch.getCMAccount(toCmAccountAddress) - if err != nil { - return nil, fmt.Errorf("failed to get supplier cmAccount: %w", err) - } - - serviceFee, err := supplierCmAccount.GetServiceFee( - &bind.CallOpts{Context: ctx}, - servicefullName, - ) - if err != nil { - return nil, fmt.Errorf("failed to get service fee: %w", err) - } - return serviceFee, nil -} - -func (ch *evmChequeHandler) IsBotAllowed(ctx context.Context, fromBot common.Address) (bool, error) { - isAllowed, err := ch.cmAccountInstance.IsBotAllowed(&bind.CallOpts{Context: ctx}, fromBot) - if err != nil { - return false, fmt.Errorf("failed to check if bot has required permissions: %w", err) - } - - return isAllowed, nil -} - func (ch *evmChequeHandler) VerifyCheque( ctx context.Context, cheque *cheques.SignedCheque, @@ -242,7 +223,7 @@ func (ch *evmChequeHandler) VerifyCheque( ch.logger.Errorf("failed to create storage session: %v", err) return err } - defer session.Abort() + defer ch.storage.Abort(session) chequeIssuerPubKey, err := ch.signer.RecoverPublicKey(cheque) if err != nil { @@ -254,9 +235,9 @@ func (ch *evmChequeHandler) VerifyCheque( return fmt.Errorf("cheque issuer does not match sender") } - chequeRecordID := models.ChequeRecordID(&cheque.Cheque) - chequeRecord, err := session.GetChequeRecord(ctx, chequeRecordID) - if err != nil && !errors.Is(err, storage.ErrNotFound) { + chequeRecordID := ChequeRecordID(&cheque.Cheque) + chequeRecord, err := ch.storage.GetChequeRecord(ctx, session, chequeRecordID) + if err != nil && !errors.Is(err, ErrNotFound) { ch.logger.Errorf("failed to get chequeRecord: %v", err) return err } @@ -281,7 +262,7 @@ func (ch *evmChequeHandler) VerifyCheque( return fmt.Errorf("cheque amount must at least cover serviceFee") } - if valid, err := ch.verifyChequeWithContract(ctx, cheque); err != nil { + if valid, err := ch.cmAccounts.VerifyCheque(ctx, cheque); err != nil { ch.logger.Errorf("Failed to verify cheque with blockchain: %v", err) return err } else if !valid { @@ -289,25 +270,15 @@ func (ch *evmChequeHandler) VerifyCheque( return fmt.Errorf("cheque is invalid (blockchain validation)") } - chequeRecord = models.ChequeRecordFromCheque(chequeRecordID, cheque) - if err := session.UpsertChequeRecord(ctx, chequeRecord); err != nil { + chequeRecord = ChequeRecordFromCheque(chequeRecordID, cheque) + if err := ch.storage.UpsertChequeRecord(ctx, session, chequeRecord); err != nil { ch.logger.Errorf("Failed to store cheque: %v", err) return err } - return session.Commit() -} - -func (ch *evmChequeHandler) verifyChequeWithContract(ctx context.Context, cheque *cheques.SignedCheque) (bool, error) { - cmAccount, err := ch.getCMAccount(cheque.FromCMAccount) - if err != nil { - ch.logger.Errorf("failed to get cmAccount contract instance: %v", err) - return false, err - } - return verifyChequeWithContract(ctx, cmAccount, cheque) + return ch.storage.Commit(session) } -// TODO @evlekht whole cash in is almost 100% copy-paste from asb, think of moving to common place func (ch *evmChequeHandler) CashIn(ctx context.Context) error { ch.logger.Debug("Cashing in...") defer ch.logger.Debug("Finished cashing in") @@ -317,9 +288,9 @@ func (ch *evmChequeHandler) CashIn(ctx context.Context) error { ch.logger.Error(err) return err } - defer session.Abort() + defer ch.storage.Abort(session) - chequeRecords, err := session.GetNotCashedChequeRecords(ctx) + chequeRecords, err := ch.storage.GetNotCashedChequeRecords(ctx, session) if err != nil { ch.logger.Errorf("failed to get not cashed cheques: %v", err) return err @@ -333,16 +304,20 @@ func (ch *evmChequeHandler) CashIn(ctx context.Context) error { go func() { defer wg.Done() - timedCtx, cancel := context.WithTimeout(ctx, cashInTxIssueTimeout) + timedCtx, cancel := context.WithTimeout(ctx, ch.cashInTxIssueTimeout) defer cancel() - txID, err := ch.cashInCheque(timedCtx, chequeRecord) + txID, err := ch.cmAccounts.CashInCheque( + timedCtx, + &chequeRecord.SignedCheque, + ch.botKey, + ) if err != nil { return } chequeRecord.TxID = txID - chequeRecord.Status = models.ChequeTxStatusPending + chequeRecord.Status = ChequeTxStatusPending // TODO @evlekht if tx will be issued, but then storage will fail to persist it, // TODO tx is still issued and app service will fail to cash in this cheque next time @@ -352,8 +327,8 @@ func (ch *evmChequeHandler) CashIn(ctx context.Context) error { // TODO @evlekht add txCreatedAt field to db and use it for mining timeout ? - if err := session.UpsertChequeRecord(ctx, chequeRecord); err != nil { - chequeRecord.Status = models.ChequeTxStatusUnknown + if err := ch.storage.UpsertChequeRecord(ctx, session, chequeRecord); err != nil { + chequeRecord.Status = ChequeTxStatusUnknown ch.logger.Errorf("failed to update cheque %s: %v", chequeRecord, err) return } @@ -362,13 +337,13 @@ func (ch *evmChequeHandler) CashIn(ctx context.Context) error { wg.Wait() - if err := session.Commit(); err != nil { + if err := ch.storage.Commit(session); err != nil { ch.logger.Errorf("failed to commit session: %v", err) return err } for _, chequeRecord := range chequeRecords { - if chequeRecord.Status != models.ChequeTxStatusPending { + if chequeRecord.Status != ChequeTxStatusPending { continue } @@ -380,48 +355,15 @@ func (ch *evmChequeHandler) CashIn(ctx context.Context) error { return nil } -func (ch *evmChequeHandler) cashInCheque(ctx context.Context, chequeRecord *models.ChequeRecord) (common.Hash, error) { - cmAccount, err := ch.getCMAccount(chequeRecord.FromCMAccount) - if err != nil { - ch.logger.Errorf("failed to get cmAccount contract instance: %v", err) - return common.Hash{}, err - } - - transactor, err := bind.NewKeyedTransactorWithChainID(ch.botKey, ch.chainID) - if err != nil { - ch.logger.Error(err) - return common.Hash{}, err - } - transactor.Context = ctx - - tx, err := cmAccount.CashInCheque( - transactor, - chequeRecord.FromCMAccount, - chequeRecord.ToCMAccount, - chequeRecord.ToBot, - chequeRecord.Counter, - chequeRecord.Amount, - chequeRecord.CreatedAt, - chequeRecord.ExpiresAt, - chequeRecord.Signature, - ) - if err != nil { - ch.logger.Errorf("failed to cash in cheque %s: %v", chequeRecord, err) - return common.Hash{}, err - } - - return tx.Hash(), nil -} - func (ch *evmChequeHandler) CheckCashInStatus(ctx context.Context) error { session, err := ch.storage.NewSession(ctx) if err != nil { ch.logger.Error(err) return err } - defer session.Abort() + defer ch.storage.Abort(session) - chequeRecords, err := session.GetChequeRecordsWithPendingTxs(ctx) + chequeRecords, err := ch.storage.GetChequeRecordsWithPendingTxs(ctx, session) if err != nil { ch.logger.Errorf("failed to get not cashed cheques: %v", err) return err @@ -438,7 +380,7 @@ func (ch *evmChequeHandler) CheckCashInStatus(ctx context.Context) error { func (ch *evmChequeHandler) checkCashInStatus(ctx context.Context, txID common.Hash) error { // TODO @evlekht timeout? what to do if timeouted? - res, err := waitMined(ctx, ch.ethClient, txID) + res, err := ch.waitMined(ctx, txID) if err != nil { ch.logger.Errorf("failed to get cash in transaction receipt %s: %v", txID, err) return err @@ -449,84 +391,34 @@ func (ch *evmChequeHandler) checkCashInStatus(ctx context.Context, txID common.H ch.logger.Error(err) return err } - defer session.Abort() + defer ch.storage.Abort(session) - chequeRecord, err := session.GetChequeRecordByTxID(ctx, txID) + chequeRecord, err := ch.storage.GetChequeRecordByTxID(ctx, session, txID) if err != nil { ch.logger.Errorf("failed to get chequeRecord by txID %s: %v", txID, err) return err } - txStatus := models.ChequeTxStatusFromTxStatus(res.Status) + txStatus := ChequeTxStatusFromTxStatus(res.Status) if chequeRecord.Status == txStatus { return nil } chequeRecord.Status = txStatus - if err := session.UpsertChequeRecord(ctx, chequeRecord); err != nil { + if err := ch.storage.UpsertChequeRecord(ctx, session, chequeRecord); err != nil { ch.logger.Errorf("failed to update chequeRecord %s: %v", chequeRecord, err) return err } - return session.Commit() -} - -func (ch *evmChequeHandler) getCMAccount(address common.Address) (*cmaccount.Cmaccount, error) { - cmAccount, ok := ch.cmAccounts.Get(address) - if ok { - return cmAccount, nil - } - - cmAccount, err := cmaccount.NewCmaccount(address, ch.ethClient) - if err != nil { - ch.logger.Errorf("failed to create cmAccount contract instance: %v", err) - return nil, err - } - _ = ch.cmAccounts.Add(address, cmAccount) - - return cmAccount, nil -} - -func (ch *evmChequeHandler) getLastCashIn(ctx context.Context, toBot common.Address) (counter *big.Int, amount *big.Int, err error) { - lastCashIn, err := ch.cmAccountInstance.GetLastCashIn( - &bind.CallOpts{Context: ctx}, - ch.botAddress, - toBot, - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to get last cash in: %w", err) - } - return lastCashIn.LastCounter, lastCashIn.LastAmount, nil -} - -func verifyChequeWithContract( - ctx context.Context, - cmAcc *cmaccount.Cmaccount, - signedCheque *cheques.SignedCheque, -) (bool, error) { - _, err := cmAcc.VerifyCheque( - &bind.CallOpts{Context: ctx}, - signedCheque.FromCMAccount, - signedCheque.ToCMAccount, - signedCheque.ToBot, - signedCheque.Counter, - signedCheque.Amount, - signedCheque.CreatedAt, - signedCheque.ExpiresAt, - signedCheque.Signature, - ) - if err != nil && err.Error() == "execution reverted" { - return false, nil - } - return err == nil, err + return ch.storage.Commit(session) } -func waitMined(ctx context.Context, b bind.DeployBackend, txID common.Hash) (*ethTypes.Receipt, error) { +func (ch *evmChequeHandler) waitMined(ctx context.Context, txID common.Hash) (*types.Receipt, error) { ticker := time.NewTicker(time.Second) defer ticker.Stop() for { - receipt, err := b.TransactionReceipt(ctx, txID) + receipt, err := ch.txReceiptGetter.TransactionReceipt(ctx, txID) if err == nil { return receipt, nil } diff --git a/internal/models/cheque_record.go b/pkg/chequehandler/cheque_record.go similarity index 98% rename from internal/models/cheque_record.go rename to pkg/chequehandler/cheque_record.go index d3cc5230..70153bcc 100644 --- a/internal/models/cheque_record.go +++ b/pkg/chequehandler/cheque_record.go @@ -1,4 +1,4 @@ -package models +package chequehandler import ( "fmt" diff --git a/internal/models/issued_cheque_record.go b/pkg/chequehandler/issued_cheque_record.go similarity index 95% rename from internal/models/issued_cheque_record.go rename to pkg/chequehandler/issued_cheque_record.go index 57c42b25..d3b67301 100644 --- a/internal/models/issued_cheque_record.go +++ b/pkg/chequehandler/issued_cheque_record.go @@ -1,4 +1,4 @@ -package models +package chequehandler import ( "math/big" diff --git a/pkg/chequehandler/mock_cheque_handler.go b/pkg/chequehandler/mock_cheque_handler.go new file mode 100644 index 00000000..ba4c6d3c --- /dev/null +++ b/pkg/chequehandler/mock_cheque_handler.go @@ -0,0 +1,100 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/chain4travel/camino-messenger-bot/pkg/chequehandler (interfaces: ChequeHandler) +// +// Generated by this command: +// +// mockgen -package=chequehandler -destination=pkg/chequehandler/mock_cheque_handler.go github.com/chain4travel/camino-messenger-bot/pkg/chequehandler ChequeHandler +// + +// Package chequehandler is a generated GoMock package. +package chequehandler + +import ( + context "context" + big "math/big" + reflect "reflect" + + cheques "github.com/chain4travel/camino-messenger-bot/pkg/cheques" + common "github.com/ethereum/go-ethereum/common" + gomock "go.uber.org/mock/gomock" +) + +// MockChequeHandler is a mock of ChequeHandler interface. +type MockChequeHandler struct { + ctrl *gomock.Controller + recorder *MockChequeHandlerMockRecorder +} + +// MockChequeHandlerMockRecorder is the mock recorder for MockChequeHandler. +type MockChequeHandlerMockRecorder struct { + mock *MockChequeHandler +} + +// NewMockChequeHandler creates a new mock instance. +func NewMockChequeHandler(ctrl *gomock.Controller) *MockChequeHandler { + mock := &MockChequeHandler{ctrl: ctrl} + mock.recorder = &MockChequeHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChequeHandler) EXPECT() *MockChequeHandlerMockRecorder { + return m.recorder +} + +// CashIn mocks base method. +func (m *MockChequeHandler) CashIn(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CashIn", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CashIn indicates an expected call of CashIn. +func (mr *MockChequeHandlerMockRecorder) CashIn(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CashIn", reflect.TypeOf((*MockChequeHandler)(nil).CashIn), arg0) +} + +// CheckCashInStatus mocks base method. +func (m *MockChequeHandler) CheckCashInStatus(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckCashInStatus", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CheckCashInStatus indicates an expected call of CheckCashInStatus. +func (mr *MockChequeHandlerMockRecorder) CheckCashInStatus(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckCashInStatus", reflect.TypeOf((*MockChequeHandler)(nil).CheckCashInStatus), arg0) +} + +// IssueCheque mocks base method. +func (m *MockChequeHandler) IssueCheque(arg0 context.Context, arg1, arg2, arg3 common.Address, arg4 *big.Int) (*cheques.SignedCheque, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IssueCheque", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(*cheques.SignedCheque) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IssueCheque indicates an expected call of IssueCheque. +func (mr *MockChequeHandlerMockRecorder) IssueCheque(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IssueCheque", reflect.TypeOf((*MockChequeHandler)(nil).IssueCheque), arg0, arg1, arg2, arg3, arg4) +} + +// VerifyCheque mocks base method. +func (m *MockChequeHandler) VerifyCheque(arg0 context.Context, arg1 *cheques.SignedCheque, arg2 common.Address, arg3 *big.Int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "VerifyCheque", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// VerifyCheque indicates an expected call of VerifyCheque. +func (mr *MockChequeHandlerMockRecorder) VerifyCheque(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyCheque", reflect.TypeOf((*MockChequeHandler)(nil).VerifyCheque), arg0, arg1, arg2, arg3) +} diff --git a/pkg/chequehandler/mock_storage.go b/pkg/chequehandler/mock_storage.go new file mode 100644 index 00000000..1b92354c --- /dev/null +++ b/pkg/chequehandler/mock_storage.go @@ -0,0 +1,185 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/chain4travel/camino-messenger-bot/pkg/chequehandler (interfaces: Storage) +// +// Generated by this command: +// +// mockgen -package=chequehandler -destination=pkg/chequehandler/mock_storage.go github.com/chain4travel/camino-messenger-bot/pkg/chequehandler Storage +// + +// Package chequehandler is a generated GoMock package. +package chequehandler + +import ( + context "context" + reflect "reflect" + + common "github.com/ethereum/go-ethereum/common" + gomock "go.uber.org/mock/gomock" +) + +// MockStorage is a mock of Storage interface. +type MockStorage struct { + ctrl *gomock.Controller + recorder *MockStorageMockRecorder +} + +// MockStorageMockRecorder is the mock recorder for MockStorage. +type MockStorageMockRecorder struct { + mock *MockStorage +} + +// NewMockStorage creates a new mock instance. +func NewMockStorage(ctrl *gomock.Controller) *MockStorage { + mock := &MockStorage{ctrl: ctrl} + mock.recorder = &MockStorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStorage) EXPECT() *MockStorageMockRecorder { + return m.recorder +} + +// Abort mocks base method. +func (m *MockStorage) Abort(arg0 Session) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Abort", arg0) +} + +// Abort indicates an expected call of Abort. +func (mr *MockStorageMockRecorder) Abort(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abort", reflect.TypeOf((*MockStorage)(nil).Abort), arg0) +} + +// Commit mocks base method. +func (m *MockStorage) Commit(arg0 Session) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockStorageMockRecorder) Commit(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockStorage)(nil).Commit), arg0) +} + +// GetChequeRecord mocks base method. +func (m *MockStorage) GetChequeRecord(arg0 context.Context, arg1 Session, arg2 common.Hash) (*ChequeRecord, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChequeRecord", arg0, arg1, arg2) + ret0, _ := ret[0].(*ChequeRecord) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChequeRecord indicates an expected call of GetChequeRecord. +func (mr *MockStorageMockRecorder) GetChequeRecord(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChequeRecord", reflect.TypeOf((*MockStorage)(nil).GetChequeRecord), arg0, arg1, arg2) +} + +// GetChequeRecordByTxID mocks base method. +func (m *MockStorage) GetChequeRecordByTxID(arg0 context.Context, arg1 Session, arg2 common.Hash) (*ChequeRecord, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChequeRecordByTxID", arg0, arg1, arg2) + ret0, _ := ret[0].(*ChequeRecord) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChequeRecordByTxID indicates an expected call of GetChequeRecordByTxID. +func (mr *MockStorageMockRecorder) GetChequeRecordByTxID(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChequeRecordByTxID", reflect.TypeOf((*MockStorage)(nil).GetChequeRecordByTxID), arg0, arg1, arg2) +} + +// GetChequeRecordsWithPendingTxs mocks base method. +func (m *MockStorage) GetChequeRecordsWithPendingTxs(arg0 context.Context, arg1 Session) ([]*ChequeRecord, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChequeRecordsWithPendingTxs", arg0, arg1) + ret0, _ := ret[0].([]*ChequeRecord) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChequeRecordsWithPendingTxs indicates an expected call of GetChequeRecordsWithPendingTxs. +func (mr *MockStorageMockRecorder) GetChequeRecordsWithPendingTxs(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChequeRecordsWithPendingTxs", reflect.TypeOf((*MockStorage)(nil).GetChequeRecordsWithPendingTxs), arg0, arg1) +} + +// GetIssuedChequeRecord mocks base method. +func (m *MockStorage) GetIssuedChequeRecord(arg0 context.Context, arg1 Session, arg2 common.Hash) (*IssuedChequeRecord, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIssuedChequeRecord", arg0, arg1, arg2) + ret0, _ := ret[0].(*IssuedChequeRecord) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetIssuedChequeRecord indicates an expected call of GetIssuedChequeRecord. +func (mr *MockStorageMockRecorder) GetIssuedChequeRecord(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIssuedChequeRecord", reflect.TypeOf((*MockStorage)(nil).GetIssuedChequeRecord), arg0, arg1, arg2) +} + +// GetNotCashedChequeRecords mocks base method. +func (m *MockStorage) GetNotCashedChequeRecords(arg0 context.Context, arg1 Session) ([]*ChequeRecord, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNotCashedChequeRecords", arg0, arg1) + ret0, _ := ret[0].([]*ChequeRecord) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNotCashedChequeRecords indicates an expected call of GetNotCashedChequeRecords. +func (mr *MockStorageMockRecorder) GetNotCashedChequeRecords(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNotCashedChequeRecords", reflect.TypeOf((*MockStorage)(nil).GetNotCashedChequeRecords), arg0, arg1) +} + +// NewSession mocks base method. +func (m *MockStorage) NewSession(arg0 context.Context) (Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSession", arg0) + ret0, _ := ret[0].(Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewSession indicates an expected call of NewSession. +func (mr *MockStorageMockRecorder) NewSession(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSession", reflect.TypeOf((*MockStorage)(nil).NewSession), arg0) +} + +// UpsertChequeRecord mocks base method. +func (m *MockStorage) UpsertChequeRecord(arg0 context.Context, arg1 Session, arg2 *ChequeRecord) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertChequeRecord", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertChequeRecord indicates an expected call of UpsertChequeRecord. +func (mr *MockStorageMockRecorder) UpsertChequeRecord(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChequeRecord", reflect.TypeOf((*MockStorage)(nil).UpsertChequeRecord), arg0, arg1, arg2) +} + +// UpsertIssuedChequeRecord mocks base method. +func (m *MockStorage) UpsertIssuedChequeRecord(arg0 context.Context, arg1 Session, arg2 *IssuedChequeRecord) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertIssuedChequeRecord", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertIssuedChequeRecord indicates an expected call of UpsertIssuedChequeRecord. +func (mr *MockStorageMockRecorder) UpsertIssuedChequeRecord(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertIssuedChequeRecord", reflect.TypeOf((*MockStorage)(nil).UpsertIssuedChequeRecord), arg0, arg1, arg2) +} diff --git a/internal/storage/cheque_records.go b/pkg/chequehandler/storage/sqlite/cheque_records.go similarity index 53% rename from internal/storage/cheque_records.go rename to pkg/chequehandler/storage/sqlite/cheque_records.go index 13b4ce93..e3cfdede 100644 --- a/internal/storage/cheque_records.go +++ b/pkg/chequehandler/storage/sqlite/cheque_records.go @@ -1,4 +1,4 @@ -package storage +package sqlite import ( "context" @@ -7,7 +7,7 @@ import ( "fmt" "math/big" - "github.com/chain4travel/camino-messenger-bot/internal/models" + "github.com/chain4travel/camino-messenger-bot/pkg/chequehandler" "github.com/chain4travel/camino-messenger-bot/pkg/cheques" "github.com/ethereum/go-ethereum/common" "github.com/jmoiron/sqlx" @@ -16,49 +16,47 @@ import ( const chequeRecordsTableName = "cheque_records" var ( - _ ChequeRecordsStorage = (*session)(nil) + _ chequehandler.ChequeRecordsStorage = (*storage)(nil) zeroHash = common.Hash{} ) -type ChequeRecordsStorage interface { - GetNotCashedChequeRecords(ctx context.Context) ([]*models.ChequeRecord, error) - GetChequeRecordsWithPendingTxs(ctx context.Context) ([]*models.ChequeRecord, error) - GetChequeRecord(ctx context.Context, chequeRecordID common.Hash) (*models.ChequeRecord, error) - GetChequeRecordByTxID(ctx context.Context, txID common.Hash) (*models.ChequeRecord, error) - UpsertChequeRecord(ctx context.Context, chequeRecord *models.ChequeRecord) error -} - type chequeRecord struct { - ChequeRecordID common.Hash `db:"cheque_record_id"` - FromCMAccount common.Address `db:"from_cm_account"` - ToCMAccount common.Address `db:"to_cm_account"` - ToBot common.Address `db:"to_bot"` - Counter []byte `db:"counter"` - Amount []byte `db:"amount"` - CreatedAt []byte `db:"created_at"` - ExpiresAt []byte `db:"expires_at"` - Signature []byte `db:"signature"` - TxID *common.Hash `db:"tx_id"` - Status *models.ChequeTxStatus `db:"status"` + ChequeRecordID common.Hash `db:"cheque_record_id"` + FromCMAccount common.Address `db:"from_cm_account"` + ToCMAccount common.Address `db:"to_cm_account"` + ToBot common.Address `db:"to_bot"` + Counter []byte `db:"counter"` + Amount []byte `db:"amount"` + CreatedAt []byte `db:"created_at"` + ExpiresAt []byte `db:"expires_at"` + Signature []byte `db:"signature"` + TxID *common.Hash `db:"tx_id"` + Status *chequehandler.ChequeTxStatus `db:"status"` } -func (s *session) GetNotCashedChequeRecords(ctx context.Context) ([]*models.ChequeRecord, error) { - chequeRecords := []*models.ChequeRecord{} - rows, err := s.tx.StmtxContext(ctx, s.storage.getNotCashedChequeRecords).QueryxContext(ctx) +func (s *storage) GetNotCashedChequeRecords(ctx context.Context, session chequehandler.Session) ([]*chequehandler.ChequeRecord, error) { + tx, err := getSQLXTx(session) + if err != nil { + s.base.Logger.Error(err) + return nil, err + } + + chequeRecords := []*chequehandler.ChequeRecord{} + rows, err := tx.StmtxContext(ctx, s.getNotCashedChequeRecords).QueryxContext(ctx) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return nil, upgradeError(err) } for rows.Next() { chequeRecord := &chequeRecord{} if err := rows.StructScan(chequeRecord); err != nil { - s.logger.Errorf("failed to get not cashed chequeRecord from db: %v", err) + s.base.Logger.Errorf("failed to get not cashed chequeRecord from db: %v", err) continue } model, err := modelFromChequeRecord(chequeRecord) if err != nil { - s.logger.Errorf("failed to parse not cashed chequeRecord: %v", err) + s.base.Logger.Errorf("failed to parse not cashed chequeRecord: %v", err) continue } chequeRecords = append(chequeRecords, model) @@ -66,22 +64,28 @@ func (s *session) GetNotCashedChequeRecords(ctx context.Context) ([]*models.Cheq return chequeRecords, nil } -func (s *session) GetChequeRecordsWithPendingTxs(ctx context.Context) ([]*models.ChequeRecord, error) { - chequeRecords := []*models.ChequeRecord{} - rows, err := s.tx.StmtxContext(ctx, s.storage.getChequeRecordsWithPendingTxs).QueryxContext(ctx) +func (s *storage) GetChequeRecordsWithPendingTxs(ctx context.Context, session chequehandler.Session) ([]*chequehandler.ChequeRecord, error) { + tx, err := getSQLXTx(session) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) + return nil, err + } + + chequeRecords := []*chequehandler.ChequeRecord{} + rows, err := tx.StmtxContext(ctx, s.getChequeRecordsWithPendingTxs).QueryxContext(ctx) + if err != nil { + s.base.Logger.Error(err) return nil, upgradeError(err) } for rows.Next() { chequeRecord := &chequeRecord{} if err := rows.StructScan(chequeRecord); err != nil { - s.logger.Errorf("failed to get chequeRecord with pending tx from db: %v", err) + s.base.Logger.Errorf("failed to get chequeRecord with pending tx from db: %v", err) continue } model, err := modelFromChequeRecord(chequeRecord) if err != nil { - s.logger.Errorf("failed to parse chequeRecord with pending tx: %v", err) + s.base.Logger.Errorf("failed to parse chequeRecord with pending tx: %v", err) continue } chequeRecords = append(chequeRecords, model) @@ -89,37 +93,55 @@ func (s *session) GetChequeRecordsWithPendingTxs(ctx context.Context) ([]*models return chequeRecords, nil } -func (s *session) GetChequeRecord(ctx context.Context, chequeRecordID common.Hash) (*models.ChequeRecord, error) { +func (s *storage) GetChequeRecord(ctx context.Context, session chequehandler.Session, chequeRecordID common.Hash) (*chequehandler.ChequeRecord, error) { + tx, err := getSQLXTx(session) + if err != nil { + s.base.Logger.Error(err) + return nil, err + } + chequeRecord := &chequeRecord{} - if err := s.tx.StmtxContext(ctx, s.storage.getChequeRecordByID).GetContext(ctx, chequeRecord, chequeRecordID); err != nil { + if err := tx.StmtxContext(ctx, s.getChequeRecordByID).GetContext(ctx, chequeRecord, chequeRecordID); err != nil { if !errors.Is(err, sql.ErrNoRows) { - s.logger.Error(err) + s.base.Logger.Error(err) } return nil, upgradeError(err) } return modelFromChequeRecord(chequeRecord) } -func (s *session) GetChequeRecordByTxID(ctx context.Context, txID common.Hash) (*models.ChequeRecord, error) { +func (s *storage) GetChequeRecordByTxID(ctx context.Context, session chequehandler.Session, txID common.Hash) (*chequehandler.ChequeRecord, error) { + tx, err := getSQLXTx(session) + if err != nil { + s.base.Logger.Error(err) + return nil, err + } + chequeRecord := &chequeRecord{} - if err := s.tx.StmtxContext(ctx, s.storage.getChequeRecordByTxID).GetContext(ctx, chequeRecord, txID); err != nil { + if err := tx.StmtxContext(ctx, s.getChequeRecordByTxID).GetContext(ctx, chequeRecord, txID); err != nil { if !errors.Is(err, sql.ErrNoRows) { - s.logger.Error(err) + s.base.Logger.Error(err) } return nil, upgradeError(err) } return modelFromChequeRecord(chequeRecord) } -func (s *session) UpsertChequeRecord(ctx context.Context, chequeRecord *models.ChequeRecord) error { - result, err := s.tx.NamedStmtContext(ctx, s.storage.upsertChequeRecord). +func (s *storage) UpsertChequeRecord(ctx context.Context, session chequehandler.Session, chequeRecord *chequehandler.ChequeRecord) error { + tx, err := getSQLXTx(session) + if err != nil { + s.base.Logger.Error(err) + return err + } + + result, err := tx.NamedStmtContext(ctx, s.upsertChequeRecord). ExecContext(ctx, chequeRecordFromModel(chequeRecord)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return upgradeError(err) } if rowsAffected, err := result.RowsAffected(); err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return upgradeError(err) } else if rowsAffected != 1 { return fmt.Errorf("failed to add chequeRecord: expected to affect 1 row, but affected %d", rowsAffected) @@ -134,47 +156,47 @@ type chequeRecordsStatements struct { } func (s *storage) prepareChequeRecordsStmts(ctx context.Context) error { - getNotCashedChequeRecords, err := s.db.PreparexContext(ctx, fmt.Sprintf(` + getNotCashedChequeRecords, err := s.base.DB.PreparexContext(ctx, fmt.Sprintf(` SELECT * FROM %s WHERE status = %d OR status IS NULL - `, chequeRecordsTableName, models.ChequeTxStatusRejected)) + `, chequeRecordsTableName, chequehandler.ChequeTxStatusRejected)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.getNotCashedChequeRecords = getNotCashedChequeRecords - getChequeRecordsWithPendingTxs, err := s.db.PreparexContext(ctx, fmt.Sprintf(` + getChequeRecordsWithPendingTxs, err := s.base.DB.PreparexContext(ctx, fmt.Sprintf(` SELECT * FROM %s WHERE status = %d - `, chequeRecordsTableName, models.ChequeTxStatusPending)) + `, chequeRecordsTableName, chequehandler.ChequeTxStatusPending)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.getChequeRecordsWithPendingTxs = getChequeRecordsWithPendingTxs - getChequeRecordByID, err := s.db.PreparexContext(ctx, fmt.Sprintf(` + getChequeRecordByID, err := s.base.DB.PreparexContext(ctx, fmt.Sprintf(` SELECT * FROM %s WHERE cheque_record_id = ? `, chequeRecordsTableName)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.getChequeRecordByID = getChequeRecordByID - getChequeByTxID, err := s.db.PreparexContext(ctx, fmt.Sprintf(` + getChequeByTxID, err := s.base.DB.PreparexContext(ctx, fmt.Sprintf(` SELECT * FROM %s WHERE tx_id = ? `, chequeRecordsTableName)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.getChequeRecordByTxID = getChequeByTxID - upsertChequeRecord, err := s.db.PrepareNamedContext(ctx, fmt.Sprintf(` + upsertChequeRecord, err := s.base.DB.PrepareNamedContext(ctx, fmt.Sprintf(` INSERT INTO %s ( cheque_record_id, from_cm_account, @@ -211,7 +233,7 @@ func (s *storage) prepareChequeRecordsStmts(ctx context.Context) error { status = excluded.status `, chequeRecordsTableName)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.upsertChequeRecord = upsertChequeRecord @@ -219,18 +241,18 @@ func (s *storage) prepareChequeRecordsStmts(ctx context.Context) error { return nil } -func modelFromChequeRecord(chequeRecord *chequeRecord) (*models.ChequeRecord, error) { +func modelFromChequeRecord(chequeRecord *chequeRecord) (*chequehandler.ChequeRecord, error) { txID := common.Hash{} if chequeRecord.TxID != nil { txID = *chequeRecord.TxID } - status := models.ChequeTxStatusUnknown + status := chequehandler.ChequeTxStatusUnknown if chequeRecord.Status != nil { status = *chequeRecord.Status } - return &models.ChequeRecord{ + return &chequehandler.ChequeRecord{ SignedCheque: cheques.SignedCheque{ Cheque: cheques.Cheque{ FromCMAccount: chequeRecord.FromCMAccount, @@ -249,14 +271,14 @@ func modelFromChequeRecord(chequeRecord *chequeRecord) (*models.ChequeRecord, er }, nil } -func chequeRecordFromModel(model *models.ChequeRecord) *chequeRecord { +func chequeRecordFromModel(model *chequehandler.ChequeRecord) *chequeRecord { var txID *common.Hash if model.TxID != zeroHash { txID = &model.TxID } - var status *models.ChequeTxStatus - if model.Status != models.ChequeTxStatusUnknown { + var status *chequehandler.ChequeTxStatus + if model.Status != chequehandler.ChequeTxStatusUnknown { status = &model.Status } diff --git a/internal/storage/issued_cheque_records.go b/pkg/chequehandler/storage/sqlite/issued_cheque_records.go similarity index 61% rename from internal/storage/issued_cheque_records.go rename to pkg/chequehandler/storage/sqlite/issued_cheque_records.go index 998a3c4f..d7f6ee71 100644 --- a/internal/storage/issued_cheque_records.go +++ b/pkg/chequehandler/storage/sqlite/issued_cheque_records.go @@ -1,4 +1,4 @@ -package storage +package sqlite import ( "context" @@ -7,19 +7,14 @@ import ( "fmt" "math/big" - "github.com/chain4travel/camino-messenger-bot/internal/models" + "github.com/chain4travel/camino-messenger-bot/pkg/chequehandler" "github.com/ethereum/go-ethereum/common" "github.com/jmoiron/sqlx" ) const issuedChequeRecordsTableName = "issued_cheque_records" -var _ ChequeRecordsStorage = (*session)(nil) - -type IssuedChequeRecordsStorage interface { - GetIssuedChequeRecord(ctx context.Context, chequeRecordID common.Hash) (*models.IssuedChequeRecord, error) - UpsertIssuedChequeRecord(ctx context.Context, chequeRecord *models.IssuedChequeRecord) error -} +var _ chequehandler.ChequeRecordsStorage = (*storage)(nil) type issuedChequeRecord struct { ChequeRecordID common.Hash `db:"cheque_record_id"` @@ -27,26 +22,38 @@ type issuedChequeRecord struct { Amount []byte `db:"amount"` } -func (s *session) GetIssuedChequeRecord(ctx context.Context, chequeRecordID common.Hash) (*models.IssuedChequeRecord, error) { +func (s *storage) GetIssuedChequeRecord(ctx context.Context, session chequehandler.Session, chequeRecordID common.Hash) (*chequehandler.IssuedChequeRecord, error) { + tx, err := getSQLXTx(session) + if err != nil { + s.base.Logger.Error(err) + return nil, err + } + chequeRecord := &issuedChequeRecord{} - if err := s.tx.StmtxContext(ctx, s.storage.getIssuedChequeRecord).GetContext(ctx, chequeRecord, chequeRecordID); err != nil { + if err := tx.StmtxContext(ctx, s.getIssuedChequeRecord).GetContext(ctx, chequeRecord, chequeRecordID); err != nil { if !errors.Is(err, sql.ErrNoRows) { - s.logger.Error(err) + s.base.Logger.Error(err) } return nil, upgradeError(err) } return modelFromIssuedChequeRecord(chequeRecord), nil } -func (s *session) UpsertIssuedChequeRecord(ctx context.Context, chequeRecord *models.IssuedChequeRecord) error { - result, err := s.tx.NamedStmtContext(ctx, s.storage.upsertIssuedChequeRecord). +func (s *storage) UpsertIssuedChequeRecord(ctx context.Context, session chequehandler.Session, chequeRecord *chequehandler.IssuedChequeRecord) error { + tx, err := getSQLXTx(session) + if err != nil { + s.base.Logger.Error(err) + return err + } + + result, err := tx.NamedStmtContext(ctx, s.upsertIssuedChequeRecord). ExecContext(ctx, issuedChequeRecordFromModel(chequeRecord)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return upgradeError(err) } if rowsAffected, err := result.RowsAffected(); err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return upgradeError(err) } else if rowsAffected != 1 { return fmt.Errorf("failed to add chequeRecord: expected to affect 1 row, but affected %d", rowsAffected) @@ -60,17 +67,17 @@ type issuedChequeRecordsStatements struct { } func (s *storage) prepareIssuedChequeRecordsStmts(ctx context.Context) error { - getIssuedChequeRecord, err := s.db.PreparexContext(ctx, fmt.Sprintf(` + getIssuedChequeRecord, err := s.base.DB.PreparexContext(ctx, fmt.Sprintf(` SELECT * FROM %s WHERE cheque_record_id = ? `, issuedChequeRecordsTableName)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.getIssuedChequeRecord = getIssuedChequeRecord - upsertIssuedChequeRecord, err := s.db.PrepareNamedContext(ctx, fmt.Sprintf(` + upsertIssuedChequeRecord, err := s.base.DB.PrepareNamedContext(ctx, fmt.Sprintf(` INSERT INTO %s ( cheque_record_id, counter, @@ -86,7 +93,7 @@ func (s *storage) prepareIssuedChequeRecordsStmts(ctx context.Context) error { amount = excluded.amount `, issuedChequeRecordsTableName)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.upsertIssuedChequeRecord = upsertIssuedChequeRecord @@ -94,15 +101,15 @@ func (s *storage) prepareIssuedChequeRecordsStmts(ctx context.Context) error { return nil } -func modelFromIssuedChequeRecord(chequeRecord *issuedChequeRecord) *models.IssuedChequeRecord { - return &models.IssuedChequeRecord{ +func modelFromIssuedChequeRecord(chequeRecord *issuedChequeRecord) *chequehandler.IssuedChequeRecord { + return &chequehandler.IssuedChequeRecord{ ChequeRecordID: chequeRecord.ChequeRecordID, Counter: big.NewInt(0).SetBytes(chequeRecord.Counter), Amount: big.NewInt(0).SetBytes(chequeRecord.Amount), } } -func issuedChequeRecordFromModel(model *models.IssuedChequeRecord) *issuedChequeRecord { +func issuedChequeRecordFromModel(model *chequehandler.IssuedChequeRecord) *issuedChequeRecord { return &issuedChequeRecord{ ChequeRecordID: model.ChequeRecordID, Counter: model.Counter.Bytes(), diff --git a/pkg/chequehandler/storage/sqlite/storage.go b/pkg/chequehandler/storage/sqlite/storage.go new file mode 100644 index 00000000..cc8db4a7 --- /dev/null +++ b/pkg/chequehandler/storage/sqlite/storage.go @@ -0,0 +1,89 @@ +package sqlite + +import ( + "context" + "database/sql" + "errors" + + "github.com/chain4travel/camino-messenger-bot/pkg/chequehandler" + "github.com/chain4travel/camino-messenger-bot/pkg/database" + "github.com/chain4travel/camino-messenger-bot/pkg/database/sqlite" + _ "github.com/golang-migrate/migrate/v4/source/file" // required by migrate + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" // sql driver, required + "go.uber.org/zap" +) + +const dbName = "cheque_handler" + +var ( + _ Storage = (*storage)(nil) + _ chequehandler.Session = (*sqlite.SQLxTxSession)(nil) + _ chequehandler.SessionHandler = (*storage)(nil) +) + +type Storage interface { + Close() error + + chequehandler.Storage +} + +func New(ctx context.Context, logger *zap.SugaredLogger, cfg sqlite.DBConfig) (Storage, error) { + baseDB, err := sqlite.New(logger, cfg, dbName) + if err != nil { + return nil, err + } + + s := &storage{base: baseDB} + + if err := s.prepare(ctx); err != nil { + return nil, err + } + + return s, nil +} + +type storage struct { + base *sqlite.DB + + issuedChequeRecordsStatements + chequeRecordsStatements +} + +func (s *storage) Close() error { + return s.base.Close() +} + +func (s *storage) prepare(ctx context.Context) error { + return errors.Join( + s.prepareIssuedChequeRecordsStmts(ctx), + s.prepareChequeRecordsStmts(ctx), + ) +} + +func (s *storage) NewSession(ctx context.Context) (chequehandler.Session, error) { + return s.base.NewSession(ctx) +} + +func (s *storage) Commit(session chequehandler.Session) error { + return s.base.Commit(session) +} + +func (s *storage) Abort(session chequehandler.Session) { + s.base.Abort(session) +} + +func getSQLXTx(session chequehandler.Session) (*sqlx.Tx, error) { + s, ok := session.(sqlite.SQLxTxer) + if !ok { + return nil, sqlite.ErrUnexpectedSessionType + } + return s.SQLxTx(), nil +} + +func upgradeError(err error) error { + if errors.Is(err, sql.ErrNoRows) { + return database.ErrNotFound + } + return err +} diff --git a/pkg/cm_accounts/cm_accounts.go b/pkg/cm_accounts/cm_accounts.go new file mode 100644 index 00000000..bb343627 --- /dev/null +++ b/pkg/cm_accounts/cm_accounts.go @@ -0,0 +1,250 @@ +package cmaccounts + +import ( + "context" + "crypto/ecdsa" + "fmt" + "math/big" + + "github.com/chain4travel/camino-messenger-bot/pkg/cheques" + "github.com/chain4travel/camino-messenger-contracts/go/contracts/cmaccount" + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethclient" + lru "github.com/hashicorp/golang-lru/v2" + "go.uber.org/zap" +) + +var ( + _ Service = &service{} + + chequeOperatorRole = crypto.Keccak256Hash([]byte("CHEQUE_OPERATOR_ROLE")) +) + +type Service interface { + GetChequeOperators(ctx context.Context, cmAccountAddress common.Address) ([]common.Address, error) + + VerifyCheque(ctx context.Context, cheque *cheques.SignedCheque) (bool, error) + + CashInCheque( + ctx context.Context, + cheque *cheques.SignedCheque, + botKey *ecdsa.PrivateKey, + ) (common.Hash, error) + + GetServiceFee( + ctx context.Context, + cmAccountAddress common.Address, + serviceFullName string, + ) (*big.Int, error) + + IsBotAllowed( + ctx context.Context, + cmAccountAddress common.Address, + botAddress common.Address, + ) (bool, error) + + GetLastCashIn( + ctx context.Context, + cmAccountAddress common.Address, + fromBot common.Address, + toBot common.Address, + ) (counter *big.Int, amount *big.Int, err error) +} + +func NewService( + logger *zap.SugaredLogger, + cacheSize int, + ethClient *ethclient.Client, +) (Service, error) { + chainID, err := ethClient.ChainID(context.Background()) + if err != nil { + logger.Errorf("Failed to get chain ID: %v", err) + return nil, err + } + + cache, err := lru.New[common.Address, *cmaccount.Cmaccount](cacheSize) + if err != nil { + return nil, err + } + + return &service{ + ethClient: ethClient, + cache: cache, + logger: logger, + chainID: chainID, + }, nil +} + +type service struct { + ethClient *ethclient.Client + cache *lru.Cache[common.Address, *cmaccount.Cmaccount] + logger *zap.SugaredLogger + chainID *big.Int +} + +func (s *service) GetChequeOperators(ctx context.Context, cmAccountAddress common.Address) ([]common.Address, error) { + cmAccount, err := s.cmAccount(cmAccountAddress) + if err != nil { + s.logger.Errorf("Failed to get cm account: %v", err) + return nil, err + } + + countBig, err := cmAccount.GetRoleMemberCount(&bind.CallOpts{Context: ctx}, chequeOperatorRole) + if err != nil { + s.logger.Errorf("Failed to call contract function: %v", err) + return nil, err + } + + count := countBig.Int64() + botsAddresses := make([]common.Address, 0, count) + for i := int64(0); i < count; i++ { + address, err := cmAccount.GetRoleMember(&bind.CallOpts{Context: ctx}, chequeOperatorRole, big.NewInt(i)) + if err != nil { + s.logger.Errorf("Failed to call contract function: %v", err) + continue + } + botsAddresses = append(botsAddresses, address) + } + + return botsAddresses, nil +} + +func (s *service) CashInCheque( + ctx context.Context, + cheque *cheques.SignedCheque, + botKey *ecdsa.PrivateKey, +) (common.Hash, error) { + cmAccount, err := s.cmAccount(cheque.FromCMAccount) + if err != nil { + s.logger.Errorf("failed to get cmAccount contract instance: %v", err) + return common.Hash{}, err + } + + transactor, err := bind.NewKeyedTransactorWithChainID(botKey, s.chainID) + if err != nil { + s.logger.Error(err) + return common.Hash{}, err + } + transactor.Context = ctx + + tx, err := cmAccount.CashInCheque( + transactor, + cheque.FromCMAccount, + cheque.ToCMAccount, + cheque.ToBot, + cheque.Counter, + cheque.Amount, + cheque.CreatedAt, + cheque.ExpiresAt, + cheque.Signature, + ) + if err != nil { + s.logger.Errorf("failed to cash in cheque %s: %v", cheque, err) + return common.Hash{}, err + } + + return tx.Hash(), nil +} + +func (s *service) VerifyCheque(ctx context.Context, cheque *cheques.SignedCheque) (bool, error) { + cmAccount, err := s.cmAccount(cheque.FromCMAccount) + if err != nil { + s.logger.Errorf("failed to get cmAccount contract instance: %v", err) + return false, err + } + + _, err = cmAccount.VerifyCheque( + &bind.CallOpts{Context: ctx}, + cheque.FromCMAccount, + cheque.ToCMAccount, + cheque.ToBot, + cheque.Counter, + cheque.Amount, + cheque.CreatedAt, + cheque.ExpiresAt, + cheque.Signature, + ) + if err != nil && err.Error() == "execution reverted" { + return false, nil + } + return err == nil, err +} + +func (s *service) GetServiceFee( + ctx context.Context, + cmAccountAddress common.Address, + serviceFullName string, +) (*big.Int, error) { + cmAccount, err := s.cmAccount(cmAccountAddress) + if err != nil { + return nil, fmt.Errorf("failed to get supplier cmAccount: %w", err) + } + + serviceFee, err := cmAccount.GetServiceFee( + &bind.CallOpts{Context: ctx}, + serviceFullName, + ) + if err != nil { + return nil, fmt.Errorf("failed to get service fee: %w", err) + } + return serviceFee, nil +} + +func (s *service) IsBotAllowed( + ctx context.Context, + cmAccountAddress common.Address, + botAddress common.Address, +) (bool, error) { + cmAccount, err := s.cmAccount(cmAccountAddress) + if err != nil { + return false, fmt.Errorf("failed to get cmAccount contract instance: %w", err) + } + + isBotAllowed, err := cmAccount.IsBotAllowed( + &bind.CallOpts{Context: ctx}, + botAddress, + ) + if err != nil { + return false, fmt.Errorf("failed to check if bot is allowed: %w", err) + } + return isBotAllowed, nil +} + +func (s *service) GetLastCashIn( + ctx context.Context, + cmAccountAddress common.Address, + fromBot common.Address, + toBot common.Address, +) (counter *big.Int, amount *big.Int, err error) { + cmAccount, err := s.cmAccount(cmAccountAddress) + if err != nil { + return nil, nil, fmt.Errorf("failed to get cmAccount contract instance: %w", err) + } + + lastCashIn, err := cmAccount.GetLastCashIn( + &bind.CallOpts{Context: ctx}, + fromBot, + toBot, + ) + if err != nil { + return nil, nil, fmt.Errorf("failed to get last cash in: %w", err) + } + return lastCashIn.LastCounter, lastCashIn.LastAmount, nil +} + +func (s *service) cmAccount(cmAccountAddr common.Address) (*cmaccount.Cmaccount, error) { + cmAccount, ok := s.cache.Get(cmAccountAddr) + if ok { + return cmAccount, nil + } + + cmaccount, err := cmaccount.NewCmaccount(cmAccountAddr, s.ethClient) + if err != nil { + return nil, err + } + s.cache.Add(cmAccountAddr, cmaccount) + + return cmaccount, nil +} diff --git a/pkg/cm_accounts/mock_cm_accounts.go b/pkg/cm_accounts/mock_cm_accounts.go new file mode 100644 index 00000000..4afd806e --- /dev/null +++ b/pkg/cm_accounts/mock_cm_accounts.go @@ -0,0 +1,135 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/chain4travel/camino-messenger-bot/pkg/cm_accounts (interfaces: Service) +// +// Generated by this command: +// +// mockgen -package=cmaccounts -destination=pkg/cm_accounts/mock_cm_accounts.go github.com/chain4travel/camino-messenger-bot/pkg/cm_accounts Service +// + +// Package cmaccounts is a generated GoMock package. +package cmaccounts + +import ( + context "context" + ecdsa "crypto/ecdsa" + big "math/big" + reflect "reflect" + + cheques "github.com/chain4travel/camino-messenger-bot/pkg/cheques" + common "github.com/ethereum/go-ethereum/common" + gomock "go.uber.org/mock/gomock" +) + +// MockService is a mock of Service interface. +type MockService struct { + ctrl *gomock.Controller + recorder *MockServiceMockRecorder +} + +// MockServiceMockRecorder is the mock recorder for MockService. +type MockServiceMockRecorder struct { + mock *MockService +} + +// NewMockService creates a new mock instance. +func NewMockService(ctrl *gomock.Controller) *MockService { + mock := &MockService{ctrl: ctrl} + mock.recorder = &MockServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockService) EXPECT() *MockServiceMockRecorder { + return m.recorder +} + +// CashInCheque mocks base method. +func (m *MockService) CashInCheque(arg0 context.Context, arg1 *cheques.SignedCheque, arg2 *ecdsa.PrivateKey) (common.Hash, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CashInCheque", arg0, arg1, arg2) + ret0, _ := ret[0].(common.Hash) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CashInCheque indicates an expected call of CashInCheque. +func (mr *MockServiceMockRecorder) CashInCheque(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CashInCheque", reflect.TypeOf((*MockService)(nil).CashInCheque), arg0, arg1, arg2) +} + +// GetChequeOperators mocks base method. +func (m *MockService) GetChequeOperators(arg0 context.Context, arg1 common.Address) ([]common.Address, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChequeOperators", arg0, arg1) + ret0, _ := ret[0].([]common.Address) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChequeOperators indicates an expected call of GetChequeOperators. +func (mr *MockServiceMockRecorder) GetChequeOperators(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChequeOperators", reflect.TypeOf((*MockService)(nil).GetChequeOperators), arg0, arg1) +} + +// GetLastCashIn mocks base method. +func (m *MockService) GetLastCashIn(arg0 context.Context, arg1, arg2, arg3 common.Address) (*big.Int, *big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLastCashIn", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(*big.Int) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetLastCashIn indicates an expected call of GetLastCashIn. +func (mr *MockServiceMockRecorder) GetLastCashIn(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLastCashIn", reflect.TypeOf((*MockService)(nil).GetLastCashIn), arg0, arg1, arg2, arg3) +} + +// GetServiceFee mocks base method. +func (m *MockService) GetServiceFee(arg0 context.Context, arg1 common.Address, arg2 string) (*big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceFee", arg0, arg1, arg2) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceFee indicates an expected call of GetServiceFee. +func (mr *MockServiceMockRecorder) GetServiceFee(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceFee", reflect.TypeOf((*MockService)(nil).GetServiceFee), arg0, arg1, arg2) +} + +// IsBotAllowed mocks base method. +func (m *MockService) IsBotAllowed(arg0 context.Context, arg1, arg2 common.Address) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsBotAllowed", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsBotAllowed indicates an expected call of IsBotAllowed. +func (mr *MockServiceMockRecorder) IsBotAllowed(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBotAllowed", reflect.TypeOf((*MockService)(nil).IsBotAllowed), arg0, arg1, arg2) +} + +// VerifyCheque mocks base method. +func (m *MockService) VerifyCheque(arg0 context.Context, arg1 *cheques.SignedCheque) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "VerifyCheque", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// VerifyCheque indicates an expected call of VerifyCheque. +func (mr *MockServiceMockRecorder) VerifyCheque(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyCheque", reflect.TypeOf((*MockService)(nil).VerifyCheque), arg0, arg1) +} diff --git a/pkg/database/common.go b/pkg/database/common.go new file mode 100644 index 00000000..7fc00cfd --- /dev/null +++ b/pkg/database/common.go @@ -0,0 +1,5 @@ +package database + +import "errors" + +var ErrNotFound = errors.New("not found") diff --git a/pkg/database/sqlite/session.go b/pkg/database/sqlite/session.go new file mode 100644 index 00000000..36c868c0 --- /dev/null +++ b/pkg/database/sqlite/session.go @@ -0,0 +1,78 @@ +package sqlite + +import ( + "context" + "database/sql" + "errors" + + "github.com/jmoiron/sqlx" +) + +var ( + _ Session = (*SQLxTxSession)(nil) + _ SQLxTxer = (*SQLxTxSession)(nil) + + ErrAlreadyCommitted = errors.New("already committed") + ErrUnexpectedSessionType = errors.New("unexpected session type") +) + +type Session interface { + Commit() error + Abort() error +} + +type SQLxTxer interface { + SQLxTx() *sqlx.Tx +} + +func (s *DB) NewSession(ctx context.Context) (*SQLxTxSession, error) { + tx, err := s.DB.BeginTxx(ctx, &sql.TxOptions{ + Isolation: sql.LevelSerializable, + }) + if err != nil { + s.Logger.Error(err) + return nil, err + } + return &SQLxTxSession{Tx: tx}, nil +} + +func (s *DB) Commit(session Session) error { + if err := session.Commit(); err != nil { + s.Logger.Error(err) + return err + } + return nil +} + +func (s *DB) Abort(session Session) { + if err := session.Abort(); err != nil { + s.Logger.Error(err) + } +} + +type SQLxTxSession struct { + *sqlx.Tx + committed bool +} + +func (s *SQLxTxSession) Commit() error { + if s.committed { + return ErrAlreadyCommitted + } + if err := s.Tx.Commit(); err != nil { + return err + } + s.committed = true + return nil +} + +func (s *SQLxTxSession) Abort() error { + if s.committed { + return nil + } + return s.Tx.Rollback() +} + +func (s *SQLxTxSession) SQLxTx() *sqlx.Tx { + return s.Tx +} diff --git a/pkg/database/sqlite/storage.go b/pkg/database/sqlite/storage.go new file mode 100644 index 00000000..ad0ce3ab --- /dev/null +++ b/pkg/database/sqlite/storage.go @@ -0,0 +1,97 @@ +package sqlite + +import ( + "errors" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/sqlite3" + _ "github.com/golang-migrate/migrate/v4/source/file" // required by migrate + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" // sql driver, required + "go.uber.org/zap" +) + +type DBConfig struct { + DBPath string + MigrationsPath string +} + +func New(logger *zap.SugaredLogger, cfg DBConfig, dbName string) (*DB, error) { + db, err := sqlx.Open("sqlite3", cfg.DBPath) + if err != nil { + logger.Error(err) + return nil, err + } + + s := &DB{ + Logger: logger, + DB: db, + } + + if err := s.migrate(dbName, cfg.MigrationsPath); err != nil { + return nil, err + } + + return s, nil +} + +type DB struct { + Logger *zap.SugaredLogger + DB *sqlx.DB +} + +func (s *DB) Close() error { + if err := s.DB.Close(); err != nil { + s.Logger.Error(err) + return err + } + return nil +} + +func (s *DB) migrate(dbName, migrationsPath string) error { + s.Logger.Infof("Performing db migrations...") + + driver, err := sqlite3.WithInstance(s.DB.DB, &sqlite3.Config{}) + if err != nil { + s.Logger.Error(err) + return err + } + + migration, err := migrate.NewWithDatabaseInstance(migrationsPath, dbName, driver) + if err != nil { + s.Logger.Error(err) + return err + } + + version, dirty, err := migration.Version() + if err != nil && !errors.Is(err, migrate.ErrNilVersion) { + s.Logger.Error(err) + return err + } + if dirty { + return errors.New("database in dirty state after previous migration, requires manual fixing") + } + s.Logger.Infof("Migration version: %d", version) + + err = migration.Up() + switch { + case errors.Is(err, migrate.ErrNoChange): + s.Logger.Infof("No migrations needed") + case err != nil: + s.Logger.Error(err) + return err + default: + newVersion, dirty, err := migration.Version() + if err != nil && !errors.Is(err, migrate.ErrNilVersion) { + s.Logger.Error(err) + return err + } + if dirty { + return errors.New("database in dirty state after previous migration, requires manual fixing") + } + s.Logger.Infof("New migration version: %d", newVersion) + } + + s.Logger.Infof("Finished preforming db migrations") + return nil +} diff --git a/internal/models/job.go b/pkg/scheduler/job.go similarity index 84% rename from internal/models/job.go rename to pkg/scheduler/job.go index 87cb9cb9..165e92ea 100644 --- a/internal/models/job.go +++ b/pkg/scheduler/job.go @@ -1,4 +1,4 @@ -package models +package scheduler import "time" diff --git a/pkg/scheduler/mock_storage.go b/pkg/scheduler/mock_storage.go new file mode 100644 index 00000000..8cbbc400 --- /dev/null +++ b/pkg/scheduler/mock_storage.go @@ -0,0 +1,125 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/chain4travel/camino-messenger-bot/pkg/scheduler (interfaces: Storage) +// +// Generated by this command: +// +// mockgen -package=scheduler -destination=pkg/scheduler/mock_storage.go github.com/chain4travel/camino-messenger-bot/pkg/scheduler Storage +// + +// Package scheduler is a generated GoMock package. +package scheduler + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStorage is a mock of Storage interface. +type MockStorage struct { + ctrl *gomock.Controller + recorder *MockStorageMockRecorder +} + +// MockStorageMockRecorder is the mock recorder for MockStorage. +type MockStorageMockRecorder struct { + mock *MockStorage +} + +// NewMockStorage creates a new mock instance. +func NewMockStorage(ctrl *gomock.Controller) *MockStorage { + mock := &MockStorage{ctrl: ctrl} + mock.recorder = &MockStorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStorage) EXPECT() *MockStorageMockRecorder { + return m.recorder +} + +// Abort mocks base method. +func (m *MockStorage) Abort(arg0 Session) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Abort", arg0) +} + +// Abort indicates an expected call of Abort. +func (mr *MockStorageMockRecorder) Abort(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abort", reflect.TypeOf((*MockStorage)(nil).Abort), arg0) +} + +// Commit mocks base method. +func (m *MockStorage) Commit(arg0 Session) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockStorageMockRecorder) Commit(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockStorage)(nil).Commit), arg0) +} + +// GetAllJobs mocks base method. +func (m *MockStorage) GetAllJobs(arg0 context.Context, arg1 Session) ([]*Job, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllJobs", arg0, arg1) + ret0, _ := ret[0].([]*Job) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllJobs indicates an expected call of GetAllJobs. +func (mr *MockStorageMockRecorder) GetAllJobs(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllJobs", reflect.TypeOf((*MockStorage)(nil).GetAllJobs), arg0, arg1) +} + +// GetJobByName mocks base method. +func (m *MockStorage) GetJobByName(arg0 context.Context, arg1 Session, arg2 string) (*Job, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetJobByName", arg0, arg1, arg2) + ret0, _ := ret[0].(*Job) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetJobByName indicates an expected call of GetJobByName. +func (mr *MockStorageMockRecorder) GetJobByName(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJobByName", reflect.TypeOf((*MockStorage)(nil).GetJobByName), arg0, arg1, arg2) +} + +// NewSession mocks base method. +func (m *MockStorage) NewSession(arg0 context.Context) (Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSession", arg0) + ret0, _ := ret[0].(Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewSession indicates an expected call of NewSession. +func (mr *MockStorageMockRecorder) NewSession(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSession", reflect.TypeOf((*MockStorage)(nil).NewSession), arg0) +} + +// UpsertJob mocks base method. +func (m *MockStorage) UpsertJob(arg0 context.Context, arg1 Session, arg2 *Job) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertJob", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertJob indicates an expected call of UpsertJob. +func (mr *MockStorageMockRecorder) UpsertJob(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertJob", reflect.TypeOf((*MockStorage)(nil).UpsertJob), arg0, arg1, arg2) +} diff --git a/internal/scheduler/scheduler.go b/pkg/scheduler/scheduler.go similarity index 55% rename from internal/scheduler/scheduler.go rename to pkg/scheduler/scheduler.go index 61d75e73..410e2f68 100644 --- a/internal/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -7,37 +7,64 @@ import ( "sync" "time" - "github.com/chain4travel/camino-messenger-bot/internal/models" - "github.com/chain4travel/camino-messenger-bot/internal/storage" + "github.com/jonboulle/clockwork" "go.uber.org/zap" ) -// TODO @evlekht its duplicate from asb, think of moving to common place -var _ Scheduler = (*scheduler)(nil) +var ( + _ Scheduler = (*scheduler)(nil) + ErrNotFound = errors.New("not found") +) + +type Storage interface { + SessionHandler + + GetAllJobs(ctx context.Context, session Session) ([]*Job, error) + UpsertJob(ctx context.Context, session Session, job *Job) error + GetJobByName(ctx context.Context, session Session, jobName string) (*Job, error) +} + +type SessionHandler interface { + NewSession(ctx context.Context) (Session, error) + Commit(session Session) error + Abort(session Session) +} + +type Session interface { + Commit() error + Abort() error +} type Scheduler interface { Start(ctx context.Context) error - Stop(ctx context.Context) error + Stop() error Schedule(ctx context.Context, period time.Duration, jobName string) error RegisterJobHandler(jobName string, jobHandler func()) } -func New(_ context.Context, logger *zap.SugaredLogger, storage storage.Storage) Scheduler { +type Stopper interface { + Stop() +} + +func New(logger *zap.SugaredLogger, storage Storage, clock clockwork.Clock) Scheduler { return &scheduler{ storage: storage, logger: logger, registry: make(map[string]func()), - timers: make(map[string]*timer), + timers: make(map[string]Stopper), + clock: clock, } } type scheduler struct { logger *zap.SugaredLogger - storage storage.Storage + storage Storage registry map[string]func() - timers map[string]*timer + timers map[string]Stopper + stopTimers func() registryLock sync.RWMutex timersLock sync.RWMutex + clock clockwork.Clock } // Start starts the scheduler. Jobs that are already due are executed immediately. @@ -47,14 +74,17 @@ func (s *scheduler) Start(ctx context.Context) error { s.logger.Errorf("failed to create storage session: %v", err) return err } - defer session.Abort() + defer s.storage.Abort(session) - jobs, err := session.GetAllJobs(ctx) + jobs, err := s.storage.GetAllJobs(ctx, session) if err != nil { s.logger.Errorf("failed to get all jobs: %v", err) return err } + timersCtx, cancel := context.WithCancel(ctx) + s.stopTimers = cancel + for _, job := range jobs { jobHandler, err := s.getJobHandler(job.Name) if err != nil { @@ -65,12 +95,14 @@ func (s *scheduler) Start(ctx context.Context) error { jobName := job.Name period := job.Period - now := time.Now() - timeUntilFirstExecution := time.Duration(0) + now := s.clock.Now() + durationUntilFirstExecution := time.Duration(0) if job.ExecuteAt.After(now) { - timeUntilFirstExecution = job.ExecuteAt.Sub(now) + durationUntilFirstExecution = job.ExecuteAt.Sub(now) } + onceDone := make(chan struct{}) + handler := func() { // TODO @evlekht panic handling? if err := s.updateJobExecutionTime(ctx, jobName); err != nil { @@ -80,26 +112,39 @@ func (s *scheduler) Start(ctx context.Context) error { jobHandler() } - timer := newTimer() - doneCh := timer.StartOnce(timeUntilFirstExecution, handler) + timer := s.clock.AfterFunc(durationUntilFirstExecution, func() { + handler() + close(onceDone) + }) + s.setJobTimer(job.Name, &timerStopper{timer}) + go func() { - <-doneCh - _ = timer.Start(period, handler) + <-onceDone + ticker := s.clock.NewTicker(period) + defer ticker.Stop() + s.setJobTimer(job.Name, ticker) + for { + select { + case <-ticker.Chan(): + handler() + case <-timersCtx.Done(): + return + } + } }() - - s.setJobTimer(job.Name, timer) } return nil } -func (s *scheduler) Stop(_ context.Context) error { - s.timersLock.RLock() - for _, timer := range s.timers { - timer.Stop() +func (s *scheduler) Stop() error { + s.stopTimers() + s.timersLock.Lock() + for jobName := range s.timers { + delete(s.timers, jobName) } - s.timersLock.RUnlock() - // TODO @evlekht await all ongoing job handlers to finish + s.timersLock.Unlock() + // TODO @evlekht await all ongoing job handlers to finish ? return nil } @@ -111,15 +156,15 @@ func (s *scheduler) Schedule(ctx context.Context, period time.Duration, jobName s.logger.Errorf("failed to create storage session: %v", err) return err } - defer session.Abort() + defer s.storage.Abort(session) - job, err := session.GetJobByName(ctx, jobName) - if err != nil && !errors.Is(err, storage.ErrNotFound) { + job, err := s.storage.GetJobByName(ctx, session, jobName) + if err != nil && !errors.Is(err, ErrNotFound) { s.logger.Errorf("failed to get job: %v", err) return err } - executeAt := time.Now().Add(period) + executeAt := s.clock.Now().Add(period) if job != nil { job.Period = period @@ -127,19 +172,19 @@ func (s *scheduler) Schedule(ctx context.Context, period time.Duration, jobName job.ExecuteAt = executeAt } } else { - job = &models.Job{ + job = &Job{ Name: jobName, ExecuteAt: executeAt, Period: period, } } - if err := session.UpsertJob(ctx, job); err != nil { + if err := s.storage.UpsertJob(ctx, session, job); err != nil { s.logger.Errorf("failed to store scheduled job: %v", err) return err } - return session.Commit() + return s.storage.Commit(session) } func (s *scheduler) RegisterJobHandler(jobName string, jobHandler func()) { @@ -154,22 +199,22 @@ func (s *scheduler) updateJobExecutionTime(ctx context.Context, jobName string) s.logger.Errorf("failed to create storage session: %v", err) return err } - defer session.Abort() + defer s.storage.Abort(session) - job, err := session.GetJobByName(ctx, jobName) + job, err := s.storage.GetJobByName(ctx, session, jobName) if err != nil { s.logger.Errorf("failed to get job: %v", err) return err } - job.ExecuteAt = time.Now().Add(job.Period) + job.ExecuteAt = s.clock.Now().Add(job.Period) - if err := session.UpsertJob(ctx, job); err != nil { + if err := s.storage.UpsertJob(ctx, session, job); err != nil { s.logger.Errorf("failed to store scheduled job: %v", err) return err } - if err := session.Commit(); err != nil { + if err := s.storage.Commit(session); err != nil { s.logger.Errorf("failed to commit session: %v", err) return err } @@ -187,8 +232,23 @@ func (s *scheduler) getJobHandler(jobName string) (func(), error) { return jobHandler, nil } -func (s *scheduler) setJobTimer(jobName string, t *timer) { +func (s *scheduler) setJobTimer(jobName string, t Stopper) { s.timersLock.Lock() s.timers[jobName] = t s.timersLock.Unlock() } + +func (s *scheduler) getJobTimer(jobName string) (Stopper, bool) { + s.timersLock.RLock() + timer, ok := s.timers[jobName] + s.timersLock.RUnlock() + return timer, ok +} + +type timerStopper struct { + timer clockwork.Timer +} + +func (t *timerStopper) Stop() { + _ = t.timer.Stop() +} diff --git a/pkg/scheduler/scheduler_test.go b/pkg/scheduler/scheduler_test.go new file mode 100644 index 00000000..11dd21a3 --- /dev/null +++ b/pkg/scheduler/scheduler_test.go @@ -0,0 +1,374 @@ +package scheduler + +import ( + "context" + reflect "reflect" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" + "go.uber.org/zap" +) + +func TestScheduler_Start(t *testing.T) { + // *** base setup + + require := require.New(t) + ctx := context.Background() + clock := clockwork.NewFakeClockAt(time.Unix(0, 100)) + ctrl := gomock.NewController(t) + storage := NewMockStorage(ctrl) + epsilon := time.Millisecond + timeout := 100 * time.Millisecond + + earlyJobExecuted := make(chan string) + nowJobExecuted := make(chan string) + lateJobExecuted := make(chan string) + + earlyJob := Job{ + Name: "early_job", + ExecuteAt: clock.Now().Add(-1), + Period: 1000, + } + nowJob := Job{ + Name: "now_job", + ExecuteAt: clock.Now(), + Period: 1003, + } + lateJob := Job{ + Name: "late_job", + ExecuteAt: clock.Now().Add(1), + Period: 1007, + } + jobs := []*Job{&earlyJob, &nowJob, &lateJob} + jobsExecChansMap := map[string]chan string{ + earlyJob.Name: earlyJobExecuted, + nowJob.Name: nowJobExecuted, + lateJob.Name: lateJobExecuted, + } + jobsExecChans := []chan string{earlyJobExecuted, nowJobExecuted, lateJobExecuted} + + // this is needed for correct time-advancement sequence + + require.Less(earlyJob.ExecuteAt, clock.Now()) + require.Equal(nowJob.ExecuteAt, clock.Now()) + + require.Less(earlyJob.Period, nowJob.Period) + require.Less(nowJob.Period, lateJob.Period) + require.Less(lateJob.Period, timeout-epsilon) + + // *** mock & executionSequence setup + + numberOfFullCycles := 4 // number of how many times each job will be executed + require.Greater(numberOfFullCycles, 1) // at least more than initial startOnce execution + + type executionStep struct { + time time.Time + jobs []Job + initialTimer bool + } + executionSequence := []executionStep{} + + // main goroutine + storageSession := &dummySession{} + storage.EXPECT().NewSession(ctx).Return(storageSession, nil) + storage.EXPECT().GetAllJobs(ctx, storageSession).Return(jobs, nil) + storage.EXPECT().Abort(storageSession) + + // startOnce and periodic start goroutines + + // its clock.Now().Add(-1), but we need real execution time for next mock setup steps + // it will be corrected after + earlyJob.ExecuteAt = clock.Now() // real execution time + + for i := 0; i < numberOfFullCycles; i++ { + for _, originalJob := range jobs { + currentJob := Job{ + Name: originalJob.Name, + ExecuteAt: originalJob.ExecuteAt.Add(originalJob.Period * time.Duration(i)), + Period: originalJob.Period, + } + + newJob := &Job{ + Name: originalJob.Name, + ExecuteAt: currentJob.ExecuteAt.Add(originalJob.Period), + Period: originalJob.Period, + } + + if len(executionSequence) == 0 || executionSequence[len(executionSequence)-1].time != currentJob.ExecuteAt { + executionSequence = append(executionSequence, executionStep{ + time: currentJob.ExecuteAt, + jobs: []Job{currentJob}, + initialTimer: i == 0, + }) + } else { + executionSequence[len(executionSequence)-1].jobs = append(executionSequence[len(executionSequence)-1].jobs, currentJob) + } + + storageSession := &dummySession{} + storage.EXPECT().NewSession(ctx).Return(storageSession, nil) + storage.EXPECT().GetJobByName(ctx, storageSession, currentJob.Name).Return(¤tJob, nil) + storage.EXPECT().UpsertJob(ctx, storageSession, newJob).Return(nil) + storage.EXPECT().Commit(storageSession).Return(nil) + storage.EXPECT().Abort(storageSession) + } + } + + // correct earlyJob.ExecuteAt + earlyJob.ExecuteAt = clock.Now().Add(-1) + + // *** scheduler + + sch := New(zap.NewNop().Sugar(), storage, clock).(*scheduler) + sch.RegisterJobHandler(earlyJob.Name, func() { + earlyJobExecuted <- earlyJob.Name + " executed" + }) + sch.RegisterJobHandler(nowJob.Name, func() { + nowJobExecuted <- nowJob.Name + " executed" + }) + sch.RegisterJobHandler(lateJob.Name, func() { + lateJobExecuted <- lateJob.Name + " executed" + }) + + // *** test + + require.NoError(sch.Start(ctx)) + require.Len(sch.timers, len(jobs)) + + // test that jobs are executed in correct order and time + + for _, step := range executionSequence { + jobNames := make([]string, len(step.jobs)) + for jobIndex, job := range step.jobs { + jobNames[jobIndex] = job.Name + } + + // advancing time to the next expected execution time + clock.Advance(step.time.Sub(clock.Now())) // first execution step will advance time by 0 + require.Equal(step.time, clock.Now()) + + // check that all expected jobs are executed + jobsExecuteChans := make([]chan string, len(step.jobs)) + for jobIndex, job := range step.jobs { + jobsExecuteChans[jobIndex] = jobsExecChansMap[job.Name] + } + + _, ok := waitForAllChannels(jobsExecuteChans, timeout) + require.True(ok, "some jobs weren't executed within timeout") + + // if its first step for this timers, means that + // those timers will be stopped after and replaced with tickers + // we need to make sure that tickers are started before advancing time on next step + if step.initialTimer { + conditions := make([]func() bool, len(step.jobs)) + for jobIndex, job := range step.jobs { + conditions[jobIndex] = func() bool { + jobTimer, ok := sch.getJobTimer(job.Name) + require.True(ok) + _, ok = jobTimer.(clockwork.Ticker) + return ok + } + } + allTimersRearmed := waitForAllConditions(conditions, epsilon, timeout) + require.True(allTimersRearmed, "some timers weren't rearmed within timeout") + } + + require.Equal(step.time, clock.Now()) + } + + require.NoError(sch.Stop()) + + // checking, that all timers were stopped + + maxPeriod := time.Duration(0) + for _, job := range jobs { + if job.Period > maxPeriod { + maxPeriod = job.Period + } + } + + clock.Advance(maxPeriod) + + caseIndex, _, _ := waitForOneChannel(jobsExecChans, timeout) + require.Equal(-1, caseIndex, "some jobs were executed after scheduler and job timers were stopped") +} + +func TestScheduler_RegisterJobHandler(t *testing.T) { + require := require.New(t) + logger := zap.NewNop().Sugar() + clock := clockwork.NewFakeClock() + ctrl := gomock.NewController(t) + storage := NewMockStorage(ctrl) + jobExecuted := "" + jobName1 := "job1" + jobName2 := "job2" + jobHandler1 := func() { jobExecuted = jobName1 } + jobHandler2 := func() { jobExecuted = jobName2 } + + checkJobHandlerRegistered := func(sch *scheduler, jobName string) { + t.Helper() + require.Empty(jobExecuted) + sch.registry[jobName]() + require.Equal(jobName, jobExecuted) + jobExecuted = "" + } + + sch := New(logger, storage, clock).(*scheduler) + + require.Empty(sch.registry) + + // we cannot compare full scheduler structure, because it contains funcs map. Funcs cannot be compared with require.Equal + // this can be changed in the future, if testify will support something like args for ignoring certain fields + // so, we'll only check registered job handlers map by calling handlers + + sch.RegisterJobHandler(jobName1, jobHandler1) + require.Len(sch.registry, 1) + checkJobHandlerRegistered(sch, jobName1) + + sch.RegisterJobHandler(jobName2, jobHandler2) + require.Len(sch.registry, 2) + checkJobHandlerRegistered(sch, jobName1) + checkJobHandlerRegistered(sch, jobName2) +} + +func TestScheduler_Schedule(t *testing.T) { + type testCase struct { + storage func(context.Context, *gomock.Controller, clockwork.Clock, *testCase) Storage + existingJob *Job + jobName string + period time.Duration + expectedErr error + } + + tests := map[string]testCase{ + "OK: New job": { + storage: func(ctx context.Context, ctrl *gomock.Controller, clock clockwork.Clock, tt *testCase) Storage { + storage := NewMockStorage(ctrl) + storageSession := &dummySession{} + storage.EXPECT().NewSession(ctx).Return(storageSession, nil) + storage.EXPECT().GetJobByName(ctx, storageSession, tt.jobName).Return(nil, ErrNotFound) + storage.EXPECT().UpsertJob(ctx, storageSession, &Job{ + Name: tt.jobName, + ExecuteAt: clock.Now().Add(tt.period), + Period: tt.period, + }).Return(nil) + storage.EXPECT().Commit(storageSession).Return(nil) + storage.EXPECT().Abort(storageSession) + return storage + }, + jobName: "new_job", + period: 10 * time.Second, + }, + "OK: Existing job": { + storage: func(ctx context.Context, ctrl *gomock.Controller, _ clockwork.Clock, tt *testCase) Storage { + storage := NewMockStorage(ctrl) + storageSession := &dummySession{} + storage.EXPECT().NewSession(ctx).Return(storageSession, nil) + storage.EXPECT().GetJobByName(ctx, storageSession, tt.jobName).Return(tt.existingJob, nil) + storage.EXPECT().UpsertJob(ctx, storageSession, &Job{ + Name: tt.jobName, + ExecuteAt: tt.existingJob.ExecuteAt, + Period: tt.period, + }).Return(nil) + storage.EXPECT().Commit(storageSession).Return(nil) + storage.EXPECT().Abort(storageSession) + return storage + }, + existingJob: &Job{ + Name: "existing_job", + ExecuteAt: time.Now(), + Period: 10 * time.Second, + }, + jobName: "existing_job", + period: 15 * time.Second, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + + sch := New( + zap.NewNop().Sugar(), + tt.storage(ctx, gomock.NewController(t), clock, &tt), + clock, + ).(*scheduler) + + err := sch.Schedule(ctx, tt.period, tt.jobName) + require.ErrorIs(t, err, tt.expectedErr) + }) + } +} + +type dummySession struct{} + +func (d *dummySession) Commit() error { + return nil +} + +func (d *dummySession) Abort() error { + return nil +} + +func waitForOneChannel[T any](chans []chan T, timeout time.Duration) (chanIndex int, receivedValue T, wasClosed bool) { + selectCases := make([]reflect.SelectCase, len(chans)+1) + for i, ch := range chans { + selectCases[i] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ch), + } + } + selectCases[len(chans)] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(time.After(timeout)), + } + + caseIndex, value, wasClosed := reflect.Select(selectCases) + if caseIndex == len(chans) { + var zeroValue T + return -1, zeroValue, wasClosed + } + return caseIndex, value.Interface().(T), wasClosed +} + +func waitForAllChannels[T any](chans []chan T, timeout time.Duration) ([]T, bool) { + values := make([]T, len(chans)) + timeoutTimer := time.NewTimer(timeout) + + for i := range chans { + select { + case <-timeoutTimer.C: + return values, false + case value := <-chans[i]: + values[i] = value + } + } + return values, true +} + +// assumes that each condition will take negligible time +func waitForAllConditions(conditions []func() bool, checkPeriod, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + ticker := time.NewTicker(checkPeriod) + defer ticker.Stop() + + for { + allTrue := true + for _, condition := range conditions { + if !condition() { + allTrue = false + break + } + } + if allTrue { + return true + } + if time.Now().After(deadline) { + return false + } + <-ticker.C + } +} diff --git a/internal/storage/jobs.go b/pkg/scheduler/storage/sqlite/jobs.go similarity index 53% rename from internal/storage/jobs.go rename to pkg/scheduler/storage/sqlite/jobs.go index 021e7251..cb625cda 100644 --- a/internal/storage/jobs.go +++ b/pkg/scheduler/storage/sqlite/jobs.go @@ -1,4 +1,4 @@ -package storage +package sqlite import ( "context" @@ -7,19 +7,13 @@ import ( "fmt" "time" - "github.com/chain4travel/camino-messenger-bot/internal/models" + "github.com/chain4travel/camino-messenger-bot/pkg/scheduler" "github.com/jmoiron/sqlx" ) const jobsTableName = "jobs" -var _ JobsStorage = (*session)(nil) - -type JobsStorage interface { - GetAllJobs(ctx context.Context) ([]*models.Job, error) - UpsertJob(ctx context.Context, job *models.Job) error - GetJobByName(ctx context.Context, jobName string) (*models.Job, error) -} +var _ scheduler.Storage = (*storage)(nil) type job struct { Name string `db:"name"` @@ -27,26 +21,38 @@ type job struct { Period int64 `db:"period"` } -func (s *session) GetJobByName(ctx context.Context, jobName string) (*models.Job, error) { +func (s *storage) GetJobByName(ctx context.Context, session scheduler.Session, jobName string) (*scheduler.Job, error) { + tx, err := getSQLXTx(session) + if err != nil { + s.base.Logger.Error(err) + return nil, err + } + job := &job{} - if err := s.tx.StmtxContext(ctx, s.storage.getJobByName).GetContext(ctx, job, jobName); err != nil { + if err := tx.StmtxContext(ctx, s.getJobByName).GetContext(ctx, job, jobName); err != nil { if !errors.Is(err, sql.ErrNoRows) { - s.logger.Error(err) + s.base.Logger.Error(err) } return nil, upgradeError(err) } return modelFromJob(job), nil } -func (s *session) UpsertJob(ctx context.Context, job *models.Job) error { - result, err := s.tx.NamedStmtContext(ctx, s.storage.upsertJob). +func (s *storage) UpsertJob(ctx context.Context, session scheduler.Session, job *scheduler.Job) error { + tx, err := getSQLXTx(session) + if err != nil { + s.base.Logger.Error(err) + return err + } + + result, err := tx.NamedStmtContext(ctx, s.upsertJob). ExecContext(ctx, jobFromModel(job)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return upgradeError(err) } if rowsAffected, err := result.RowsAffected(); err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return upgradeError(err) } else if rowsAffected != 1 { return fmt.Errorf("failed to add cheque: expected to affect 1 row, but affected %d", rowsAffected) @@ -54,17 +60,23 @@ func (s *session) UpsertJob(ctx context.Context, job *models.Job) error { return nil } -func (s *session) GetAllJobs(ctx context.Context) ([]*models.Job, error) { - jobs := []*models.Job{} - rows, err := s.tx.StmtxContext(ctx, s.storage.getAllJobs).QueryxContext(ctx) +func (s *storage) GetAllJobs(ctx context.Context, session scheduler.Session) ([]*scheduler.Job, error) { + tx, err := getSQLXTx(session) + if err != nil { + s.base.Logger.Error(err) + return nil, err + } + + jobs := []*scheduler.Job{} + rows, err := tx.StmtxContext(ctx, s.getAllJobs).QueryxContext(ctx) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return nil, upgradeError(err) } for rows.Next() { job := &job{} if err := rows.StructScan(job); err != nil { - s.logger.Errorf("failed to get not cashed cheque from db: %v", err) + s.base.Logger.Errorf("failed to get not cashed cheque from db: %v", err) continue } jobs = append(jobs, modelFromJob(job)) @@ -78,17 +90,17 @@ type jobsStatements struct { } func (s *storage) prepareJobsStmts(ctx context.Context) error { - getJobByName, err := s.db.PreparexContext(ctx, fmt.Sprintf(` + getJobByName, err := s.base.DB.PreparexContext(ctx, fmt.Sprintf(` SELECT * FROM %s WHERE name = ? `, jobsTableName)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.getJobByName = getJobByName - upsertJob, err := s.db.PrepareNamedContext(ctx, fmt.Sprintf(` + upsertJob, err := s.base.DB.PrepareNamedContext(ctx, fmt.Sprintf(` INSERT INTO %s ( name, execute_at, @@ -102,16 +114,16 @@ func (s *storage) prepareJobsStmts(ctx context.Context) error { DO UPDATE SET period = excluded.period `, jobsTableName)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.upsertJob = upsertJob - getAllJobs, err := s.db.PreparexContext(ctx, fmt.Sprintf(` + getAllJobs, err := s.base.DB.PreparexContext(ctx, fmt.Sprintf(` SELECT * FROM %s `, jobsTableName)) if err != nil { - s.logger.Error(err) + s.base.Logger.Error(err) return err } s.getAllJobs = getAllJobs @@ -119,15 +131,15 @@ func (s *storage) prepareJobsStmts(ctx context.Context) error { return nil } -func modelFromJob(job *job) *models.Job { - return &models.Job{ +func modelFromJob(job *job) *scheduler.Job { + return &scheduler.Job{ Name: job.Name, ExecuteAt: time.Unix(job.ExecuteAt, 0), Period: time.Duration(job.Period) * time.Second, } } -func jobFromModel(model *models.Job) *job { +func jobFromModel(model *scheduler.Job) *job { return &job{ Name: model.Name, ExecuteAt: model.ExecuteAt.Unix(), diff --git a/pkg/scheduler/storage/sqlite/storage.go b/pkg/scheduler/storage/sqlite/storage.go new file mode 100644 index 00000000..51d2c85d --- /dev/null +++ b/pkg/scheduler/storage/sqlite/storage.go @@ -0,0 +1,85 @@ +package sqlite + +import ( + "context" + "database/sql" + "errors" + + "github.com/chain4travel/camino-messenger-bot/pkg/database" + "github.com/chain4travel/camino-messenger-bot/pkg/database/sqlite" + "github.com/chain4travel/camino-messenger-bot/pkg/scheduler" + _ "github.com/golang-migrate/migrate/v4/source/file" // required by migrate + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" // sql driver, required + "go.uber.org/zap" +) + +const dbName = "scheduler" + +var ( + _ Storage = (*storage)(nil) + _ scheduler.Session = (*sqlite.SQLxTxSession)(nil) + _ scheduler.SessionHandler = (*storage)(nil) +) + +type Storage interface { + Close() error + + scheduler.Storage +} + +func New(ctx context.Context, logger *zap.SugaredLogger, cfg sqlite.DBConfig) (Storage, error) { + baseDB, err := sqlite.New(logger, cfg, dbName) + if err != nil { + return nil, err + } + + s := &storage{base: baseDB} + + if err := s.prepare(ctx); err != nil { + return nil, err + } + + return s, nil +} + +type storage struct { + base *sqlite.DB + + jobsStatements +} + +func (s *storage) Close() error { + return s.base.Close() +} + +func (s *storage) prepare(ctx context.Context) error { + return s.prepareJobsStmts(ctx) +} + +func (s *storage) NewSession(ctx context.Context) (scheduler.Session, error) { + return s.base.NewSession(ctx) +} + +func (s *storage) Commit(session scheduler.Session) error { + return s.base.Commit(session) +} + +func (s *storage) Abort(session scheduler.Session) { + s.base.Abort(session) +} + +func getSQLXTx(session scheduler.Session) (*sqlx.Tx, error) { + s, ok := session.(sqlite.SQLxTxer) + if !ok { + return nil, sqlite.ErrUnexpectedSessionType + } + return s.SQLxTx(), nil +} + +func upgradeError(err error) error { + if errors.Is(err, sql.ErrNoRows) { + return database.ErrNotFound + } + return err +} diff --git a/scripts/mock.gen.sh b/scripts/mock.gen.sh index 8491b1c0..6216a52a 100755 --- a/scripts/mock.gen.sh +++ b/scripts/mock.gen.sh @@ -13,17 +13,22 @@ then go install -v go.uber.org/mock/mockgen@v0.4.0 fi -# tuples of (source interface import path, comma-separated interface names, output file path) +# tuples of (source interface import path, comma-separated interface names, output file path) with optional 4th argument for package name input="scripts/mocks.mockgen.txt" while IFS= read -r line do - IFS='=' read src_import_path interface_name output_path <<< "${line}" - package_name=$(basename "$(dirname $output_path)") + IFS='=' read src_import_path interface_name output_path package_name <<< "${line}" + # If package_name is not provided, use the basename of the directory containing the output file + if [[ -z "$package_name" ]]; then + package_name=$(basename "$(dirname "$output_path")") + fi + [[ $src_import_path == \#* ]] && continue echo "Generating ${output_path}..." mockgen -package=${package_name} -destination=${output_path} ${src_import_path} ${interface_name} done < "$input" + # tuples of (source import path, comma-separated interface names to exclude, output file path) input="scripts/mocks.mockgen.source.txt" while IFS= read -r line diff --git a/scripts/mocks.mockgen.txt b/scripts/mocks.mockgen.txt index 82ba3a72..3d1de3cc 100644 --- a/scripts/mocks.mockgen.txt +++ b/scripts/mocks.mockgen.txt @@ -2,3 +2,7 @@ github.com/chain4travel/camino-messenger-bot/internal/compression=Decompressor=i github.com/chain4travel/camino-messenger-bot/internal/matrix=Client=internal/matrix/mock_client.go github.com/chain4travel/camino-messenger-bot/internal/messaging=Messenger=internal/messaging/mock_messenger.go github.com/chain4travel/camino-messenger-bot/internal/messaging=ServiceRegistry=internal/messaging/mock_service_registry.go +github.com/chain4travel/camino-messenger-bot/pkg/cm_accounts=Service=pkg/cm_accounts/mock_cm_accounts.go=cmaccounts +github.com/chain4travel/camino-messenger-bot/pkg/scheduler=Storage=pkg/scheduler/mock_storage.go +github.com/chain4travel/camino-messenger-bot/pkg/chequehandler=ChequeHandler=pkg/chequehandler/mock_cheque_handler.go=chequehandler +github.com/chain4travel/camino-messenger-bot/pkg/chequehandler=Storage=pkg/chequehandler/mock_storage.go=chequehandler