Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to handle context cancellations for TCP protocol #1389

Merged
merged 13 commits into from
Sep 23, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ _testmain.go

coverage.txt
.idea/**
.vscode/**
dev/*
.run/**

Expand Down
76 changes: 70 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"log"
"net"
"os"
"sync"
"syscall"
"time"

Expand All @@ -42,6 +43,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
conn net.Conn
debugf = func(format string, v ...any) {}
)

switch {
case opt.DialContext != nil:
conn, err = opt.DialContext(ctx, addr)
Expand All @@ -53,9 +55,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
conn, err = net.DialTimeout("tcp", addr, opt.DialTimeout)
}
}

if err != nil {
return nil, err
}

if opt.Debug {
if opt.Debugf != nil {
debugf = func(format string, v ...any) {
Expand All @@ -68,6 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse][conn=%d][%s]", num, conn.RemoteAddr()), 0).Printf
}
}

compression := CompressionNone
if opt.Compression != nil {
switch opt.Compression.Method {
Expand Down Expand Up @@ -96,9 +101,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
maxCompressionBuffer: opt.MaxCompressionBuffer,
}
)

if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil {
return nil, err
}

if connect.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
if err := connect.sendAddendum(); err != nil {
return nil, err
Expand All @@ -109,6 +116,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
if num == 1 && !resources.ClientMeta.IsSupportedClickHouseVersion(connect.server.Version) {
debugf("[handshake] WARNING: version %v of ClickHouse is not supported by this client - client supports %v", connect.server.Version, resources.ClientMeta.SupportedVersions())
}

return connect, nil
}

Expand All @@ -131,6 +139,8 @@ type connect struct {
readTimeout time.Duration
blockBufferSize uint8
maxCompressionBuffer int
readerMutex sync.Mutex
closeMutex sync.Mutex
}

func (c *connect) settings(querySettings Settings) []proto.Setting {
Expand All @@ -153,15 +163,16 @@ func (c *connect) settings(querySettings Settings) []proto.Setting {
for k, v := range c.opt.Settings {
settings = append(settings, settingToProtoSetting(k, v))
}

for k, v := range querySettings {
settings = append(settings, settingToProtoSetting(k, v))
}

return settings
}

func (c *connect) isBad() bool {
switch {
case c.closed:
if c.isClosed() {
return true
}

Expand All @@ -172,19 +183,43 @@ func (c *connect) isBad() bool {
if err := c.connCheck(); err != nil {
return true
}

return false
}

func (c *connect) isClosed() bool {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

return c.closed
}

func (c *connect) setClosed() {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

c.closed = true
}

func (c *connect) close() error {
c.closeMutex.Lock()
if c.closed {
c.closeMutex.Unlock()
return nil
}
c.closed = true
c.buffer = nil
c.reader = nil
c.closeMutex.Unlock()

if err := c.conn.Close(); err != nil {
return err
}

c.buffer = nil

c.readerMutex.Lock()
c.reader = nil
c.readerMutex.Unlock()

return nil
}

Expand All @@ -193,6 +228,7 @@ func (c *connect) progress() (*Progress, error) {
if err := progress.Decode(c.reader, c.revision); err != nil {
return nil, err
}

c.debugf("[progress] %s", &progress)
return &progress, nil
}
Expand All @@ -202,6 +238,7 @@ func (c *connect) exception() error {
if err := e.Decode(c.reader); err != nil {
return err
}

c.debugf("[exception] %s", e.Error())
return &e
}
Expand All @@ -218,6 +255,12 @@ func (c *connect) compressBuffer(start int) error {
}

func (c *connect) sendData(block *proto.Block, name string) error {
if c.isClosed() {
err := errors.New("attempted sending on closed connection")
c.debugf("[send data] err: %v", err)
return err
}

c.debugf("[send data] compression=%q", c.compression)
c.buffer.PutByte(proto.ClientData)
c.buffer.PutString(name)
Expand All @@ -227,6 +270,7 @@ func (c *connect) sendData(block *proto.Block, name string) error {
if err := block.EncodeHeader(c.buffer, c.revision); err != nil {
return err
}

for i := range block.Columns {
if err := block.EncodeColumn(c.buffer, c.revision, i); err != nil {
return err
Expand All @@ -242,33 +286,50 @@ func (c *connect) sendData(block *proto.Block, name string) error {
compressionOffset = 0
}
}

if err := c.compressBuffer(compressionOffset); err != nil {
return err
}

if err := c.flush(); err != nil {
switch {
case errors.Is(err, syscall.EPIPE):
c.debugf("[send data] pipe is broken, closing connection")
c.closed = true
c.setClosed()
case errors.Is(err, io.EOF):
c.debugf("[send data] unexpected EOF, closing connection")
c.closed = true
c.setClosed()
default:
c.debugf("[send data] unexpected error: %v", err)
}
return err
}

defer func() {
c.buffer.Reset()
}()

return nil
}

func (c *connect) readData(ctx context.Context, packet byte, compressible bool) (*proto.Block, error) {
if c.isClosed() {
err := errors.New("attempted reading on closed connection")
c.debugf("[read data] err: %v", err)
return nil, err
}

if c.reader == nil {
err := errors.New("attempted reading on nil reader")
c.debugf("[read data] err: %v", err)
return nil, err
}

if _, err := c.reader.Str(); err != nil {
c.debugf("[read data] str error: %v", err)
return nil, err
}

if compressible && c.compression != CompressionNone {
c.reader.EnableCompression()
defer c.reader.DisableCompression()
Expand All @@ -285,6 +346,7 @@ func (c *connect) readData(ctx context.Context, packet byte, compressible bool)
c.debugf("[read data] decode error: %v", err)
return nil, err
}

block.Packet = packet
c.debugf("[read data] compression=%q. block: columns=%d, rows=%d", c.compression, len(block.Columns), block.Rows())
return &block, nil
Expand All @@ -295,10 +357,12 @@ func (c *connect) flush() error {
// Nothing to flush.
return nil
}

n, err := c.conn.Write(c.buffer.Buf)
if err != nil {
return errors.Wrap(err, "write")
}

if n != len(c.buffer.Buf) {
return errors.New("wrote less than expected")
}
Expand Down
Loading
Loading