diff --git a/server/controllers/websocket/mux.go b/server/controllers/websocket/mux.go index ccfbdf99f9..8288df3212 100644 --- a/server/controllers/websocket/mux.go +++ b/server/controllers/websocket/mux.go @@ -1,6 +1,7 @@ package websocket import ( + "fmt" "net/http" "github.com/gorilla/websocket" @@ -18,6 +19,7 @@ type PartitionKeyGenerator interface { type PartitionRegistry interface { Register(key string, buffer chan string) Deregister(key string, buffer chan string) + IsKeyExists(key string) bool } // Multiplexor is responsible for handling the data transfer between the storage layer @@ -51,6 +53,11 @@ func (m *Multiplexor) Handle(w http.ResponseWriter, r *http.Request) error { return errors.Wrapf(err, "generating partition key") } + // check if the job ID exists before registering receiver + if !m.registry.IsKeyExists(key) { + return fmt.Errorf("invalid key: %s", key) + } + // Buffer size set to 1000 to ensure messages get queued. // TODO: make buffer size configurable buffer := make(chan string, 1000) diff --git a/server/handlers/mocks/mock_project_command_output_handler.go b/server/handlers/mocks/mock_project_command_output_handler.go index fa8fb149b6..926f1ca4b0 100644 --- a/server/handlers/mocks/mock_project_command_output_handler.go +++ b/server/handlers/mocks/mock_project_command_output_handler.go @@ -50,6 +50,21 @@ func (mock *MockProjectCommandOutputHandler) Handle() { pegomock.GetGenericMockFrom(mock).Invoke("Handle", params, []reflect.Type{}) } +func (mock *MockProjectCommandOutputHandler) IsKeyExists(_param0 string) bool { + if mock == nil { + panic("mock must not be nil. Use myMock := NewMockProjectCommandOutputHandler().") + } + params := []pegomock.Param{_param0} + result := pegomock.GetGenericMockFrom(mock).Invoke("IsKeyExists", params, []reflect.Type{reflect.TypeOf((*bool)(nil)).Elem()}) + var ret0 bool + if len(result) != 0 { + if result[0] != nil { + ret0 = result[0].(bool) + } + } + return ret0 +} + func (mock *MockProjectCommandOutputHandler) Register(_param0 string, _param1 chan string) { if mock == nil { panic("mock must not be nil. Use myMock := NewMockProjectCommandOutputHandler().") @@ -193,6 +208,33 @@ func (c *MockProjectCommandOutputHandler_Handle_OngoingVerification) GetCaptured func (c *MockProjectCommandOutputHandler_Handle_OngoingVerification) GetAllCapturedArguments() { } +func (verifier *VerifierMockProjectCommandOutputHandler) IsKeyExists(_param0 string) *MockProjectCommandOutputHandler_IsKeyExists_OngoingVerification { + params := []pegomock.Param{_param0} + methodInvocations := pegomock.GetGenericMockFrom(verifier.mock).Verify(verifier.inOrderContext, verifier.invocationCountMatcher, "IsKeyExists", params, verifier.timeout) + return &MockProjectCommandOutputHandler_IsKeyExists_OngoingVerification{mock: verifier.mock, methodInvocations: methodInvocations} +} + +type MockProjectCommandOutputHandler_IsKeyExists_OngoingVerification struct { + mock *MockProjectCommandOutputHandler + methodInvocations []pegomock.MethodInvocation +} + +func (c *MockProjectCommandOutputHandler_IsKeyExists_OngoingVerification) GetCapturedArguments() string { + _param0 := c.GetAllCapturedArguments() + return _param0[len(_param0)-1] +} + +func (c *MockProjectCommandOutputHandler_IsKeyExists_OngoingVerification) GetAllCapturedArguments() (_param0 []string) { + params := pegomock.GetGenericMockFrom(c.mock).GetInvocationParams(c.methodInvocations) + if len(params) > 0 { + _param0 = make([]string, len(c.methodInvocations)) + for u, param := range params[0] { + _param0[u] = param.(string) + } + } + return +} + func (verifier *VerifierMockProjectCommandOutputHandler) Register(_param0 string, _param1 chan string) *MockProjectCommandOutputHandler_Register_OngoingVerification { params := []pegomock.Param{_param0, _param1} methodInvocations := pegomock.GetGenericMockFrom(verifier.mock).Verify(verifier.inOrderContext, verifier.invocationCountMatcher, "Register", params, verifier.timeout) diff --git a/server/handlers/project_command_output_handler.go b/server/handlers/project_command_output_handler.go index 44a8e9ca7e..02cb851607 100644 --- a/server/handlers/project_command_output_handler.go +++ b/server/handlers/project_command_output_handler.go @@ -81,6 +81,8 @@ type ProjectCommandOutputHandler interface { // Deregister removes a channel from successive updates and closes it. Deregister(jobID string, receiver chan string) + IsKeyExists(key string) bool + // Listens for msg from channel Handle() @@ -114,6 +116,13 @@ func NewAsyncProjectCommandOutputHandler( } } +func (p *AsyncProjectCommandOutputHandler) IsKeyExists(key string) bool { + p.receiverBuffersLock.RLock() + defer p.receiverBuffersLock.RUnlock() + _, ok := p.receiverBuffers[key] + return ok +} + func (p *AsyncProjectCommandOutputHandler) Send(ctx models.ProjectCommandContext, msg string, operationComplete bool) { p.projectCmdOutput <- &ProjectCmdOutputLine{ JobID: ctx.JobID, @@ -219,12 +228,7 @@ func (p *AsyncProjectCommandOutputHandler) writeLogLine(jobID string, line strin select { case ch <- line: default: - // Client ws conn could be closed in two ways: - // 1. Client closes the conn gracefully -> the closeHandler() is executed which - // closes the channel and cleans up resources. - // 2. Client does not close the conn and the closeHandler() is not executed -> the - // receiverChan will be blocking for N number of messages (equal to buffer size) - // before we delete the channel and clean up the resources. + // Delete buffered channel if it's blocking. delete(p.receiverBuffers[jobID], ch) } } @@ -274,9 +278,6 @@ func (p *AsyncProjectCommandOutputHandler) CleanUp(pullContext PullContext) { delete(p.projectOutputBuffers, jobID) p.projectOutputBuffersLock.Unlock() - // Only delete the pull record from receiver buffers. - // WS channel will be closed when the user closes the browser tab - // in closeHanlder(). p.receiverBuffersLock.Lock() delete(p.receiverBuffers, jobID) p.receiverBuffersLock.Unlock() @@ -305,3 +306,7 @@ func (p *NoopProjectOutputHandler) SetJobURLWithStatus(ctx models.ProjectCommand func (p *NoopProjectOutputHandler) CleanUp(pullContext PullContext) { } + +func (p *NoopProjectOutputHandler) IsKeyExists(key string) bool { + return false +}