Skip to content

Commit

Permalink
go/worker/txnscheduler: Check txns before queuing them
Browse files Browse the repository at this point in the history
  • Loading branch information
abukosek committed Jan 13, 2020
1 parent 46c3e9d commit 2521985
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 64 deletions.
1 change: 1 addition & 0 deletions go/oasis-node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ var (
{workerCommon.CfgClientPort, workerClientPort},
{storageWorker.CfgWorkerEnabled, true},
{txnscheduler.CfgWorkerEnabled, true},
{txnscheduler.CfgCheckTxEnabled, false},
{mergeWorker.CfgWorkerEnabled, true},
{supplementarysanity.CfgEnabled, true},
{supplementarysanity.CfgInterval, 1},
Expand Down
9 changes: 9 additions & 0 deletions go/worker/common/runtime_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
keymanagerApi "github.com/oasislabs/oasis-core/go/keymanager/api"
keymanagerClient "github.com/oasislabs/oasis-core/go/keymanager/client"
registry "github.com/oasislabs/oasis-core/go/registry/api"
roothash "github.com/oasislabs/oasis-core/go/roothash/api/block"
"github.com/oasislabs/oasis-core/go/runtime/localstorage"
storage "github.com/oasislabs/oasis-core/go/storage/api"
"github.com/oasislabs/oasis-core/go/worker/common/committee"
Expand Down Expand Up @@ -268,6 +269,14 @@ func (n *RuntimeHostNode) GetWorkerHostLocked() host.Host {
return n.workerHost
}

// GetCurrentBlock returns the current roothash block from the underlying common node.
func (n *RuntimeHostNode) GetCurrentBlock() roothash.Block {
n.commonNode.CrossNode.Lock()
defer n.commonNode.CrossNode.Unlock()

return *n.commonNode.CurrentBlock
}

// NewRuntimeHostNode creates a new runtime host node.
func NewRuntimeHostNode(commonNode *committee.Node, workerHostFactory host.Factory) *RuntimeHostNode {
return &RuntimeHostNode{
Expand Down
3 changes: 3 additions & 0 deletions go/worker/txnscheduler/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ var (
// ErrNotReady is the error returned when the transaction scheduler is not
// yet ready to process transactions.
ErrNotReady = errors.New(ModuleName, 3, "txnscheduler: not ready")

// ErrCheckTxFailed is the error returned when CheckTx fails.
ErrCheckTxFailed = errors.New(ModuleName, 4, "txnscheduler: CheckTx failed")
)

// TransactionScheduler is the transaction scheduler API interface.
Expand Down
22 changes: 14 additions & 8 deletions go/worker/txnscheduler/committee/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ var (
)

// Node is a committee node.
type Node struct {
type Node struct { // nolint: maligned
*commonWorker.RuntimeHostNode

checkTxEnabled bool

commonNode *committee.Node
computeNode *computeCommittee.Node

Expand Down Expand Up @@ -389,14 +391,16 @@ func (n *Node) worker() {

n.logger.Info("starting committee node")

// Initialize worker host for the new runtime.
if err := n.InitializeRuntimeWorkerHost(n.ctx); err != nil {
n.logger.Error("failed to initialize worker host",
"err", err,
)
return
if n.checkTxEnabled {
// Initialize worker host for the new runtime.
if err := n.InitializeRuntimeWorkerHost(n.ctx); err != nil {
n.logger.Error("failed to initialize worker host",
"err", err,
)
return
}
defer n.StopRuntimeWorkerHost()
}
defer n.StopRuntimeWorkerHost()

// Initialize transaction scheduler's algorithm.
runtime, err := n.commonNode.Runtime.RegistryDescriptor(n.ctx)
Expand Down Expand Up @@ -451,6 +455,7 @@ func NewNode(
commonNode *committee.Node,
computeNode *computeCommittee.Node,
workerHostFactory host.Factory,
checkTxEnabled bool,
) (*Node, error) {
metricsOnce.Do(func() {
prometheus.MustRegister(nodeCollectors...)
Expand All @@ -460,6 +465,7 @@ func NewNode(

n := &Node{
RuntimeHostNode: commonWorker.NewRuntimeHostNode(commonNode, workerHostFactory),
checkTxEnabled: checkTxEnabled,
commonNode: commonNode,
computeNode: computeNode,
ctx: ctx,
Expand Down
10 changes: 9 additions & 1 deletion go/worker/txnscheduler/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
const (
// CfgWorkerEnabled enables the tx scheduler worker.
CfgWorkerEnabled = "worker.txn_scheduler.enabled"
// CfgCheckTxEnabled enables checking each transaction before scheduling it.
CfgCheckTxEnabled = "worker.txn_scheduler.check_tx.enabled"
)

// Flags has the configuration flags.
Expand All @@ -23,17 +25,23 @@ func Enabled() bool {
return viper.GetBool(CfgWorkerEnabled)
}

// CheckTxEnabled reads our CheckTx enabled flag from viper.
func CheckTxEnabled() bool {
return viper.GetBool(CfgCheckTxEnabled)
}

// New creates a new worker.
func New(
commonWorker *workerCommon.Worker,
compute *compute.Worker,
registration *registration.Worker,
) (*Worker, error) {
return newWorker(Enabled(), commonWorker, compute, registration)
return newWorker(Enabled(), commonWorker, compute, registration, CheckTxEnabled())
}

func init() {
Flags.Bool(CfgWorkerEnabled, false, "Enable transaction scheduler process")
Flags.Bool(CfgCheckTxEnabled, true, "Enable checking transactions before scheduling them")

_ = viper.BindPFlags(Flags)

Expand Down
51 changes: 51 additions & 0 deletions go/worker/txnscheduler/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package txnscheduler

import (
"context"
"fmt"

"github.com/oasislabs/oasis-core/go/common/cbor"
"github.com/oasislabs/oasis-core/go/runtime/transaction"
"github.com/oasislabs/oasis-core/go/worker/common/host/protocol"
"github.com/oasislabs/oasis-core/go/worker/txnscheduler/api"
)

Expand All @@ -15,6 +19,53 @@ func (w *Worker) SubmitTx(ctx context.Context, rq *api.SubmitTxRequest) (*api.Su
return nil, api.ErrUnknownRuntime
}

if w.checkTxEnabled {
// Check transaction before queuing it.
checkRq := &protocol.Body{
WorkerCheckTxBatchRequest: &protocol.WorkerCheckTxBatchRequest{
Inputs: transaction.RawBatch{rq.Data},
Block: runtime.GetCurrentBlock(),
},
}
workerHost := runtime.GetWorkerHost()
if workerHost == nil {
w.logger.Error("worker host not initialized")
return nil, api.ErrNotReady
}
resp, err := workerHost.Call(ctx, checkRq)
if err != nil {
w.logger.Error("worker host CheckTx call error",
"err", err,
)
return nil, err
}
if resp == nil {
w.logger.Error("worker host CheckTx reponse is nil")
return nil, api.ErrCheckTxFailed
}
if resp.WorkerCheckTxBatchResponse.Results == nil {
w.logger.Error("worker host CheckTx response contains no results")
return nil, api.ErrCheckTxFailed
}
if len(resp.WorkerCheckTxBatchResponse.Results) != 1 {
w.logger.Error("worker host CheckTx response doesn't contain exactly one result",
"num_results", len(resp.WorkerCheckTxBatchResponse.Results),
)
return nil, api.ErrCheckTxFailed
}

// Interpret CheckTx result.
resultRaw := resp.WorkerCheckTxBatchResponse.Results[0]
var result transaction.TxnOutput
cbor.MustUnmarshal(resultRaw, &result)
if result.Error != nil {
w.logger.Error("worker CheckTx failed with error",
"err", result.Error,
)
return nil, fmt.Errorf("%w: %s", api.ErrCheckTxFailed, *result.Error)
}
}

if err := runtime.QueueCall(ctx, rq.Data); err != nil {
return nil, err
}
Expand Down
80 changes: 25 additions & 55 deletions go/worker/txnscheduler/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import (
type Worker struct {
*workerCommon.RuntimeHostWorker

enabled bool
enabled bool
checkTxEnabled bool

commonWorker *workerCommon.Worker
registration *registration.Worker
Expand Down Expand Up @@ -152,7 +153,7 @@ func (w *Worker) registerRuntime(commonNode *committeeCommon.Node) error {
}

// Create committee node for the given runtime.
node, err := committee.NewNode(commonNode, computeNode, workerHostFactory)
node, err := committee.NewNode(commonNode, computeNode, workerHostFactory, w.checkTxEnabled)
if err != nil {
return err
}
Expand All @@ -172,29 +173,32 @@ func newWorker(
commonWorker *workerCommon.Worker,
compute *compute.Worker,
registration *registration.Worker,
checkTxEnabled bool,
) (*Worker, error) {
ctx, cancelCtx := context.WithCancel(context.Background())

w := &Worker{
enabled: enabled,
commonWorker: commonWorker,
registration: registration,
compute: compute,
runtimes: make(map[common.Namespace]*committee.Node),
ctx: ctx,
cancelCtx: cancelCtx,
quitCh: make(chan struct{}),
initCh: make(chan struct{}),
logger: logging.GetLogger("worker/txnscheduler"),
enabled: enabled,
checkTxEnabled: checkTxEnabled,
commonWorker: commonWorker,
registration: registration,
compute: compute,
runtimes: make(map[common.Namespace]*committee.Node),
ctx: ctx,
cancelCtx: cancelCtx,
quitCh: make(chan struct{}),
initCh: make(chan struct{}),
logger: logging.GetLogger("worker/txnscheduler"),
}

if enabled {
if !w.commonWorker.Enabled() {
panic("common worker should have been enabled for transaction scheduler")
}

// Create the runtime host worker.
var err error

// Create the runtime host worker.
w.RuntimeHostWorker, err = workerCommon.NewRuntimeHostWorker(commonWorker)
if err != nil {
return nil, err
Expand All @@ -212,48 +216,14 @@ func newWorker(

// Register transaction scheduler worker role.
if err = w.registration.RegisterRole(node.RoleTransactionScheduler, func(n *node.Node) error {
// Wait until all the runtimes are initialized.
for _, rt := range w.runtimes {
select {
case <-rt.Initialized():
case <-w.ctx.Done():
return w.ctx.Err()
}
}

for _, rt := range n.Runtimes {
var grr error

workerRT := w.runtimes[rt.ID]
if workerRT == nil {
continue
}

workerHost := workerRT.GetWorkerHost()
if workerHost == nil {
w.logger.Debug("runtime has shut down",
"runtime", rt.ID,
)
continue
}
if rt.Capabilities.TEE, grr = workerHost.WaitForCapabilityTEE(w.ctx); grr != nil {
w.logger.Error("failed to obtain CapabilityTEE",
"err", grr,
"runtime", rt.ID,
)
continue
}

runtimeVersion, grr := workerHost.WaitForRuntimeVersion(w.ctx)
if grr == nil && runtimeVersion != nil {
rt.Version = *runtimeVersion
} else {
w.logger.Error("failed to obtain RuntimeVersion",
"err", grr,
"runtime", rt.ID,
"runtime_version", runtimeVersion,
)
continue
if w.checkTxEnabled {
// Wait until all the runtimes are initialized.
for _, rt := range w.runtimes {
select {
case <-rt.Initialized():
case <-w.ctx.Done():
return w.ctx.Err()
}
}
}

Expand Down

0 comments on commit 2521985

Please sign in to comment.