diff --git a/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go b/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go index a10d50b53592e..ae878d8dbaca8 100644 --- a/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go +++ b/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go @@ -40,7 +40,13 @@ func main() { flag.Parse() - srv, err := example.NewSQLiteFlightSQLServer() + db, err := example.CreateDB() + if err != nil { + log.Fatal(err) + } + defer db.Close() + + srv, err := example.NewSQLiteFlightSQLServer(db) if err != nil { log.Fatal(err) } diff --git a/go/arrow/flight/flightsql/example/sqlite_server.go b/go/arrow/flight/flightsql/example/sqlite_server.go index 5dfd6d99a6e6b..0742113000c2a 100644 --- a/go/arrow/flight/flightsql/example/sqlite_server.go +++ b/go/arrow/flight/flightsql/example/sqlite_server.go @@ -140,20 +140,8 @@ func prepareQueryForGetKeys(filter string) string { ` ORDER BY pk_catalog_name, pk_schema_name, pk_table_name, pk_key_name, key_sequence` } -type Statement struct { - stmt *sql.Stmt - params [][]interface{} -} - -type SQLiteFlightSQLServer struct { - flightsql.BaseServer - db *sql.DB - - prepared sync.Map -} - -func NewSQLiteFlightSQLServer() (*SQLiteFlightSQLServer, error) { - db, err := sql.Open("sqlite", ":memory:") +func CreateDB() (*sql.DB, error) { + db, err := sql.Open("sqlite", "file::memory:?cache=shared") if err != nil { return nil, err } @@ -178,10 +166,27 @@ func NewSQLiteFlightSQLServer() (*SQLiteFlightSQLServer, error) { INSERT INTO intTable (keyName, value, foreignId) VALUES ('negative one', -1, 1); INSERT INTO intTable (keyName, value, foreignId) VALUES (NULL, NULL, NULL); `) - if err != nil { + db.Close() return nil, err } + + return db, nil +} + +type Statement struct { + stmt *sql.Stmt + params [][]interface{} +} + +type SQLiteFlightSQLServer struct { + flightsql.BaseServer + db *sql.DB + + prepared sync.Map +} + +func NewSQLiteFlightSQLServer(db *sql.DB) (*SQLiteFlightSQLServer, error) { ret := &SQLiteFlightSQLServer{db: db} ret.Alloc = memory.DefaultAllocator for k, v := range SqlInfoResultMap() { diff --git a/go/arrow/flight/flightsql/sqlite_server_test.go b/go/arrow/flight/flightsql/sqlite_server_test.go index 81966a537bb11..7df9d932ed1a5 100644 --- a/go/arrow/flight/flightsql/sqlite_server_test.go +++ b/go/arrow/flight/flightsql/sqlite_server_test.go @@ -21,6 +21,7 @@ package flightsql_test import ( "context" + "database/sql" "os" "strings" "testing" @@ -42,6 +43,7 @@ import ( type FlightSqliteServerSuite struct { suite.Suite + db *sql.DB srv *example.SQLiteFlightSQLServer s flight.Server cl *flightsql.Client @@ -71,7 +73,9 @@ func (s *FlightSqliteServerSuite) SetupTest() { var err error s.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) s.s = flight.NewServerWithMiddleware(nil) - s.srv, err = example.NewSQLiteFlightSQLServer() + s.db, err = example.CreateDB() + s.Require().NoError(err) + s.srv, err = example.NewSQLiteFlightSQLServer(s.db) s.Require().NoError(err) s.srv.Alloc = s.mem @@ -89,6 +93,8 @@ func (s *FlightSqliteServerSuite) TearDownTest() { s.Require().NoError(s.cl.Close()) s.s.Shutdown() s.srv = nil + err := s.db.Close() + s.Require().NoError(err) s.mem.AssertSize(s.T(), 0) }