Skip to content

Commit

Permalink
adding support for statement list in cassandra
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Hoffman committed Feb 15, 2018
1 parent 0b70524 commit fd13249
Showing 1 changed file with 29 additions and 35 deletions.
64 changes: 29 additions & 35 deletions plugins/database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,14 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
return "", "", err
}

var creationCQL string
switch len(statements.CreationStatements) {
case 0:
creationCQL = defaultUserCreationCQL
case 1:
creationCQL = statements.CreationStatements[0]
default:
return "", "", fmt.Errorf("expected 0 or 1 creation statements, got %d", len(statements.CreationStatements))
creationCQL := statements.CreationStatements
if len(creationCQL) == 0 {
creationCQL = []string{defaultUserCreationCQL}
}

var rollbackCQL string
switch len(statements.RollbackStatements) {
case 0:
rollbackCQL = defaultUserDeletionCQL
case 1:
rollbackCQL = statements.RollbackStatements[0]
default:
return "", "", fmt.Errorf("expected 0 or 1 rollback statements, got %d", len(statements.RollbackStatements))
rollbackCQL := statements.CreationStatements
if len(rollbackCQL) == 0 {
rollbackCQL = []string{defaultUserDeletionCQL}
}

username, err = c.GenerateUsername(usernameConfig)
Expand All @@ -124,28 +114,32 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
}

// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
for _, stmt := range creationCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}

err = session.Query(dbutil.QueryHelper(query, map[string]string{
"username": username,
"password": password,
})).Exec()
if err != nil {
for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
err = session.Query(dbutil.QueryHelper(query, map[string]string{
"username": username,
"password": password,
})).Exec()
if err != nil {
for _, stmt := range rollbackCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}

session.Query(dbutil.QueryHelper(query, map[string]string{
"username": username,
})).Exec()
}
}

session.Query(dbutil.QueryHelper(query, map[string]string{
"username": username,
})).Exec()
return "", "", err
}
return "", "", err
}
}

Expand Down

0 comments on commit fd13249

Please sign in to comment.