Skip to content

Commit

Permalink
Add ability to handle context cancellations for TCP protocol (#1389)
Browse files Browse the repository at this point in the history
* Add ability to handle context cancellations.

* Fixed ability to handle context cancellations.

* Removed obsolete comments.

* Added missing change.

* Fixed data race on connection close().

* Synchronisation fix.

* Fixed data race on connection close().

* Sync conn.close() calls eparately. Add test.

* Add one more test.

* Final clean-up.

* Close net connection first, to release blocked reader, then all pending ops on that connection will be properly released.

* Make new tests pretty.

* Clean-up.
  • Loading branch information
tinybit authored Sep 23, 2024
1 parent 360020f commit 014acb5
Show file tree
Hide file tree
Showing 4 changed files with 391 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ _testmain.go

coverage.txt
.idea/**
.vscode/**
dev/*
.run/**

Expand Down
76 changes: 70 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"log"
"net"
"os"
"sync"
"syscall"
"time"

Expand All @@ -42,6 +43,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
conn net.Conn
debugf = func(format string, v ...any) {}
)

switch {
case opt.DialContext != nil:
conn, err = opt.DialContext(ctx, addr)
Expand All @@ -53,9 +55,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
conn, err = net.DialTimeout("tcp", addr, opt.DialTimeout)
}
}

if err != nil {
return nil, err
}

if opt.Debug {
if opt.Debugf != nil {
debugf = func(format string, v ...any) {
Expand All @@ -68,6 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse][conn=%d][%s]", num, conn.RemoteAddr()), 0).Printf
}
}

compression := CompressionNone
if opt.Compression != nil {
switch opt.Compression.Method {
Expand Down Expand Up @@ -96,9 +101,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
maxCompressionBuffer: opt.MaxCompressionBuffer,
}
)

if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil {
return nil, err
}

if connect.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
if err := connect.sendAddendum(); err != nil {
return nil, err
Expand All @@ -109,6 +116,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
if num == 1 && !resources.ClientMeta.IsSupportedClickHouseVersion(connect.server.Version) {
debugf("[handshake] WARNING: version %v of ClickHouse is not supported by this client - client supports %v", connect.server.Version, resources.ClientMeta.SupportedVersions())
}

return connect, nil
}

Expand All @@ -131,6 +139,8 @@ type connect struct {
readTimeout time.Duration
blockBufferSize uint8
maxCompressionBuffer int
readerMutex sync.Mutex
closeMutex sync.Mutex
}

func (c *connect) settings(querySettings Settings) []proto.Setting {
Expand All @@ -153,15 +163,16 @@ func (c *connect) settings(querySettings Settings) []proto.Setting {
for k, v := range c.opt.Settings {
settings = append(settings, settingToProtoSetting(k, v))
}

for k, v := range querySettings {
settings = append(settings, settingToProtoSetting(k, v))
}

return settings
}

func (c *connect) isBad() bool {
switch {
case c.closed:
if c.isClosed() {
return true
}

Expand All @@ -172,19 +183,43 @@ func (c *connect) isBad() bool {
if err := c.connCheck(); err != nil {
return true
}

return false
}

func (c *connect) isClosed() bool {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

return c.closed
}

func (c *connect) setClosed() {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

c.closed = true
}

func (c *connect) close() error {
c.closeMutex.Lock()
if c.closed {
c.closeMutex.Unlock()
return nil
}
c.closed = true
c.buffer = nil
c.reader = nil
c.closeMutex.Unlock()

if err := c.conn.Close(); err != nil {
return err
}

c.buffer = nil

c.readerMutex.Lock()
c.reader = nil
c.readerMutex.Unlock()

return nil
}

Expand All @@ -193,6 +228,7 @@ func (c *connect) progress() (*Progress, error) {
if err := progress.Decode(c.reader, c.revision); err != nil {
return nil, err
}

c.debugf("[progress] %s", &progress)
return &progress, nil
}
Expand All @@ -202,6 +238,7 @@ func (c *connect) exception() error {
if err := e.Decode(c.reader); err != nil {
return err
}

c.debugf("[exception] %s", e.Error())
return &e
}
Expand All @@ -218,6 +255,12 @@ func (c *connect) compressBuffer(start int) error {
}

func (c *connect) sendData(block *proto.Block, name string) error {
if c.isClosed() {
err := errors.New("attempted sending on closed connection")
c.debugf("[send data] err: %v", err)
return err
}

c.debugf("[send data] compression=%q", c.compression)
c.buffer.PutByte(proto.ClientData)
c.buffer.PutString(name)
Expand All @@ -227,6 +270,7 @@ func (c *connect) sendData(block *proto.Block, name string) error {
if err := block.EncodeHeader(c.buffer, c.revision); err != nil {
return err
}

for i := range block.Columns {
if err := block.EncodeColumn(c.buffer, c.revision, i); err != nil {
return err
Expand All @@ -242,33 +286,50 @@ func (c *connect) sendData(block *proto.Block, name string) error {
compressionOffset = 0
}
}

if err := c.compressBuffer(compressionOffset); err != nil {
return err
}

if err := c.flush(); err != nil {
switch {
case errors.Is(err, syscall.EPIPE):
c.debugf("[send data] pipe is broken, closing connection")
c.closed = true
c.setClosed()
case errors.Is(err, io.EOF):
c.debugf("[send data] unexpected EOF, closing connection")
c.closed = true
c.setClosed()
default:
c.debugf("[send data] unexpected error: %v", err)
}
return err
}

defer func() {
c.buffer.Reset()
}()

return nil
}

func (c *connect) readData(ctx context.Context, packet byte, compressible bool) (*proto.Block, error) {
if c.isClosed() {
err := errors.New("attempted reading on closed connection")
c.debugf("[read data] err: %v", err)
return nil, err
}

if c.reader == nil {
err := errors.New("attempted reading on nil reader")
c.debugf("[read data] err: %v", err)
return nil, err
}

if _, err := c.reader.Str(); err != nil {
c.debugf("[read data] str error: %v", err)
return nil, err
}

if compressible && c.compression != CompressionNone {
c.reader.EnableCompression()
defer c.reader.DisableCompression()
Expand All @@ -285,6 +346,7 @@ func (c *connect) readData(ctx context.Context, packet byte, compressible bool)
c.debugf("[read data] decode error: %v", err)
return nil, err
}

block.Packet = packet
c.debugf("[read data] compression=%q. block: columns=%d, rows=%d", c.compression, len(block.Columns), block.Rows())
return &block, nil
Expand All @@ -295,10 +357,12 @@ func (c *connect) flush() error {
// Nothing to flush.
return nil
}

n, err := c.conn.Write(c.buffer.Buf)
if err != nil {
return errors.Wrap(err, "write")
}

if n != len(c.buffer.Buf) {
return errors.New("wrote less than expected")
}
Expand Down
Loading

0 comments on commit 014acb5

Please sign in to comment.