diff --git a/pkg/networkservice/common/begin/event_factory.go b/pkg/networkservice/common/begin/event_factory.go index 02c25fe62..6a8b00bc7 100644 --- a/pkg/networkservice/common/begin/event_factory.go +++ b/pkg/networkservice/common/begin/event_factory.go @@ -25,7 +25,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/peer" - "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/extend" "github.com/networkservicemesh/sdk/pkg/tools/postpone" @@ -160,16 +159,16 @@ type eventFactoryServer struct { ctxFunc func() (context.Context, context.CancelFunc) request *networkservice.NetworkServiceRequest returnedConnection *networkservice.Connection - closeTimeout time.Duration + contextTimeout time.Duration afterCloseFunc func() server networkservice.NetworkServiceServer } -func newEventFactoryServer(ctx context.Context, closeTimeout time.Duration, afterClose func()) *eventFactoryServer { +func newEventFactoryServer(ctx context.Context, contextTimeout time.Duration, afterClose func()) *eventFactoryServer { f := &eventFactoryServer{ server: next.Server(ctx), initialCtxFunc: postpone.Context(ctx), - closeTimeout: closeTimeout, + contextTimeout: contextTimeout, } f.updateContext(ctx) @@ -207,7 +206,12 @@ func (f *eventFactoryServer) Request(opts ...Option) <-chan error { default: ctx, cancel := f.ctxFunc() defer cancel() - conn, err := f.server.Request(ctx, f.request) + + extendedCtx, cancel := context.WithTimeout(context.Background(), f.contextTimeout) + defer cancel() + + extendedCtx = extend.WithValuesFromContext(extendedCtx, ctx) + conn, err := f.server.Request(extendedCtx, f.request) if err == nil && f.request != nil { f.request.Connection = conn } @@ -236,12 +240,11 @@ func (f *eventFactoryServer) Close(opts ...Option) <-chan error { ctx, cancel := f.ctxFunc() defer cancel() - c := clock.FromContext(ctx) - closeCtx, cancel := c.WithTimeout(context.Background(), f.closeTimeout) + extendedCtx, cancel := context.WithTimeout(context.Background(), f.contextTimeout) defer cancel() - closeCtx = extend.WithValuesFromContext(closeCtx, ctx) - _, err := f.server.Close(closeCtx, f.request.GetConnection()) + extendedCtx = extend.WithValuesFromContext(extendedCtx, ctx) + _, err := f.server.Close(extendedCtx, f.request.GetConnection()) f.afterCloseFunc() ch <- err } diff --git a/pkg/networkservice/common/begin/event_factory_server_test.go b/pkg/networkservice/common/begin/event_factory_server_test.go index 4df4e4d19..3a105998b 100644 --- a/pkg/networkservice/common/begin/event_factory_server_test.go +++ b/pkg/networkservice/common/begin/event_factory_server_test.go @@ -33,7 +33,6 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" "github.com/networkservicemesh/sdk/pkg/tools/clock" - "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) // This test reproduces the situation when refresh changes the eventFactory context @@ -138,19 +137,12 @@ func TestContextTimeout_Server(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // Add clockMock to the context - clockMock := clockmock.New(ctx) - ctx = clock.WithClock(ctx, clockMock) - - ctx, cancel = clockMock.WithDeadline(ctx, clockMock.Now().Add(time.Second*3)) - defer cancel() - - closeTimeout := time.Minute + contextTimeout := time.Second * 2 eventFactoryServ := &eventFactoryServer{} server := chain.NewNetworkServiceServer( - begin.NewServer(begin.WithCloseTimeout(closeTimeout)), + begin.NewServer(begin.WithContextTimeout(contextTimeout)), eventFactoryServ, - &delayedNSEServer{t: t, closeTimeout: closeTimeout, clock: clockMock}, + &delayedNSEServer{t: t, contextTimeout: contextTimeout}, ) // Do Request @@ -229,9 +221,8 @@ func (f *failedNSEServer) Close(ctx context.Context, conn *networkservice.Connec type delayedNSEServer struct { t *testing.T - clock *clockmock.Mock initialTimeout time.Duration - closeTimeout time.Duration + contextTimeout time.Duration } func (d *delayedNSEServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { @@ -247,10 +238,10 @@ func (d *delayedNSEServer) Request(ctx context.Context, request *networkservice. d.initialTimeout = timeout } // All requests timeout must be equal the first - require.Equal(d.t, d.initialTimeout, timeout) + require.Less(d.t, (d.initialTimeout - timeout).Abs(), time.Second) // Add delay - d.clock.Add(timeout / 2) + time.Sleep(timeout / 2) return next.Server(ctx).Request(ctx, request) } @@ -258,9 +249,9 @@ func (d *delayedNSEServer) Close(ctx context.Context, conn *networkservice.Conne require.Greater(d.t, d.initialTimeout, time.Duration(0)) deadline, _ := ctx.Deadline() - clockTime := clock.FromContext(ctx) + timeout := time.Until(deadline) - require.Equal(d.t, d.closeTimeout, clockTime.Until(deadline)) + require.Less(d.t, (d.contextTimeout - timeout).Abs(), time.Second) return next.Server(ctx).Close(ctx, conn) } diff --git a/pkg/networkservice/common/begin/options.go b/pkg/networkservice/common/begin/options.go index d4ee6fe52..509acc7aa 100644 --- a/pkg/networkservice/common/begin/options.go +++ b/pkg/networkservice/common/begin/options.go @@ -22,9 +22,9 @@ import ( ) type option struct { - cancelCtx context.Context - reselect bool - closeTimeout time.Duration + cancelCtx context.Context + reselect bool + contextTimeout time.Duration } // Option - event option @@ -44,9 +44,9 @@ func WithReselect() Option { } } -// WithCloseTimeout - set a custom timeout for a context in begin.Close -func WithCloseTimeout(timeout time.Duration) Option { +// WithContextTimeout - set a custom timeout for a context in begin.Close +func WithContextTimeout(timeout time.Duration) Option { return func(o *option) { - o.closeTimeout = timeout + o.contextTimeout = timeout } } diff --git a/pkg/networkservice/common/begin/server.go b/pkg/networkservice/common/begin/server.go index cb6c711e7..a057a9299 100644 --- a/pkg/networkservice/common/begin/server.go +++ b/pkg/networkservice/common/begin/server.go @@ -33,15 +33,15 @@ import ( type beginServer struct { genericsync.Map[string, *eventFactoryServer] - closeTimeout time.Duration + contextTimeout time.Duration } // NewServer - creates a new begin chain element func NewServer(opts ...Option) networkservice.NetworkServiceServer { o := &option{ - cancelCtx: context.Background(), - reselect: false, - closeTimeout: time.Minute, + cancelCtx: context.Background(), + reselect: false, + contextTimeout: time.Minute, } for _, opt := range opts { @@ -49,7 +49,7 @@ func NewServer(opts ...Option) networkservice.NetworkServiceServer { } return &beginServer{ - closeTimeout: o.closeTimeout, + contextTimeout: o.contextTimeout, } } @@ -68,7 +68,7 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo eventFactoryServer, _ := b.LoadOrStore(request.GetConnection().GetId(), newEventFactoryServer( ctx, - b.closeTimeout, + b.contextTimeout, func() { b.Delete(request.GetRequestConnection().GetId()) }, @@ -88,8 +88,12 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo eventFactoryServer.request != nil && eventFactoryServer.request.Connection != nil { log.FromContext(ctx).Info("Closing connection due to RESELECT_REQUESTED state") + closeCtx, cancel := context.WithTimeout(context.Background(), b.contextTimeout) + defer cancel() + eventFactoryCtx, eventFactoryCtxCancel := eventFactoryServer.ctxFunc() - _, closeErr := next.Server(eventFactoryCtx).Close(eventFactoryCtx, eventFactoryServer.request.Connection) + closeCtx = extend.WithValuesFromContext(closeCtx, eventFactoryCtx) + _, closeErr := next.Server(closeCtx).Close(closeCtx, eventFactoryServer.request.Connection) if closeErr != nil { log.FromContext(ctx).Errorf("Can't close old connection: %v", closeErr) } @@ -97,8 +101,11 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo eventFactoryCtxCancel() } - withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer) - conn, err = next.Server(withEventFactoryCtx).Request(withEventFactoryCtx, request) + extendedCtx, cancel := context.WithTimeout(context.Background(), b.contextTimeout) + extendedCtx = extend.WithValuesFromContext(extendedCtx, withEventFactory(ctx, eventFactoryServer)) + defer cancel() + + conn, err = next.Server(extendedCtx).Request(extendedCtx, request) if err != nil { if eventFactoryServer.state != established { eventFactoryServer.state = closed @@ -143,14 +150,13 @@ func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection if currentServerClient != eventFactoryServer { return } - closeCtx, cancel := context.WithTimeout(context.Background(), b.closeTimeout) + extendedCtx, cancel := context.WithTimeout(context.Background(), b.contextTimeout) + extendedCtx = extend.WithValuesFromContext(extendedCtx, withEventFactory(ctx, eventFactoryServer)) defer cancel() // Always close with the last valid EventFactory we got conn = eventFactoryServer.request.Connection - withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer) - closeCtx = extend.WithValuesFromContext(closeCtx, withEventFactoryCtx) - _, err = next.Server(closeCtx).Close(closeCtx, conn) + _, err = next.Server(extendedCtx).Close(extendedCtx, conn) eventFactoryServer.afterCloseFunc() }): return &emptypb.Empty{}, err diff --git a/pkg/networkservice/common/begin/server_test.go b/pkg/networkservice/common/begin/server_test.go index 682116251..70c3142f0 100644 --- a/pkg/networkservice/common/begin/server_test.go +++ b/pkg/networkservice/common/begin/server_test.go @@ -26,6 +26,7 @@ import ( "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "google.golang.org/protobuf/types/known/emptypb" "github.com/networkservicemesh/sdk/pkg/networkservice/common/begin" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" @@ -41,14 +42,24 @@ type waitServer struct { } func (s *waitServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { - time.Sleep(waitTime) - s.requestDone.Store(1) + afterCh := time.After(time.Second) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-afterCh: + s.requestDone.Add(1) + } return next.Server(ctx).Request(ctx, request) } func (s *waitServer) Close(ctx context.Context, connection *networkservice.Connection) (*empty.Empty, error) { - time.Sleep(waitTime) - s.closeDone.Store(1) + afterCh := time.After(time.Second) + select { + case <-ctx.Done(): + return &emptypb.Empty{}, nil + case <-afterCh: + s.closeDone.Add(1) + } return next.Server(ctx).Close(ctx, connection) } @@ -82,3 +93,39 @@ func TestBeginWorksWithSmallTimeout(t *testing.T) { return waitSrv.closeDone.Load() == 1 }, waitTime*2, time.Millisecond*500) } + +func TestBeginHasExtendedTimeoutOnReselect(t *testing.T) { + t.Cleanup(func() { + goleak.VerifyNone(t) + }) + requestCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + + waitSrv := &waitServer{} + server := next.NewNetworkServiceServer( + begin.NewServer(), + waitSrv, + ) + + // Make a first request to create an event factory. Begin should make Request only + request := testRequest("id") + _, err := server.Request(requestCtx, request) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + require.Equal(t, int32(0), waitSrv.requestDone.Load()) + require.Eventually(t, func() bool { + return waitSrv.requestDone.Load() == 1 + }, waitTime*2, time.Millisecond*500) + + // Make a second request with RESELECT_REQUESTED. Begin should make Close with extended context first and then Request + requestCtx, cancel = context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + newRequest := request.Clone() + newRequest.Connection.State = networkservice.State_RESELECT_REQUESTED + + _, err = server.Request(requestCtx, newRequest) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + require.Equal(t, int32(0), waitSrv.closeDone.Load()) + require.Eventually(t, func() bool { + return waitSrv.closeDone.Load() == 1 && waitSrv.requestDone.Load() == 2 + }, waitTime*4, time.Millisecond*500) +} diff --git a/pkg/networkservice/common/dial/client.go b/pkg/networkservice/common/dial/client.go index 8fd6739d6..31be1deac 100644 --- a/pkg/networkservice/common/dial/client.go +++ b/pkg/networkservice/common/dial/client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 Cisco and/or its affiliates. +// Copyright (c) 2021-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -72,8 +72,12 @@ func (d *dialClient) Request(ctx context.Context, request *networkservice.Networ return next.Client(ctx).Request(ctx, request, opts...) } + di.mu.Lock() + dialClientURL := di.clientURL + di.mu.Unlock() + // If our existing dialer has a different URL close down the chain - if di.clientURL != nil && di.clientURL.String() != clientURL.String() { + if dialClientURL != nil && dialClientURL.String() != clientURL.String() { closeCtx, closeCancel := closeContextFunc() defer closeCancel() err := di.Dial(closeCtx, di.clientURL) diff --git a/pkg/networkservice/common/dial/dialer.go b/pkg/networkservice/common/dial/dialer.go index bf988279e..2b285f9d9 100644 --- a/pkg/networkservice/common/dial/dialer.go +++ b/pkg/networkservice/common/dial/dialer.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -20,6 +20,7 @@ import ( "context" "net/url" "runtime" + "sync" "time" "github.com/pkg/errors" @@ -37,6 +38,7 @@ type dialer struct { *grpc.ClientConn dialOptions []grpc.DialOption dialTimeout time.Duration + mu sync.Mutex } func newDialer(ctx context.Context, dialTimeout time.Duration, dialOptions ...grpc.DialOption) *dialer { @@ -56,8 +58,10 @@ func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error { di.cleanupCancel() } + di.mu.Lock() // Set the clientURL di.clientURL = clientURL + di.mu.Unlock() // Setup dialTimeout if needed dialCtx := ctx