Skip to content

Commit

Permalink
make deal state channel id nilable (#490)
Browse files Browse the repository at this point in the history
* fix: make client deal state channel id nilable

* fix: make provider deal state channel id nilable
  • Loading branch information
dirkmc authored Feb 24, 2021
1 parent f24d924 commit d1405ef
Show file tree
Hide file tree
Showing 18 changed files with 1,643 additions and 87 deletions.
3 changes: 1 addition & 2 deletions retrievalmarket/impl/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"github.com/filecoin-project/go-address"
datatransfer "github.com/filecoin-project/go-data-transfer"
versioning "github.com/filecoin-project/go-ds-versioning/pkg"
versionedfsm "github.com/filecoin-project/go-ds-versioning/pkg/fsm"
"github.com/filecoin-project/go-multistore"
"github.com/filecoin-project/go-state-types/abi"
Expand Down Expand Up @@ -99,7 +98,7 @@ func NewClient(
StateEntryFuncs: clientstates.ClientStateEntryFuncs,
FinalityStates: clientstates.ClientFinalityStates,
Notifier: c.notifySubscribers,
}, retrievalMigrations, versioning.VersionKey("1"))
}, retrievalMigrations, "2")
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion retrievalmarket/impl/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func TestMigrations(t *testing.T) {
},
},
StoreID: storeIDs[i],
ChannelID: channelIDs[i],
ChannelID: &channelIDs[i],
LastPaymentRequested: lastPaymentRequesteds[i],
AllBlocksReceived: allBlocksReceiveds[i],
TotalFunds: totalFundss[i],
Expand Down
2 changes: 1 addition & 1 deletion retrievalmarket/impl/clientstates/client_fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var ClientEvents = fsm.Events{
From(rm.DealStatusNew).To(rm.DealStatusWaitForAcceptance).
From(rm.DealStatusRetryLegacy).To(rm.DealStatusWaitForAcceptanceLegacy).
Action(func(deal *rm.ClientDealState, channelID datatransfer.ChannelID) error {
deal.ChannelID = channelID
deal.ChannelID = &channelID
deal.Message = ""
return nil
}),
Expand Down
13 changes: 8 additions & 5 deletions retrievalmarket/impl/clientstates/client_states.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func SendFunds(ctx fsm.Context, environment ClientDealEnvironment, deal rm.Clien
}

// send payment voucher (or fail)
err = environment.SendDataTransferVoucher(ctx.Context(), deal.ChannelID, &rm.DealPayment{
err = environment.SendDataTransferVoucher(ctx.Context(), *deal.ChannelID, &rm.DealPayment{
ID: deal.DealProposal.ID,
PaymentChannel: deal.PaymentInfo.PayCh,
PaymentVoucher: voucher,
Expand Down Expand Up @@ -164,10 +164,13 @@ func CheckFunds(ctx fsm.Context, environment ClientDealEnvironment, deal rm.Clie

// CancelDeal clears a deal that went wrong for an unknown reason
func CancelDeal(ctx fsm.Context, environment ClientDealEnvironment, deal rm.ClientDealState) error {
// Read next response (or fail)
err := environment.CloseDataTransfer(ctx.Context(), deal.ChannelID)
if err != nil {
return ctx.Trigger(rm.ClientEventDataTransferError, err)
// If the data transfer has started, cancel it
if deal.ChannelID != nil {
// Read next response (or fail)
err := environment.CloseDataTransfer(ctx.Context(), *deal.ChannelID)
if err != nil {
return ctx.Trigger(rm.ClientEventDataTransferError, err)
}
}

return ctx.Trigger(rm.ClientEventCancelComplete)
Expand Down
10 changes: 10 additions & 0 deletions retrievalmarket/impl/clientstates/client_states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ func TestSendFunds(t *testing.T) {
node := testnodes.NewTestRetrievalClientNode(nodeParams)
environment := &fakeEnvironment{node, nil, sendDataTransferVoucherError, nil}
fsmCtx := fsmtest.NewTestContext(ctx, eventMachine)
dealState.ChannelID = &datatransfer.ChannelID{
Initiator: "initiator",
Responder: dealState.Sender,
ID: 1,
}
err := clientstates.SendFunds(fsmCtx, environment, *dealState)
require.NoError(t, err)
fsmCtx.ReplayEvents(t, dealState)
Expand Down Expand Up @@ -527,6 +532,11 @@ func TestCancelDeal(t *testing.T) {
node := testnodes.NewTestRetrievalClientNode(testnodes.TestRetrievalClientNodeParams{})
environment := &fakeEnvironment{node, nil, nil, closeError}
fsmCtx := fsmtest.NewTestContext(ctx, eventMachine)
dealState.ChannelID = &datatransfer.ChannelID{
Initiator: "initiator",
Responder: dealState.Sender,
ID: 1,
}
err := clientstates.CancelDeal(fsmCtx, environment, *dealState)
require.NoError(t, err)
fsmCtx.ReplayEvents(t, dealState)
Expand Down
2 changes: 1 addition & 1 deletion retrievalmarket/impl/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ func TestProviderMigrations(t *testing.T) {
},
},
StoreID: storeIDs[i],
ChannelID: channelIDs[i],
ChannelID: &channelIDs[i],
PieceInfo: &piecestore.PieceInfo{
PieceCID: *pieceCIDs[i],
Deals: []piecestore.DealInfo{
Expand Down
2 changes: 1 addition & 1 deletion retrievalmarket/impl/providerstates/provider_fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var ProviderEvents = fsm.Events{
From(rm.DealStatusFundsNeededUnseal).ToNoChange().
From(rm.DealStatusNew).To(rm.DealStatusUnsealing).
Action(func(deal *rm.ProviderDealState, channelID datatransfer.ChannelID) error {
deal.ChannelID = channelID
deal.ChannelID = &channelID
return nil
}),

Expand Down
16 changes: 10 additions & 6 deletions retrievalmarket/impl/providerstates/provider_states.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ func UnpauseDeal(ctx fsm.Context, environment ProviderDealEnvironment, deal rm.P
if err != nil {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
}
err = environment.ResumeDataTransfer(ctx.Context(), deal.ChannelID)
if err != nil {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
if deal.ChannelID != nil {
err = environment.ResumeDataTransfer(ctx.Context(), *deal.ChannelID)
if err != nil {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
}
}
return nil
}
Expand All @@ -87,9 +89,11 @@ func CancelDeal(ctx fsm.Context, environment ProviderDealEnvironment, deal rm.Pr
if err != nil {
return ctx.Trigger(rm.ProviderEventMultiStoreError, err)
}
err = environment.CloseDataTransfer(ctx.Context(), deal.ChannelID)
if err != nil && !errors.Is(err, statemachine.ErrTerminated) {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
if deal.ChannelID != nil {
err = environment.CloseDataTransfer(ctx.Context(), *deal.ChannelID)
if err != nil && !errors.Is(err, statemachine.ErrTerminated) {
return ctx.Trigger(rm.ProviderEventDataTransferError, err)
}
}
return ctx.Trigger(rm.ProviderEventCancelComplete)
}
Expand Down
11 changes: 11 additions & 0 deletions retrievalmarket/impl/providerstates/provider_states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/require"

datatransfer "github.com/filecoin-project/go-data-transfer"
"github.com/filecoin-project/go-state-types/abi"
"github.com/filecoin-project/go-state-types/big"
"github.com/filecoin-project/go-statemachine/fsm"
Expand Down Expand Up @@ -112,6 +113,11 @@ func TestUnpauseDeal(t *testing.T) {
environment := rmtesting.NewTestProviderDealEnvironment(node)
setupEnv(environment)
fsmCtx := fsmtest.NewTestContext(ctx, eventMachine)
dealState.ChannelID = &datatransfer.ChannelID{
Initiator: "initiator",
Responder: dealState.Receiver,
ID: 1,
}
err := providerstates.UnpauseDeal(fsmCtx, environment, *dealState)
require.NoError(t, err)
node.VerifyExpectations(t)
Expand Down Expand Up @@ -155,6 +161,11 @@ func TestCancelDeal(t *testing.T) {
environment := rmtesting.NewTestProviderDealEnvironment(node)
setupEnv(environment)
fsmCtx := fsmtest.NewTestContext(ctx, eventMachine)
dealState.ChannelID = &datatransfer.ChannelID{
Initiator: "initiator",
Responder: dealState.Receiver,
ID: 1,
}
err := providerstates.CancelDeal(fsmCtx, environment, *dealState)
require.NoError(t, err)
node.VerifyExpectations(t)
Expand Down
22 changes: 19 additions & 3 deletions retrievalmarket/impl/requestvalidation/revalidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"sync"

logging "github.com/ipfs/go-log/v2"

datatransfer "github.com/filecoin-project/go-data-transfer"
"github.com/filecoin-project/go-state-types/abi"
"github.com/filecoin-project/go-state-types/big"
Expand All @@ -14,6 +16,8 @@ import (
"github.com/filecoin-project/go-fil-markets/retrievalmarket/migrations"
)

var log = logging.Logger("retrieval-revalidator")

// RevalidatorEnvironment are the dependencies needed to
// build the logic of revalidation -- essentially, access to the node at statemachines
type RevalidatorEnvironment interface {
Expand Down Expand Up @@ -52,9 +56,15 @@ func NewProviderRevalidator(env RevalidatorEnvironment) *ProviderRevalidator {
// a given channel ID with a retrieval deal, so that checks run for data sent
// on the channel
func (pr *ProviderRevalidator) TrackChannel(deal rm.ProviderDealState) {
// Sanity check
if deal.ChannelID == nil {
log.Errorf("cannot track deal %s: channel ID is nil", deal.ID)
return
}

pr.trackedChannelsLk.Lock()
defer pr.trackedChannelsLk.Unlock()
pr.trackedChannels[deal.ChannelID] = &channelData{
pr.trackedChannels[*deal.ChannelID] = &channelData{
dealID: deal.Identifier(),
}
pr.writeDealState(deal)
Expand All @@ -63,9 +73,15 @@ func (pr *ProviderRevalidator) TrackChannel(deal rm.ProviderDealState) {
// UntrackChannel indicates a retrieval deal is finish and no longer is tracked
// by this provider
func (pr *ProviderRevalidator) UntrackChannel(deal rm.ProviderDealState) {
// Sanity check
if deal.ChannelID == nil {
log.Errorf("cannot untrack deal %s: channel ID is nil", deal.ID)
return
}

pr.trackedChannelsLk.Lock()
defer pr.trackedChannelsLk.Unlock()
delete(pr.trackedChannels, deal.ChannelID)
delete(pr.trackedChannels, *deal.ChannelID)
}

func (pr *ProviderRevalidator) loadDealState(channel *channelData) error {
Expand All @@ -82,7 +98,7 @@ func (pr *ProviderRevalidator) loadDealState(channel *channelData) error {
}

func (pr *ProviderRevalidator) writeDealState(deal rm.ProviderDealState) {
channel := pr.trackedChannels[deal.ChannelID]
channel := pr.trackedChannels[*deal.ChannelID]
channel.totalSent = deal.TotalSent
if !deal.PricePerByte.IsZero() {
channel.totalPaidFor = big.Div(big.Max(big.Sub(deal.FundsReceived, deal.UnsealPrice), big.Zero()), deal.PricePerByte).Uint64()
Expand Down
35 changes: 18 additions & 17 deletions retrievalmarket/impl/requestvalidation/revalidator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestOnPullDataSent(t *testing.T) {
},
"record block": {
deal: deal,
channelID: deal.ChannelID,
channelID: *deal.ChannelID,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventBlockSent,
expectedArgs: []interface{}{deal.TotalSent + uint64(500)},
Expand All @@ -64,7 +64,7 @@ func TestOnPullDataSent(t *testing.T) {
},
"record block zero price per byte": {
deal: dealZeroPricePerByte,
channelID: dealZeroPricePerByte.ChannelID,
channelID: *dealZeroPricePerByte.ChannelID,
expectedID: dealZeroPricePerByte.Identifier(),
expectedEvent: rm.ProviderEventBlockSent,
expectedArgs: []interface{}{dealZeroPricePerByte.TotalSent + uint64(500)},
Expand All @@ -73,7 +73,7 @@ func TestOnPullDataSent(t *testing.T) {
},
"request payment": {
deal: deal,
channelID: deal.ChannelID,
channelID: *deal.ChannelID,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentRequested,
expectedArgs: []interface{}{deal.TotalSent + defaultCurrentInterval},
Expand All @@ -88,7 +88,7 @@ func TestOnPullDataSent(t *testing.T) {
},
"request payment, legacy": {
deal: legacyDeal,
channelID: legacyDeal.ChannelID,
channelID: *legacyDeal.ChannelID,
expectedID: legacyDeal.Identifier(),
expectedEvent: rm.ProviderEventPaymentRequested,
expectedArgs: []interface{}{legacyDeal.TotalSent + defaultCurrentInterval},
Expand Down Expand Up @@ -140,7 +140,7 @@ func TestOnComplete(t *testing.T) {
dealZeroPricePerByte.PricePerByte = big.Zero()
legacyDeal := deal
legacyDeal.LegacyProtocol = true
channelID := deal.ChannelID
channelID := *deal.ChannelID
testCases := map[string]struct {
expectedEvents []eventSent
deal rm.ProviderDealState
Expand Down Expand Up @@ -296,6 +296,7 @@ func TestRevalidate(t *testing.T) {

deal := *makeDealState(rm.DealStatusFundsNeeded)
deal.TotalSent = defaultTotalSent + defaultCurrentInterval
channelID := *deal.ChannelID
smallerPayment := abi.NewTokenAmount(400000)
payment := &retrievalmarket.DealPayment{
ID: deal.ID,
Expand Down Expand Up @@ -329,7 +330,7 @@ func TestRevalidate(t *testing.T) {
},
"not a payment voucher": {
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
noSend: true,
expectedError: errors.New("wrong voucher type"),
},
Expand All @@ -338,7 +339,7 @@ func TestRevalidate(t *testing.T) {
tn.ChainHeadError = errors.New("something went wrong")
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedError: errors.New("something went wrong"),
expectedID: deal.Identifier(),
Expand All @@ -355,7 +356,7 @@ func TestRevalidate(t *testing.T) {
tn.ChainHeadError = errors.New("something went wrong")
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: legacyPayment,
expectedError: errors.New("something went wrong"),
expectedID: deal.Identifier(),
Expand All @@ -372,7 +373,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, abi.NewTokenAmount(0), errors.New("your money's no good here"))
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedError: errors.New("your money's no good here"),
expectedID: deal.Identifier(),
Expand All @@ -389,7 +390,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, abi.NewTokenAmount(0), errors.New("your money's no good here"))
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: legacyPayment,
expectedError: errors.New("your money's no good here"),
expectedID: deal.Identifier(),
Expand All @@ -406,7 +407,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, smallerPayment, nil)
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedError: datatransfer.ErrPause,
expectedID: deal.Identifier(),
Expand All @@ -423,7 +424,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, smallerPayment, nil)
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: legacyPayment,
expectedError: datatransfer.ErrPause,
expectedID: deal.Identifier(),
Expand All @@ -440,7 +441,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, defaultPaymentPerInterval, nil)
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentReceived,
Expand All @@ -452,7 +453,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, defaultPaymentPerInterval, nil)
},
deal: lastPaymentDeal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentReceived,
Expand All @@ -467,7 +468,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, defaultPaymentPerInterval, nil)
},
deal: lastPaymentDeal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: legacyPayment,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentReceived,
Expand All @@ -482,7 +483,7 @@ func TestRevalidate(t *testing.T) {
_ = tn.ExpectVoucher(payCh, voucher, nil, defaultPaymentPerInterval, big.Zero(), nil)
},
deal: deal,
channelID: deal.ChannelID,
channelID: channelID,
voucher: payment,
expectedID: deal.Identifier(),
expectedEvent: rm.ProviderEventPaymentReceived,
Expand Down Expand Up @@ -565,7 +566,7 @@ func makeDealState(status retrievalmarket.DealStatus) *retrievalmarket.ProviderD
TotalSent: defaultTotalSent,
CurrentInterval: defaultCurrentInterval,
FundsReceived: defaultFundsReceived,
ChannelID: channelID,
ChannelID: &channelID,
Receiver: channelID.Initiator,
DealProposal: retrievalmarket.DealProposal{
ID: dealID,
Expand Down
Loading

0 comments on commit d1405ef

Please sign in to comment.