Skip to content

Commit

Permalink
Add WithSession helper for Cassandra driver
Browse files Browse the repository at this point in the history
Also changes method receivers names for the Cassandra struct to "c"
since "p" makes no sense in this context.
  • Loading branch information
mixedCase committed Mar 13, 2018
1 parent 22f2495 commit 78c4707
Showing 1 changed file with 50 additions and 27 deletions.
77 changes: 50 additions & 27 deletions database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoKeyspace = fmt.Errorf("no keyspace provided")
ErrDatabaseDirty = fmt.Errorf("database is dirty")
ErrClosedSession = fmt.Errorf("session is closed")
)

type Config struct {
Expand All @@ -39,7 +40,29 @@ type Cassandra struct {
config *Config
}

func (p *Cassandra) Open(url string) (database.Driver, error) {
func WithSession(session *gocql.Session, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
} else if isClosed := session.Closed(); isClosed {
return nil, ErrClosedSession
} else if len(config.KeyspaceName) == 0 {
return nil, ErrNoKeyspace
}

if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
c := &Cassandra{
session: session,
config: config,
}
if err := c.ensureVersionTable(); err != nil {
return nil, err
}
return c, nil
}

func (c *Cassandra) Open(url string) (database.Driver, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
Expand All @@ -55,7 +78,7 @@ func (p *Cassandra) Open(url string) (database.Driver, error) {
migrationsTable = DefaultMigrationsTable
}

p.config = &Config{
c.config = &Config{
KeyspaceName: u.Path,
MigrationsTable: migrationsTable,
}
Expand Down Expand Up @@ -100,60 +123,60 @@ func (p *Cassandra) Open(url string) (database.Driver, error) {
cluster.Timeout = timeout
}

p.session, err = cluster.CreateSession()
c.session, err = cluster.CreateSession()

if err != nil {
return nil, err
}

if err := p.ensureVersionTable(); err != nil {
if err := c.ensureVersionTable(); err != nil {
return nil, err
}

return p, nil
return c, nil
}

func (p *Cassandra) Close() error {
p.session.Close()
func (c *Cassandra) Close() error {
c.session.Close()
return nil
}

func (p *Cassandra) Lock() error {
func (c *Cassandra) Lock() error {
if dbLocked {
return database.ErrLocked
}
dbLocked = true
return nil
}

func (p *Cassandra) Unlock() error {
func (c *Cassandra) Unlock() error {
dbLocked = false
return nil
}

func (p *Cassandra) Run(migration io.Reader) error {
func (c *Cassandra) Run(migration io.Reader) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
}
// run migration
query := string(migr[:])
if err := p.session.Query(query).Exec(); err != nil {
if err := c.session.Query(query).Exec(); err != nil {
// TODO: cast to Cassandra error and get line number
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}

return nil
}

func (p *Cassandra) SetVersion(version int, dirty bool) error {
query := `TRUNCATE "` + p.config.MigrationsTable + `"`
if err := p.session.Query(query).Exec(); err != nil {
func (c *Cassandra) SetVersion(version int, dirty bool) error {
query := `TRUNCATE "` + c.config.MigrationsTable + `"`
if err := c.session.Query(query).Exec(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if version >= 0 {
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
if err := p.session.Query(query, version, dirty).Exec(); err != nil {
query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
if err := c.session.Query(query, version, dirty).Exec(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
Expand All @@ -162,9 +185,9 @@ func (p *Cassandra) SetVersion(version int, dirty bool) error {
}

// Return current keyspace version
func (p *Cassandra) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
err = p.session.Query(query).Scan(&version, &dirty)
func (c *Cassandra) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
err = c.session.Query(query).Scan(&version, &dirty)
switch {
case err == gocql.ErrNotFound:
return database.NilVersion, false, nil
Expand All @@ -180,31 +203,31 @@ func (p *Cassandra) Version() (version int, dirty bool, err error) {
}
}

func (p *Cassandra) Drop() error {
func (c *Cassandra) Drop() error {
// select all tables in current schema
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character
iter := p.session.Query(query).Iter()
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName[1:]) // Skip '/' character
iter := c.session.Query(query).Iter()
var tableName string
for iter.Scan(&tableName) {
err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
if err != nil {
return err
}
}
// Re-create the version table
if err := p.ensureVersionTable(); err != nil {
if err := c.ensureVersionTable(); err != nil {
return err
}
return nil
}

// Ensure version table exists
func (p *Cassandra) ensureVersionTable() error {
err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec()
func (c *Cassandra) ensureVersionTable() error {
err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
if err != nil {
return err
}
if _, _, err = p.Version(); err != nil {
if _, _, err = c.Version(); err != nil {
return err
}
return nil
Expand Down

0 comments on commit 78c4707

Please sign in to comment.