Skip to content

Commit

Permalink
Merge pull request #368 from sylwiaszunejko/ring-describer-refactor
Browse files Browse the repository at this point in the history
Ring describer refactor
  • Loading branch information
dkropachev authored Dec 11, 2024
2 parents b578328 + b849a70 commit 7432982
Show file tree
Hide file tree
Showing 11 changed files with 664 additions and 366 deletions.
90 changes: 43 additions & 47 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) {
// Deprecated.
var TimeoutLimit int64 = 0

type ConnInterface interface {
Close()
exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error)
awaitSchemaAgreement(ctx context.Context) error
executeQuery(ctx context.Context, qry *Query) *Iter
querySystem(ctx context.Context, query string) *Iter
getIsSchemaV2() bool
setSchemaV2(s bool)
query(ctx context.Context, statement string, values ...interface{}) (iter *Iter)
getScyllaSupported() scyllaSupported
}

// Conn is a single connection to a Cassandra node. It can be used to execute
// queries, but users are usually advised to use a more reliable, higher
// level API.
Expand Down Expand Up @@ -212,6 +224,18 @@ type Conn struct {
tabletsRoutingV1 int32
}

func (c *Conn) getIsSchemaV2() bool {
return c.isSchemaV2
}

func (c *Conn) setSchemaV2(s bool) {
c.isSchemaV2 = s
}

func (c *Conn) getScyllaSupported() scyllaSupported {
return c.scyllaSupported
}

// connect establishes a connection to a Cassandra node using session's connection config.
func (s *Session) connect(ctx context.Context, host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
return s.dial(ctx, host, s.connCfg, errorHandler)
Expand Down Expand Up @@ -350,6 +374,10 @@ func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error {
c.w = newWriteCoalescer(c.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
}

if c.isScyllaConn() { // ScyllaDB does not support system.peers_v2
c.setSchemaV2(false)
}

go c.serve(ctx)
go c.heartBeat(ctx)

Expand Down Expand Up @@ -1768,52 +1796,19 @@ func (c *Conn) query(ctx context.Context, statement string, values ...interface{
return c.executeQuery(ctx, q)
}

func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter {
func (c *Conn) querySystem(ctx context.Context, query string) *Iter {
usingClause := ""
if c.session.control != nil {
usingClause = c.session.usingTimeoutClause
}
var (
peerSchema = "SELECT * FROM system.peers" + usingClause
peerV2Schemas = "SELECT * FROM system.peers_v2" + usingClause
)

c.mu.Lock()
if isScyllaConn((c)) { // ScyllaDB does not support system.peers_v2
c.isSchemaV2 = false
}

isSchemaV2 := c.isSchemaV2
c.mu.Unlock()

if version.AtLeast(4, 0, 0) && isSchemaV2 {
// Try "system.peers_v2" and fallback to "system.peers" if it's not found
iter := c.query(ctx, peerV2Schemas)

err := iter.checkErrAndNotFound()
if err != nil {
if errFrame, ok := err.(errorFrame); ok && errFrame.code == ErrCodeInvalid { // system.peers_v2 not found, try system.peers
c.mu.Lock()
c.isSchemaV2 = false
c.mu.Unlock()
return c.query(ctx, peerSchema)
} else {
return iter
}
}
return iter
} else {
return c.query(ctx, peerSchema)
}
queryStmt := query + usingClause
return c.query(ctx, queryStmt)
}

func (c *Conn) querySystemLocal(ctx context.Context) *Iter {
usingClause := ""
if c.session.control != nil {
usingClause = c.session.usingTimeoutClause
}
return c.query(ctx, "SELECT * FROM system.local WHERE key='local'"+usingClause)
}
const qrySystemPeers = "SELECT * FROM system.peers"
const qrySystemPeersV2 = "SELECT * FROM system.peers_2"

const qrySystemLocal = "SELECT * FROM system.local WHERE key='local'"

func getSchemaAgreement(queryLocalSchemasRows []string, querySystemPeersRows []map[string]interface{}, connectAddress net.IP, port int, translateAddressPort func(addr net.IP, port int) (net.IP, int), logger StdLogger) (err error) {
versions := make(map[string]struct{})
Expand Down Expand Up @@ -1850,11 +1845,7 @@ func getSchemaAgreement(queryLocalSchemasRows []string, querySystemPeersRows []m
}

func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {
usingClause := ""
if c.session.control != nil {
usingClause = c.session.usingTimeoutClause
}
var localSchemas = "SELECT schema_version FROM system.local WHERE key='local'" + usingClause
var localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"

var schemaVersion string

Expand All @@ -1874,7 +1865,12 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {
}

for time.Now().Before(endDeadline) {
iter := c.querySystemPeers(ctx, c.host.version)
var iter *Iter
if c.getIsSchemaV2() {
iter = c.querySystem(ctx, qrySystemPeersV2)
} else {
iter = c.querySystem(ctx, qrySystemPeers)
}
var systemPeersRows []map[string]interface{}
systemPeersRows, err = iter.SliceMap()
if err != nil {
Expand All @@ -1886,7 +1882,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {

schemaVersions := []string{}

iter = c.query(ctx, localSchemas)
iter = c.querySystem(ctx, localSchemas)
for iter.Scan(&schemaVersion) {
schemaVersions = append(schemaVersions, schemaVersion)
schemaVersion = ""
Expand Down
2 changes: 1 addition & 1 deletion connectionpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func (pool *hostConnPool) initConnPicker(conn *Conn) {
return
}

if isScyllaConn(conn) {
if conn.isScyllaConn() {
pool.connPicker = newScyllaConnPicker(conn)
return
}
Expand Down
64 changes: 23 additions & 41 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ const (
controlConnClosing = -1
)

type controlConnection interface {
getConn() *connHost
awaitSchemaAgreement() error
query(statement string, values ...interface{}) (iter *Iter)
discoverProtocol(hosts []*HostInfo) (int, error)
connect(hosts []*HostInfo) error
close()
getSession() *Session
}

// Ensure that the atomic variable is aligned to a 64bit boundary
// so that atomic operations can be applied on 32bit architectures.
type controlConn struct {
Expand All @@ -49,6 +59,10 @@ type controlConn struct {
quit chan struct{}
}

func (c *controlConn) getSession() *Session {
return c.session
}

func createControlConn(session *Session) *controlConn {

control := &controlConn{
Expand Down Expand Up @@ -264,18 +278,18 @@ func (c *controlConn) connect(hosts []*HostInfo) error {
}

type connHost struct {
conn *Conn
conn ConnInterface
host *HostInfo
}

func (c *controlConn) setupConn(conn *Conn) error {
// we need up-to-date host info for the filterHost call below
iter := conn.querySystemLocal(context.TODO())
iter := conn.querySystem(context.TODO(), qrySystemLocal)
defaultPort := 9042
if tcpAddr, ok := conn.conn.RemoteAddr().(*net.TCPAddr); ok {
defaultPort = tcpAddr.Port
}
host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress, defaultPort)
host, err := hostInfoFromIter(iter, conn.host.connectAddress, defaultPort, c.session.cfg.translateAddressPort)
if err != nil {
return err
}
Expand Down Expand Up @@ -359,7 +373,7 @@ func (c *controlConn) reconnect() {
return
}

err = c.session.refreshRing()
err = c.session.refreshRingNow()
if err != nil {
c.session.logger.Printf("gocql: unable to refresh ring: %v\n", err)
}
Expand Down Expand Up @@ -462,45 +476,14 @@ func (c *controlConn) writeFrame(w frameBuilder) (frame, error) {
return framer.parseFrame()
}

func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter {
const maxConnectAttempts = 5
connectAttempts := 0

for i := 0; i < maxConnectAttempts; i++ {
ch := c.getConn()
if ch == nil {
if connectAttempts > maxConnectAttempts {
break
}

connectAttempts++

c.reconnect()
continue
}

return fn(ch)
}

return &Iter{err: errNoControl}
}

func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
return c.withConnHost(func(ch *connHost) *Iter {
return fn(ch.conn)
})
}

// query will return nil if the connection is closed or nil
func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)

for {
iter = c.withConn(func(conn *Conn) *Iter {
// we want to keep the query on the control connection
q.conn = conn
return conn.executeQuery(context.TODO(), q)
})
ch := c.getConn()
q.conn = ch.conn.(*Conn)
iter = ch.conn.executeQuery(context.TODO(), q)

if gocqlDebug && iter.err != nil {
c.session.logger.Printf("control: error executing %q: %v\n", statement, iter.err)
Expand All @@ -516,9 +499,8 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
}

func (c *controlConn) awaitSchemaAgreement() error {
return c.withConn(func(conn *Conn) *Iter {
return &Iter{err: conn.awaitSchemaAgreement(context.TODO())}
}).err
ch := c.getConn()
return (&Iter{err: ch.conn.awaitSchemaAgreement(context.TODO())}).err
}

func (c *controlConn) close() {
Expand Down
18 changes: 16 additions & 2 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ type FrameHeaderObserver interface {
ObserveFrameHeader(context.Context, ObservedFrameHeader)
}

type framerInterface interface {
ReadBytesInternal() ([]byte, error)
GetCustomPayload() map[string][]byte
GetHeaderWarnings() []string
}

// a framer is responsible for reading, writing and parsing frames on a single stream
type framer struct {
proto byte
Expand Down Expand Up @@ -1866,7 +1872,7 @@ func (f *framer) readStringList() []string {
return l
}

func (f *framer) readBytesInternal() ([]byte, error) {
func (f *framer) ReadBytesInternal() ([]byte, error) {
size := f.readInt()
if size < 0 {
return nil, nil
Expand All @@ -1883,7 +1889,7 @@ func (f *framer) readBytesInternal() ([]byte, error) {
}

func (f *framer) readBytes() []byte {
l, err := f.readBytesInternal()
l, err := f.ReadBytesInternal()
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -2015,6 +2021,14 @@ func (f *framer) writeCustomPayload(customPayload *map[string][]byte) {
}
}

func (f *framer) GetCustomPayload() map[string][]byte {
return f.customPayload
}

func (f *framer) GetHeaderWarnings() []string {
return f.header.warnings
}

// these are protocol level binary types
func (f *framer) writeInt(n int32) {
f.buf = appendInt(f.buf, n)
Expand Down
Loading

0 comments on commit 7432982

Please sign in to comment.