Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[branch/v9] Backport #11112 #11188

Merged
merged 2 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions lib/srv/termmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ const maxHistory = 1000
// - history scrollback for new clients
// - stream breaking
type TermManager struct {
// These two fields need to be first in the struct so that they are 64-bit aligned which is a requirement
// for atomic operations on certain architectures.
countWritten uint64
countRead uint64

mu sync.Mutex
writers map[string]io.Writer
readerState map[string]*int32
readerState map[string]bool
OnWriteError func(idString string, err error)
countWritten uint64
countRead uint64
// buffer is used to buffer writes when turned off
buffer []byte
on bool
Expand All @@ -48,7 +51,7 @@ type TermManager struct {
// we only support one concurrent reader so this isn't mutex protected
remaining []byte
readStateUpdate *sync.Cond
closed *int32
closed bool
lastWasBroadcast bool
terminateNotifier chan struct{}
}
Expand All @@ -57,8 +60,8 @@ type TermManager struct {
func NewTermManager() *TermManager {
return &TermManager{
writers: make(map[string]io.Writer),
readerState: make(map[string]*int32),
closed: new(int32),
readerState: make(map[string]bool),
closed: false,
readStateUpdate: sync.NewCond(&sync.Mutex{}),
incoming: make(chan []byte, 100),
terminateNotifier: make(chan struct{}),
Expand Down Expand Up @@ -214,8 +217,7 @@ func (g *TermManager) DeleteWriter(name string) {
}

func (g *TermManager) AddReader(name string, r io.Reader) {
readerState := new(int32)
g.readerState[name] = readerState
g.readerState[name] = false

go func() {
for {
Expand All @@ -231,32 +233,32 @@ func (g *TermManager) AddReader(name string, r io.Reader) {
// This is the ASCII control code for CTRL+C.
if b == 0x03 {
g.mu.Lock()
if !g.on {
if !g.on && !g.closed {
select {
case g.terminateNotifier <- struct{}{}:
default:
}
}
g.mu.Unlock()
return
break
}
}

g.incoming <- buf[:n]
if atomic.LoadInt32(g.closed) == 1 || atomic.LoadInt32(readerState) == 1 {
g.mu.Lock()
if g.closed || g.readerState[name] {
g.mu.Unlock()
return
}
g.mu.Unlock()
}
}()
}

func (g *TermManager) DeleteReader(name string) {
g.mu.Lock()
defer g.mu.Unlock()

if g.readerState[name] != nil {
atomic.StoreInt32(g.readerState[name], 1)
}
g.readerState[name] = true
}

func (g *TermManager) CountWritten() uint64 {
Expand All @@ -268,7 +270,11 @@ func (g *TermManager) CountRead() uint64 {
}

func (g *TermManager) Close() {
if atomic.CompareAndSwapInt32(g.closed, 0, 1) {
g.mu.Lock()
defer g.mu.Unlock()

if !g.closed {
g.closed = true
close(g.terminateNotifier)
}
}
Expand Down
50 changes: 50 additions & 0 deletions lib/srv/termmanager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
Copyright 2022 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package srv

import (
"io"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestCTRLCPassthrough(t *testing.T) {
m := NewTermManager()
m.On()
r, w := io.Pipe()
m.AddReader("foo", r)
go w.Write([]byte("\x03"))
buf := make([]byte, 1)
_, err := m.Read(buf)
require.NoError(t, err)
require.Equal(t, []byte("\x03"), buf)
}

func TestCTRLCCapture(t *testing.T) {
m := NewTermManager()
r, w := io.Pipe()
m.AddReader("foo", r)
go w.Write([]byte("\x03"))

select {
case <-m.TerminateNotifier():
case <-time.After(time.Second * 10):
t.Fatal("terminateNotifier should've seen an event")
}
}