Skip to content

Commit

Permalink
add WithStmtCfg - tackles #135
Browse files Browse the repository at this point in the history
  • Loading branch information
tgulacsi committed Dec 1, 2016
1 parent 937a4b6 commit 7c9d054
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 12 deletions.
26 changes: 16 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ora
--
import "gopkg.in/rana/ora.v3"
import "gopkg.in/rana/ora.v4"

Package ora implements an Oracle database driver.

Expand All @@ -11,18 +11,24 @@ Package ora implements an Oracle database driver.
import (
"database/sql"

_ "gopkg.in/rana/ora.v3"
_ "gopkg.in/rana/ora.v4"
)

func main() {
db, err := sql.Open("ora", "user/passw@host:port/sid")
defer db.Close()

// Set timeout (Go 1.8)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// Set prefetch count (Go 1.8)
ctx = ora.WithStmtCfg(ctx, ora.Cfg().StmtCfg.SetPrefetchCount(50000))
rows, err := db.QueryContext(ctx, "SELECT * FROM user_objects")
}

Call stored procedure with OUT parameters:

import (
"gopkg.in/rana/ora.v3"
"gopkg.in/rana/ora.v4"
)

func main() {
Expand Down Expand Up @@ -95,7 +101,7 @@ your system, maybe tailored to your specific locations) to a folder in
The ora package has no external Go dependencies and is available on GitHub and
gopkg.in:

go get gopkg.in/rana/ora.v3
go get gopkg.in/rana/ora.v4


### Data Types
Expand Down Expand Up @@ -1015,8 +1021,8 @@ ora driver methods. For example:
To use the standard Go log package:

import (
"gopkg.in/rana/ora.v3"
"gopkg.in/rana/ora.v3/lg"
"gopkg.in/rana/ora.v4"
"gopkg.in/rana/ora.v4/lg"
)

func main() {
Expand Down Expand Up @@ -1047,8 +1053,8 @@ To use the glog package:

import (
"flag"
"gopkg.in/rana/ora.v3"
"gopkg.in/rana/ora.v3/glg"
"gopkg.in/rana/ora.v4"
"gopkg.in/rana/ora.v4/glg"
)

func main() {
Expand Down Expand Up @@ -1081,8 +1087,8 @@ which produces a sample log of:
To use the log15 package:

import (
"gopkg.in/rana/ora.v3"
"gopkg.in/rana/ora.v3/lg15"
"gopkg.in/rana/ora.v4"
"gopkg.in/rana/ora.v4/lg15"
)
func main() {
// use the optional log15 package for ora logging
Expand Down
27 changes: 27 additions & 0 deletions ctx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2016 Tamás Gulácsi. All rights reserved.
// Use of this source code is governed by The MIT License
// found in the accompanying LICENSE file.

package ora

import "context"

const (
stmtCfgKey = "stmtCfg"
)

// ctxStmtCfg returns the StmtCfg from the context, and
// whether it exist at all.
func ctxStmtCfg(ctx context.Context) (StmtCfg, bool) {
cfg, ok := ctx.Value(stmtCfgKey).(StmtCfg)
return cfg, ok
}

// WithStmtCfg returns a new context, with the given cfg that
// can be used to configure several parameters.
//
// WARNING: the StmtCfg must be derived from Cfg(), or NewStmtCfg(),
// as an empty StmtCfg is not usable!
func WithStmtCfg(ctx context.Context, cfg StmtCfg) context.Context {
return context.WithValue(ctx, stmtCfgKey, cfg)
}
7 changes: 7 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ Package ora implements an Oracle database driver.
func main() {
db, err := sql.Open("ora", "user/passw@host:port/sid")
defer db.Close()
// Set timeout (Go 1.8)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// Set prefetch count (Go 1.8)
ctx = ora.WithStmtCfg(ctx, ora.Cfg().StmtCfg.SetPrefetchCount(50000))
rows, err := db.QueryContext(ctx, "SELECT * FROM user_objects")
defer rows.Close()
}
Call stored procedure with OUT parameters:
Expand Down
4 changes: 2 additions & 2 deletions drvStmt_go1_8.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (ds *DrvStmt) ExecContext(ctx context.Context, values []driver.NamedValue)
grp, ctx := errgroup.WithContext(ctx)
grp.Go(func() error {
var err error
res.rowsAffected, res.lastInsertId, err = ds.stmt.exe(params, false)
res.rowsAffected, res.lastInsertId, err = ds.stmt.exeC(ctx, params, false)
if err != nil {
return errE(err)
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func (ds *DrvStmt) QueryContext(ctx context.Context, values []driver.NamedValue)
grp, ctx := errgroup.WithContext(ctx)
grp.Go(func() error {
var err error
rset, err = ds.stmt.qry(params)
rset, err = ds.stmt.qryC(ctx, params)
if err != nil {
return errE(err)
}
Expand Down
19 changes: 19 additions & 0 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import "C"
import (
"bytes"
"container/list"
"context"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -219,9 +220,15 @@ var spcRpl = strings.NewReplacer("\t", " ", " ", " ", " ", " ")

// exe executes a SQL statement on an Oracle server returning rowsAffected, lastInsertId and error.
func (stmt *Stmt) exe(params []interface{}, isAssocArray bool) (rowsAffected uint64, lastInsertId int64, err error) {
return stmt.exeC(context.Background(), params, isAssocArray)
}
func (stmt *Stmt) exeC(ctx context.Context, params []interface{}, isAssocArray bool) (rowsAffected uint64, lastInsertId int64, err error) {
if stmt == nil {
return 0, 0, er("stmt may not be nil.")
}
if err = ctx.Err(); err != nil {
return
}
defer func() {
if value := recover(); value != nil {
err = errR(value)
Expand All @@ -232,6 +239,9 @@ func (stmt *Stmt) exe(params []interface{}, isAssocArray bool) (rowsAffected uin
if err != nil {
return 0, 0, errE(err)
}
if cfg, ok := ctxStmtCfg(ctx); ok {
stmt.SetCfg(cfg)
}
// for case of inserting and returning identity for database/sql package
stmt.RLock()
pkgEnvInsert := stmt.env.isPkgEnv && stmt.stmtType == C.OCI_STMT_INSERT
Expand Down Expand Up @@ -316,12 +326,21 @@ func (stmt *Stmt) Qry(params ...interface{}) (*Rset, error) {

// qry runs a SQL query on an Oracle server returning a *Rset and possible error.
func (stmt *Stmt) qry(params []interface{}) (rset *Rset, err error) {
return stmt.qryC(context.Background(), params)
}
func (stmt *Stmt) qryC(ctx context.Context, params []interface{}) (rset *Rset, err error) {
defer func() {
if value := recover(); value != nil {
err = errR(value)
}
}()
stmt.log(_drv.Cfg().Log.Stmt.Qry)
if err := ctx.Err(); err != nil {
return nil, err
}
if cfg, ok := ctxStmtCfg(ctx); ok {
stmt.SetCfg(cfg)
}
err = stmt.checkClosed()
if err != nil {
return nil, errE(err)
Expand Down

0 comments on commit 7c9d054

Please sign in to comment.