diff --git a/.gitignore b/.gitignore index 2d3f9ae66c..2e8e75bad6 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ _testmain.go coverage.txt .idea/** +.vscode/** dev/* .run/** diff --git a/conn.go b/conn.go index 6d831a0478..2757e12134 100644 --- a/conn.go +++ b/conn.go @@ -25,6 +25,7 @@ import ( "log" "net" "os" + "sync" "syscall" "time" @@ -42,6 +43,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er conn net.Conn debugf = func(format string, v ...any) {} ) + switch { case opt.DialContext != nil: conn, err = opt.DialContext(ctx, addr) @@ -53,9 +55,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er conn, err = net.DialTimeout("tcp", addr, opt.DialTimeout) } } + if err != nil { return nil, err } + if opt.Debug { if opt.Debugf != nil { debugf = func(format string, v ...any) { @@ -68,6 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse][conn=%d][%s]", num, conn.RemoteAddr()), 0).Printf } } + compression := CompressionNone if opt.Compression != nil { switch opt.Compression.Method { @@ -96,9 +101,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er maxCompressionBuffer: opt.MaxCompressionBuffer, } ) + if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil { return nil, err } + if connect.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM { if err := connect.sendAddendum(); err != nil { return nil, err @@ -109,6 +116,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er if num == 1 && !resources.ClientMeta.IsSupportedClickHouseVersion(connect.server.Version) { debugf("[handshake] WARNING: version %v of ClickHouse is not supported by this client - client supports %v", connect.server.Version, resources.ClientMeta.SupportedVersions()) } + return connect, nil } @@ -131,6 +139,8 @@ type connect struct { readTimeout time.Duration blockBufferSize uint8 maxCompressionBuffer int + readerMutex sync.Mutex + closeMutex sync.Mutex } func (c *connect) settings(querySettings Settings) []proto.Setting { @@ -153,15 +163,16 @@ func (c *connect) settings(querySettings Settings) []proto.Setting { for k, v := range c.opt.Settings { settings = append(settings, settingToProtoSetting(k, v)) } + for k, v := range querySettings { settings = append(settings, settingToProtoSetting(k, v)) } + return settings } func (c *connect) isBad() bool { - switch { - case c.closed: + if c.isClosed() { return true } @@ -172,19 +183,43 @@ func (c *connect) isBad() bool { if err := c.connCheck(); err != nil { return true } + return false } +func (c *connect) isClosed() bool { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + + return c.closed +} + +func (c *connect) setClosed() { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + + c.closed = true +} + func (c *connect) close() error { + c.closeMutex.Lock() if c.closed { + c.closeMutex.Unlock() return nil } c.closed = true - c.buffer = nil - c.reader = nil + c.closeMutex.Unlock() + if err := c.conn.Close(); err != nil { return err } + + c.buffer = nil + + c.readerMutex.Lock() + c.reader = nil + c.readerMutex.Unlock() + return nil } @@ -193,6 +228,7 @@ func (c *connect) progress() (*Progress, error) { if err := progress.Decode(c.reader, c.revision); err != nil { return nil, err } + c.debugf("[progress] %s", &progress) return &progress, nil } @@ -202,6 +238,7 @@ func (c *connect) exception() error { if err := e.Decode(c.reader); err != nil { return err } + c.debugf("[exception] %s", e.Error()) return &e } @@ -218,6 +255,12 @@ func (c *connect) compressBuffer(start int) error { } func (c *connect) sendData(block *proto.Block, name string) error { + if c.isClosed() { + err := errors.New("attempted sending on closed connection") + c.debugf("[send data] err: %v", err) + return err + } + c.debugf("[send data] compression=%q", c.compression) c.buffer.PutByte(proto.ClientData) c.buffer.PutString(name) @@ -227,6 +270,7 @@ func (c *connect) sendData(block *proto.Block, name string) error { if err := block.EncodeHeader(c.buffer, c.revision); err != nil { return err } + for i := range block.Columns { if err := block.EncodeColumn(c.buffer, c.revision, i); err != nil { return err @@ -242,33 +286,50 @@ func (c *connect) sendData(block *proto.Block, name string) error { compressionOffset = 0 } } + if err := c.compressBuffer(compressionOffset); err != nil { return err } + if err := c.flush(); err != nil { switch { case errors.Is(err, syscall.EPIPE): c.debugf("[send data] pipe is broken, closing connection") - c.closed = true + c.setClosed() case errors.Is(err, io.EOF): c.debugf("[send data] unexpected EOF, closing connection") - c.closed = true + c.setClosed() default: c.debugf("[send data] unexpected error: %v", err) } return err } + defer func() { c.buffer.Reset() }() + return nil } func (c *connect) readData(ctx context.Context, packet byte, compressible bool) (*proto.Block, error) { + if c.isClosed() { + err := errors.New("attempted reading on closed connection") + c.debugf("[read data] err: %v", err) + return nil, err + } + + if c.reader == nil { + err := errors.New("attempted reading on nil reader") + c.debugf("[read data] err: %v", err) + return nil, err + } + if _, err := c.reader.Str(); err != nil { c.debugf("[read data] str error: %v", err) return nil, err } + if compressible && c.compression != CompressionNone { c.reader.EnableCompression() defer c.reader.DisableCompression() @@ -285,6 +346,7 @@ func (c *connect) readData(ctx context.Context, packet byte, compressible bool) c.debugf("[read data] decode error: %v", err) return nil, err } + block.Packet = packet c.debugf("[read data] compression=%q. block: columns=%d, rows=%d", c.compression, len(block.Columns), block.Rows()) return &block, nil @@ -295,10 +357,12 @@ func (c *connect) flush() error { // Nothing to flush. return nil } + n, err := c.conn.Write(c.buffer.Buf) if err != nil { return errors.Wrap(err, "write") } + if n != len(c.buffer.Buf) { return errors.New("wrote less than expected") } diff --git a/conn_process.go b/conn_process.go index 967e2ffc4d..1825d87232 100644 --- a/conn_process.go +++ b/conn_process.go @@ -20,8 +20,10 @@ package clickhouse import ( "context" "fmt" - "github.com/ClickHouse/clickhouse-go/v2/lib/proto" "io" + + "github.com/ClickHouse/clickhouse-go/v2/lib/proto" + "github.com/pkg/errors" ) type onProcess struct { @@ -33,51 +35,137 @@ type onProcess struct { } func (c *connect) firstBlock(ctx context.Context, on *onProcess) (*proto.Block, error) { + // if context is already timedout/cancelled — we're done + select { + case <-ctx.Done(): + c.cancel() + return nil, ctx.Err() + default: + } + + // do reads in background + resultCh := make(chan *proto.Block, 1) + errCh := make(chan error, 1) + + go func() { + block, err := c.firstBlockImpl(ctx, on) + if err != nil { + errCh <- err + return + } + resultCh <- block + }() + + // select on context or read channels (results/errors) + select { + case <-ctx.Done(): + c.cancel() + return nil, ctx.Err() + + case err := <-errCh: + return nil, err + + case block := <-resultCh: + return block, nil + } +} + +func (c *connect) firstBlockImpl(ctx context.Context, on *onProcess) (*proto.Block, error) { + c.readerMutex.Lock() + defer c.readerMutex.Unlock() + for { - select { - case <-ctx.Done(): - c.cancel() - return nil, ctx.Err() - default: + if c.reader == nil { + return nil, errors.New("unexpected state: c.reader is nil") } + packet, err := c.reader.ReadByte() if err != nil { return nil, err } + switch packet { case proto.ServerData: return c.readData(ctx, packet, true) + case proto.ServerEndOfStream: c.debugf("[end of stream]") return nil, io.EOF + default: if err := c.handle(ctx, packet, on); err != nil { + // handling error, return return nil, err } + + // handled okay, read next byte } } } func (c *connect) process(ctx context.Context, on *onProcess) error { + // if context is already timedout/cancelled — we're done + select { + case <-ctx.Done(): + c.cancel() + return ctx.Err() + default: + } + + // do reads in background + errCh := make(chan error, 1) + doneCh := make(chan bool, 1) + + go func() { + err := c.processImpl(ctx, on) + if err != nil { + errCh <- err + return + } + + doneCh <- true + }() + + // select on context or read channel (errors) + select { + case <-ctx.Done(): + c.cancel() + return ctx.Err() + + case err := <-errCh: + return err + + case <-doneCh: + return nil + } +} + +func (c *connect) processImpl(ctx context.Context, on *onProcess) error { + c.readerMutex.Lock() + defer c.readerMutex.Unlock() + for { - select { - case <-ctx.Done(): - c.cancel() - return ctx.Err() - default: + if c.reader == nil { + return errors.New("unexpected state: c.reader is nil") } + packet, err := c.reader.ReadByte() if err != nil { return err } + switch packet { case proto.ServerEndOfStream: c.debugf("[end of stream]") return nil } + if err := c.handle(ctx, packet, on); err != nil { + // handling error, return return err } + + // handled okay, read next byte } } diff --git a/tests/context_cancel_test.go b/tests/context_cancel_test.go new file mode 100644 index 0000000000..45f8f1179f --- /dev/null +++ b/tests/context_cancel_test.go @@ -0,0 +1,221 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tests + +import ( + "context" + "log" + "testing" + "time" + + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/stretchr/testify/assert" +) + +func TestContextCancellationOfHeavyGeneratedInsert(t *testing.T) { + var ( + heavyQuery = `INSERT INTO test_query_cancellation.trips + SELECT + number + 1 AS trip_id, + now() - INTERVAL intDiv(number, 100) SECOND AS pickup_datetime, + now() - INTERVAL intDiv(number, 100) SECOND + INTERVAL rand() % 3600 SECOND AS dropoff_datetime, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 - 74.00) AS pickup_longitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 + 40.50) AS pickup_latitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 - 74.00) AS dropoff_longitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 + 40.50) AS dropoff_latitude, + rand() % 6 + 1 AS passenger_count, + (rand() % 2000) / 100.0 AS trip_distance, + (rand() % 5000) / 100.0 AS fare_amount, + (rand() % 500) / 100.0 AS extra, + (rand() % 1000) / 100.0 AS tip_amount, + (rand() % 300) / 100.0 AS tolls_amount, + (rand() % 6000) / 100.0 AS total_amount, + CAST(rand() % 5 + 1 AS Enum('CSH' = 1, 'CRE' = 2, 'NOC' = 3, 'DIS' = 4, 'UNK' = 5)) AS payment_type, + 'Neighborhood ' || toString(rand() % 100 + 1) AS pickup_ntaname, + 'Neighborhood ' || toString(rand() % 100 + 1) AS dropoff_ntaname + FROM numbers(100000000);` + ) + + conn, err := SetupTestContextCancellationType1(t, false) + assert.Nil(t, err) + assert.NotNil(t, conn) + + ExecuteTestContextCancellation(t, conn, heavyQuery) +} + +func TestContextCancellationOfHeavyOptimizeFinal(t *testing.T) { + var ( + heavyQuery = "OPTIMIZE TABLE test_query_cancellation.trips FINAL" + ) + + conn, err := SetupTestContextCancellationType1(t, true) + assert.Nil(t, err) + assert.NotNil(t, conn) + + ExecuteTestContextCancellation(t, conn, heavyQuery) +} + +func TestContextCancellationOfHeavyInsertFromS3(t *testing.T) { + var ( + heavyQuery = `INSERT INTO test_query_cancellation.trips + SELECT + trip_id, + pickup_datetime, + dropoff_datetime, + pickup_longitude, + pickup_latitude, + dropoff_longitude, + dropoff_latitude, + passenger_count, + trip_distance, + fare_amount, + extra, + tip_amount, + tolls_amount, + total_amount, + payment_type, + pickup_ntaname, + dropoff_ntaname + FROM s3( + 'https://datasets-documentation.s3.eu-west-3.amazonaws.com/nyc-taxi/trips_{0..2}.gz', + 'TabSeparatedWithNames' + );` + ) + + conn, err := SetupTestContextCancellationType1(t, true) + assert.Nil(t, err) + assert.NotNil(t, conn) + + ExecuteTestContextCancellation(t, conn, heavyQuery) +} + +func SetupTestContextCancellationType1(t *testing.T, fillTableWithRandomData bool) (clickhouse.Conn, error) { + var ( + q1 = "CREATE DATABASE IF NOT EXISTS test_query_cancellation" + q2 = "DROP TABLE IF EXISTS test_query_cancellation.trips" + q3 = `CREATE TABLE test_query_cancellation.trips ( + trip_id UInt32, + pickup_datetime DateTime, + dropoff_datetime DateTime, + pickup_longitude Nullable(Float64), + pickup_latitude Nullable(Float64), + dropoff_longitude Nullable(Float64), + dropoff_latitude Nullable(Float64), + passenger_count UInt8, + trip_distance Float32, + fare_amount Float32, + extra Float32, + tip_amount Float32, + tolls_amount Float32, + total_amount Float32, + payment_type Enum('CSH' = 1, 'CRE' = 2, 'NOC' = 3, 'DIS' = 4, 'UNK' = 5), + pickup_ntaname LowCardinality(String), + dropoff_ntaname LowCardinality(String) + ) + ENGINE = MergeTree + PRIMARY KEY (pickup_datetime, dropoff_datetime);` + q4 = `INSERT INTO test_query_cancellation.trips + SELECT + number + 1 AS trip_id, + now() - INTERVAL intDiv(number, 100) SECOND AS pickup_datetime, + now() - INTERVAL intDiv(number, 100) SECOND + INTERVAL rand() % 3600 SECOND AS dropoff_datetime, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 - 74.00) AS pickup_longitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 + 40.50) AS pickup_latitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 - 74.00) AS dropoff_longitude, + if(rand() % 2 = 0, NULL, (rand() % 3600) / 100.0 + 40.50) AS dropoff_latitude, + rand() % 6 + 1 AS passenger_count, + (rand() % 2000) / 100.0 AS trip_distance, + (rand() % 5000) / 100.0 AS fare_amount, + (rand() % 500) / 100.0 AS extra, + (rand() % 1000) / 100.0 AS tip_amount, + (rand() % 300) / 100.0 AS tolls_amount, + (rand() % 6000) / 100.0 AS total_amount, + CAST(rand() % 5 + 1 AS Enum('CSH' = 1, 'CRE' = 2, 'NOC' = 3, 'DIS' = 4, 'UNK' = 5)) AS payment_type, + 'Neighborhood ' || toString(rand() % 100 + 1) AS pickup_ntaname, + 'Neighborhood ' || toString(rand() % 100 + 1) AS dropoff_ntaname + FROM numbers(30000000);` + ) + + prepareQueries := []string{q1, q2, q3} + if fillTableWithRandomData { + prepareQueries = append(prepareQueries, q4) + } + + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + + assert.Nil(t, err) + assert.NotNil(t, conn) + + if err = conn.Ping(context.Background()); err != nil { + return nil, err + } + + t.Log("Connected.") + + // prepare table + for _, query := range prepareQueries { + err = conn.Exec(context.Background(), query) + if err != nil { + log.Printf("Finished with error: %v\n", err) + conn.Close() + return nil, err + } + } + + return conn, nil +} + +func ExecuteTestContextCancellation(t *testing.T, conn clickhouse.Conn, query string) { + // prepare context + ctx, cancelCtx := context.WithCancel(context.Background()) + defer cancelCtx() + + doneCh := make(chan bool, 1) + queryTimeCh := make(chan time.Duration, 1) + + // run query in background + go func() { + // running heavy query... + start := time.Now() + defer func() { + queryTimeCh <- time.Since(start) + doneCh <- true + }() + + if err := conn.Exec(ctx, query); err != nil { + return + } + }() + + cancelBackoff := 3 * time.Second + + // let query run for awhile and stop + go func() { + time.Sleep(3 * time.Second) + cancelCtx() + }() + + <-doneCh + conn.Close() + + queryTime := <-queryTimeCh + + assert.Less(t, queryTime-cancelBackoff, time.Second) +}