Skip to content

Commit

Permalink
Improve the performance of QueryContext by reusing the result channel
Browse files Browse the repository at this point in the history
This commit improves the performance of QueryContext by changing it to
reuse the result channel instead of creating a new one for each query.
This is particularly impactful for queries that scan more than one row.

It also adds a test that actually exercises the sqlite3_interrupt logic
since the existing tests did not. Those tests cancelled the context
before scanning any of the rows and could be made to pass without ever
calling sqlite3_interrupt. The below version of SQLiteRows.Next passes
the previous tests:

```go
func (rc *SQLiteRows) Next(dest []driver.Value) error {
	rc.s.mu.Lock()
	defer rc.s.mu.Unlock()
	if rc.s.closed {
		return io.EOF
	}
	if err := rc.ctx.Err(); err != nil {
		return err
	}
	return rc.nextSyncLocked(dest)
}
```

Benchmark results:
```
goos: darwin
goarch: arm64
pkg: github.com/mattn/go-sqlite3
cpu: Apple M1 Max
                                          │    b.txt    │               n.txt                │
                                          │   sec/op    │   sec/op     vs base               │
Suite/BenchmarkQueryContext/Background-10   4.088µ ± 1%   4.154µ ± 3%  +1.60% (p=0.011 n=10)
Suite/BenchmarkQueryContext/WithCancel-10   12.84µ ± 3%   11.67µ ± 3%  -9.08% (p=0.000 n=10)
geomean                                     7.245µ        6.963µ       -3.89%

                                          │    b.txt     │                 n.txt                  │
                                          │     B/op     │     B/op      vs base                  │
Suite/BenchmarkQueryContext/Background-10     400.0 ± 0%     400.0 ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkQueryContext/WithCancel-10   2.547Ki ± 0%   1.282Ki ± 0%  -49.67% (p=0.000 n=10)
geomean                                      1021.4          724.6       -29.06%
¹ all samples are equal

                                          │   b.txt    │                n.txt                 │
                                          │ allocs/op  │ allocs/op   vs base                  │
Suite/BenchmarkQueryContext/Background-10   12.00 ± 0%   12.00 ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkQueryContext/WithCancel-10   49.00 ± 0%   28.00 ± 0%  -42.86% (p=0.000 n=10)
geomean                                     24.25        18.33       -24.41%
¹ all samples are equal
```
  • Loading branch information
charlievieth committed Nov 7, 2024
1 parent 41871ea commit d3c66c9
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 15 deletions.
22 changes: 14 additions & 8 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ type SQLiteRows struct {
cls bool
closed bool
ctx context.Context // no better alternative to pass context into Next() method
resultCh chan error
}

type functionInfo struct {
Expand Down Expand Up @@ -2172,24 +2173,29 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return io.EOF
}

if rc.ctx.Done() == nil {
done := rc.ctx.Done()
if done == nil {
return rc.nextSyncLocked(dest)
}
resultCh := make(chan error)
defer close(resultCh)
if err := rc.ctx.Err(); err != nil {
return err // Fast check if the channel is closed
}
if rc.resultCh == nil {
rc.resultCh = make(chan error)
}
go func() {
resultCh <- rc.nextSyncLocked(dest)
rc.resultCh <- rc.nextSyncLocked(dest)
}()
select {
case err := <-resultCh:
case err := <-rc.resultCh:
return err
case <-rc.ctx.Done():
case <-done:
select {
case <-resultCh: // no need to interrupt
case <-rc.resultCh: // no need to interrupt
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
C.sqlite3_interrupt(rc.s.c.db)
<-resultCh // ensure goroutine completed
<-rc.resultCh // ensure goroutine completed
}
return rc.ctx.Err()
}
Expand Down
147 changes: 147 additions & 0 deletions sqlite3_go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ package sqlite3
import (
"context"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"os"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -268,6 +270,151 @@ func TestQueryRowContextCancelParallel(t *testing.T) {
}
}

// Test that we can successfully interrupt a long running query when
// the context is canceled. The previous two QueryRowContext tests
// only test that we handle a previously cancelled context and thus
// do not call sqlite3_interrupt.
func TestQueryRowContextCancelInterrupt(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()

// Test that we have the unixepoch function and if not skip the test.
if _, err := db.Exec(`SELECT unixepoch(datetime(100000, 'unixepoch', 'localtime'))`); err != nil {
libVersion, libVersionNumber, sourceID := Version()
if strings.Contains(err.Error(), "no such function: unixepoch") {
t.Skip("Skipping the 'unixepoch' function is not implemented in "+
"this version of sqlite3:", libVersion, libVersionNumber, sourceID)
}
t.Fatal(err)
}

const createTableStmt = `
CREATE TABLE timestamps (
ts TIMESTAMP NOT NULL
);`
if _, err := db.Exec(createTableStmt); err != nil {
t.Fatal(err)
}

stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()

// Computationally expensive query that consumes many rows. This is needed
// to test cancellation because queries are not interrupted immediately.
// Instead, queries are only halted at certain checkpoints where the
// sqlite3.isInterrupted is checked and true.
queryStmt := `
SELECT
SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1,
SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2,
SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3,
SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4
FROM
timestamps
WHERE datetime(ts, 'unixepoch', 'localtime')
LIKE
?;`

query := func(t *testing.T, timeout time.Duration) (int, error) {
// Create a complicated pattern to match timestamps
const pattern = "%2%0%2%4%-%-%:%:%"
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
rows, err := db.QueryContext(ctx, queryStmt, pattern)
if err != nil {
return 0, err
}
var count int
for rows.Next() {
var n int64
if err := rows.Scan(&n, &n, &n, &n); err != nil {
return count, err
}
count++
}
return count, rows.Err()
}

average := func(n int, fn func()) time.Duration {
start := time.Now()
for i := 0; i < n; i++ {
fn()
}
return time.Since(start) / time.Duration(n)
}

createRows := func(n int) {
t.Logf("Creating %d rows", n)
if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil {
t.Fatal(err)
}
ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix()
rr := rand.New(rand.NewSource(1234))
for i := 0; i < n; i++ {
if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil {
t.Fatal(err)
}
}
}

const TargetRuntime = 200 * time.Millisecond
const N = 5_000 // Number of rows to insert at a time

// Create enough rows that the query takes ~200ms to run.
start := time.Now()
createRows(N)
baseAvg := average(4, func() {
if _, err := query(t, time.Hour); err != nil {
t.Fatal(err)
}
})
t.Log("Base average:", baseAvg)
rowCount := N * (int(TargetRuntime/baseAvg) + 1)
createRows(rowCount)
t.Log("Table setup time:", time.Since(start))

// Set the timeout to 1/10 of the average query time.
avg := average(2, func() {
n, err := query(t, time.Hour)
if err != nil {
t.Fatal(err)
}
if n == 0 {
t.Fatal("scanned zero rows")
}
})
// Guard against the timeout being too short to reliably test.
if avg < TargetRuntime/2 {
t.Fatalf("Average query runtime should be around %s got: %s ",
TargetRuntime, avg)
}
timeout := (avg / 10).Round(100 * time.Microsecond)
t.Logf("Average: %s Timeout: %s", avg, timeout)

for i := 0; i < 10; i++ {
tt := time.Now()
n, err := query(t, timeout)
if !errors.Is(err, context.DeadlineExceeded) {
fn := t.Errorf
if err != nil {
fn = t.Fatalf
}
fn("expected error %v got %v", context.DeadlineExceeded, err)
}
d := time.Since(tt)
t.Logf("%d: rows: %d duration: %s", i, n, d)
if d > timeout*4 {
t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d)
}
}
}

func TestExecCancel(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
Expand Down
82 changes: 75 additions & 7 deletions sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package sqlite3

import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"errors"
Expand Down Expand Up @@ -2030,7 +2031,7 @@ func BenchmarkCustomFunctions(b *testing.B) {
}

func TestSuite(t *testing.T) {
initializeTestDB(t)
initializeTestDB(t, false)
defer freeTestDB()

for _, test := range tests {
Expand All @@ -2039,7 +2040,7 @@ func TestSuite(t *testing.T) {
}

func BenchmarkSuite(b *testing.B) {
initializeTestDB(b)
initializeTestDB(b, true)
defer freeTestDB()

for _, benchmark := range benchmarks {
Expand Down Expand Up @@ -2068,8 +2069,13 @@ type TestDB struct {

var db *TestDB

func initializeTestDB(t testing.TB) {
tempFilename := TempFilename(t)
func initializeTestDB(t testing.TB, memory bool) {
var tempFilename string
if memory {
tempFilename = ":memory:"
} else {
tempFilename = TempFilename(t)
}
d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil {
os.Remove(tempFilename)
Expand All @@ -2084,9 +2090,11 @@ func freeTestDB() {
if err != nil {
panic(err)
}
err = os.Remove(db.tempFilename)
if err != nil {
panic(err)
if db.tempFilename != "" && db.tempFilename != ":memory:" {
err := os.Remove(db.tempFilename)
if err != nil {
panic(err)
}
}
}

Expand All @@ -2107,6 +2115,7 @@ var tests = []testing.InternalTest{
var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkExec", F: benchmarkExec},
{Name: "BenchmarkQuery", F: benchmarkQuery},
{Name: "BenchmarkQueryContext", F: benchmarkQueryContext},
{Name: "BenchmarkParams", F: benchmarkParams},
{Name: "BenchmarkStmt", F: benchmarkStmt},
{Name: "BenchmarkRows", F: benchmarkRows},
Expand Down Expand Up @@ -2479,6 +2488,65 @@ func benchmarkQuery(b *testing.B) {
}
}

// benchmarkQueryContext is benchmark for QueryContext
func benchmarkQueryContext(b *testing.B) {
const createTableStmt = `
CREATE TABLE IF NOT EXISTS query_context(
id INTEGER PRIMARY KEY
);
DELETE FROM query_context;
VACUUM;`
test := func(ctx context.Context, b *testing.B) {
if _, err := db.Exec(createTableStmt); err != nil {
b.Fatal(err)
}
for i := 0; i < 10; i++ {
_, err := db.Exec("INSERT INTO query_context VALUES (?);", int64(i))
if err != nil {
db.Fatal(err)
}
}
stmt, err := db.PrepareContext(ctx, `SELECT id FROM query_context;`)
if err != nil {
b.Fatal(err)
}
b.Cleanup(func() { stmt.Close() })

var n int
for i := 0; i < b.N; i++ {
rows, err := stmt.QueryContext(ctx)
if err != nil {
b.Fatal(err)
}
for rows.Next() {
if err := rows.Scan(&n); err != nil {
b.Fatal(err)
}
}
if err := rows.Err(); err != nil {
b.Fatal(err)
}
}
}

// When the context does not have a Done channel we should use
// the fast path that directly handles the query instead of
// handling it in a goroutine. This benchmark also serves to
// highlight the performance impact of using a cancelable
// context.
b.Run("Background", func(b *testing.B) {
test(context.Background(), b)
})

// Benchmark a query with a context that can be canceled. This
// requires using a goroutine and is thus much slower.
b.Run("WithCancel", func(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
test(ctx, b)
})
}

// benchmarkParams is benchmark for params
func benchmarkParams(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down

0 comments on commit d3c66c9

Please sign in to comment.