Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Mar 15, 2024
1 parent 39e52e4 commit 0e3ace3
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 50 deletions.
18 changes: 13 additions & 5 deletions compress.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

package mysql

import (
Expand Down Expand Up @@ -120,10 +128,10 @@ func (c *compressor) uncompressPacket() error {
fmt.Fprintf(os.Stderr, "uncompress cmplen=%v uncomplen=%v seq=%v\n",
comprLength, uncompressedLength, compressionSequence)
}
if compressionSequence != c.mc.compressionSequence {
if compressionSequence != c.mc.compresSequence {
return ErrPktSync
}
c.mc.compressionSequence++
c.mc.compresSequence++

comprData, err := c.mc.buf.readNext(comprLength)
if err != nil {
Expand Down Expand Up @@ -206,15 +214,15 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err
c.mc.cfg.Logger.Print(
fmt.Sprintf(
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
comprLength, uncompressedLen, c.mc.compressionSequence))
comprLength, uncompressedLen, c.mc.compresSequence))
}

// compression header
data[0] = byte(0xff & comprLength)
data[1] = byte(0xff & (comprLength >> 8))
data[2] = byte(0xff & (comprLength >> 16))

data[3] = c.mc.compressionSequence
data[3] = c.mc.compresSequence

// this value is never greater than maxPayloadLength
data[4] = byte(0xff & uncompressedLen)
Expand All @@ -226,6 +234,6 @@ func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) err
return err
}

c.mc.compressionSequence++
c.mc.compresSequence++
return nil
}
28 changes: 18 additions & 10 deletions compress_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

package mysql

import (
Expand Down Expand Up @@ -30,7 +38,7 @@ func newMockBuf(data []byte) buffer {
func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte {
// get status variables

cs := mc.compressionSequence
cs := mc.compresSequence

var b bytes.Buffer
cw := newCompressor(mc, &b)
Expand All @@ -46,13 +54,13 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by
}

if len(uncompressedPacket) > 0 {
if mc.compressionSequence != (cs + 1) {
t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)
if mc.compresSequence != (cs + 1) {
t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compresSequence)
}

} else {
if mc.compressionSequence != cs {
t.Fatalf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressionSequence)
if mc.compresSequence != cs {
t.Fatalf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compresSequence)
}
}

Expand All @@ -62,7 +70,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by
// uncompressHelper uncompresses compressedPacket and checks state variables
func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte {
// get status variables
cs := mc.compressionSequence
cs := mc.compresSequence

// mocking out buf variable
mc.buf = newMockBuf(compressedPacket)
Expand All @@ -76,12 +84,12 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS
}

if expSize > 0 {
if mc.compressionSequence != (cs + 1) {
t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)
if mc.compresSequence != (cs + 1) {
t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compresSequence)
}
} else {
if mc.compressionSequence != cs {
t.Fatalf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressionSequence)
if mc.compresSequence != cs {
t.Fatalf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compresSequence)
}
}
return uncompressedPacket
Expand Down
50 changes: 24 additions & 26 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ import (
)

type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket().
packetReader packetReader
packetWriter io.Writer
cfg *Config
connector *connector
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
compressionSequence uint8
parseTime bool
compress bool
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket().
packetReader packetReader
packetWriter io.Writer
cfg *Config
connector *connector
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
compresSequence uint8
parseTime bool
compress bool

// for context support (Go 1.8+)
watching bool
Expand All @@ -52,21 +52,19 @@ type packetReader interface {
readNext(need int) ([]byte, error)
}

func (mc *mysqlConn) resetSeqNo() {
func (mc *mysqlConn) resetSequenceNr() {
mc.sequence = 0
mc.compressionSequence = 0
mc.compresSequence = 0
}

// syncSeqNo must be called when:
// - at least one large packet is sent (split packet happend), and
// - finished writing, before start reading.
func (mc *mysqlConn) syncSeqNo() {
// This syncs compressionSequence to sequence.
// This is done in `net_flush()` in MySQL and MariaDB.
// syncSequenceNr must be called when finished writing some packet and before start reading.
func (mc *mysqlConn) syncSequenceNr() {
// Syncs compressionSequence to sequence.
// This is not documented but done in `net_flush()` in MySQL and MariaDB.
// https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171
// https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293
if mc.compress {
mc.sequence = mc.compressionSequence
mc.sequence = mc.compresSequence
}
}

Expand Down
2 changes: 1 addition & 1 deletion infile.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) {
if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
return ioErr
}
mc.conn().syncSeqNo()
mc.conn().syncSequenceNr()

// read OK packet
if err == nil {
Expand Down
16 changes: 8 additions & 8 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {

func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.resetSeqNo()
mc.resetSequenceNr()

data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
Expand All @@ -440,7 +440,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {

func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
// Reset Packet Sequence
mc.resetSeqNo()
mc.resetSequenceNr()

pktLen := 1 + len(arg)
data, err := mc.buf.takeBuffer(pktLen + 4)
Expand All @@ -458,13 +458,13 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {

// Send CMD packet
err = mc.writePacket(data)
mc.syncSeqNo()
mc.syncSequenceNr()
return err
}

func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.resetSeqNo()
mc.resetSequenceNr()

data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
Expand Down Expand Up @@ -948,7 +948,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
pktLen = dataOffset + argLen
}

stmt.mc.resetSeqNo()
stmt.mc.resetSequenceNr()
// Add command byte [1 byte]
data[4] = comStmtSendLongData

Expand All @@ -972,7 +972,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
}

// Reset Packet Sequence
stmt.mc.resetSeqNo()
stmt.mc.resetSequenceNr()
return nil
}

Expand All @@ -997,7 +997,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
}

// Reset packet-sequence
mc.resetSeqNo()
mc.resetSequenceNr()

var data []byte
var err error
Expand Down Expand Up @@ -1219,7 +1219,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
}

err = mc.writePacket(data)
mc.syncSeqNo()
mc.syncSequenceNr()
return err
}

Expand Down

0 comments on commit 0e3ace3

Please sign in to comment.