Skip to content

Commit

Permalink
sql: add user_id column to system.web_sessions table
Browse files Browse the repository at this point in the history
This patch adds a new `user_id` column to the `system.web_sessions` table,
which corresponds to the existing `username` column. Migrations are
also added to alter and backfill the table in older clusters.

Release note: None
  • Loading branch information
andyyang890 committed Feb 15, 2023
1 parent 494b909 commit 72a4b69
Show file tree
Hide file tree
Showing 17 changed files with 449 additions and 36 deletions.
2 changes: 1 addition & 1 deletion docs/generated/settings/settings-for-tenants.txt
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,4 @@ trace.jaeger.agent string the address of a Jaeger agent to receive traces using
trace.opentelemetry.collector string address of an OpenTelemetry trace collector to receive traces using the otel gRPC protocol, as <host>:<port>. If no port is specified, 4317 will be used.
trace.span_registry.enabled boolean true if set, ongoing traces can be seen at https://<ui>/#/debug/tracez
trace.zipkin.collector string the address of a Zipkin instance to receive traces, as <host>:<port>. If no port is specified, 9411 will be used.
version version 1000022.2-44 set the active cluster version in the format '<major>.<minor>'
version version 1000022.2-48 set the active cluster version in the format '<major>.<minor>'
2 changes: 1 addition & 1 deletion docs/generated/settings/settings.html
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,6 @@
<tr><td><div id="setting-trace-opentelemetry-collector" class="anchored"><code>trace.opentelemetry.collector</code></div></td><td>string</td><td><code></code></td><td>address of an OpenTelemetry trace collector to receive traces using the otel gRPC protocol, as &lt;host&gt;:&lt;port&gt;. If no port is specified, 4317 will be used.</td></tr>
<tr><td><div id="setting-trace-span-registry-enabled" class="anchored"><code>trace.span_registry.enabled</code></div></td><td>boolean</td><td><code>true</code></td><td>if set, ongoing traces can be seen at https://&lt;ui&gt;/#/debug/tracez</td></tr>
<tr><td><div id="setting-trace-zipkin-collector" class="anchored"><code>trace.zipkin.collector</code></div></td><td>string</td><td><code></code></td><td>the address of a Zipkin instance to receive traces, as &lt;host&gt;:&lt;port&gt;. If no port is specified, 9411 will be used.</td></tr>
<tr><td><div id="setting-version" class="anchored"><code>version</code></div></td><td>version</td><td><code>1000022.2-44</code></td><td>set the active cluster version in the format &#39;&lt;major&gt;.&lt;minor&gt;&#39;</td></tr>
<tr><td><div id="setting-version" class="anchored"><code>version</code></div></td><td>version</td><td><code>1000022.2-48</code></td><td>set the active cluster version in the format &#39;&lt;major&gt;.&lt;minor&gt;&#39;</td></tr>
</tbody>
</table>
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ UNION ALL SELECT create_statement FROM [SHOW CREATE TABLE system.namespace]
"revokedAt" TIMESTAMP NULL,
"lastUsedAt" TIMESTAMP NOT NULL DEFAULT now():::TIMESTAMP,
"auditInfo" STRING NULL,
user_id OID NULL,
crdb_region system.public.crdb_internal_region NOT VISIBLE NOT NULL DEFAULT default_to_database_primary_region(gateway_region())::system.public.crdb_internal_region,
CONSTRAINT "primary" PRIMARY KEY (id ASC),
INDEX "web_sessions_expiresAt_idx" ("expiresAt" ASC),
INDEX "web_sessions_createdAt_idx" ("createdAt" ASC),
INDEX "web_sessions_revokedAt_idx" ("revokedAt" ASC),
INDEX "web_sessions_lastUsedAt_idx" ("lastUsedAt" ASC),
FAMILY "fam_0_id_hashedSecret_username_createdAt_expiresAt_revokedAt_lastUsedAt_auditInfo" (id, "hashedSecret", username, "createdAt", "expiresAt", "revokedAt", "lastUsedAt", "auditInfo"),
FAMILY fam_9_crdb_region (crdb_region)
FAMILY "fam_0_id_hashedSecret_username_createdAt_expiresAt_revokedAt_lastUsedAt_auditInfo" (id, "hashedSecret", username, "createdAt", "expiresAt", "revokedAt", "lastUsedAt", "auditInfo", user_id),
FAMILY fam_10_crdb_region (crdb_region)
) LOCALITY REGIONAL BY ROW`},
{`CREATE TABLE public.namespace (
"parentID" INT8 NOT NULL,
Expand Down
78 changes: 59 additions & 19 deletions pkg/cli/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ package cli

import (
"context"
"database/sql/driver"
"fmt"
"net/http"
"os"

"github.com/cockroachdb/cockroach/pkg/cli/clierrorplus"
"github.com/cockroachdb/cockroach/pkg/cli/clisqlclient"
"github.com/cockroachdb/cockroach/pkg/cli/clisqlexec"
"github.com/cockroachdb/cockroach/pkg/clusterversion"
"github.com/cockroachdb/cockroach/pkg/server"
"github.com/cockroachdb/cockroach/pkg/server/serverpb"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
Expand Down Expand Up @@ -122,28 +124,64 @@ func createAuthSessionToken(
expiration := timeutil.Now().Add(authCtx.validityPeriod)

// Create the session on the server to the server.
insertSessionStmt := `
var id int64
err = sqlConn.ExecTxn(ctx, func(ctx context.Context, conn clisqlclient.TxBoundConn) error {
rows, err := conn.Query(ctx, fmt.Sprintf(
"SELECT crdb_internal.is_at_least_version('%s')",
clusterversion.ByKey(clusterversion.V23_1WebSessionsTableHasUserIDColumn)))
if err != nil {
return err
}
row := make([]driver.Value, 1)
if err := rows.Next(row); err != nil {
return err
}
if err := rows.Close(); err != nil {
return err
}
webSessionsHasUserIDCol, ok := row[0].(bool)
if !ok {
return errors.Newf("expected bool, got %T", row[0])
}
insertSessionStmt := `
INSERT INTO system.web_sessions ("hashedSecret", username, "expiresAt")
VALUES($1, $2, $3)
VALUES ($1, $2, $3)
RETURNING id
`
var id int64
row, err := sqlConn.QueryRow(ctx,
insertSessionStmt,
hashedSecret,
username,
expiration,
)
if webSessionsHasUserIDCol {
insertSessionStmt = `
INSERT INTO system.web_sessions ("hashedSecret", username, "expiresAt", user_id)
VALUES ($1, $2, $3, (SELECT user_id FROM system.users WHERE username = $2))
RETURNING id
`
}
rows, err = conn.Query(ctx,
insertSessionStmt,
hashedSecret,
username,
expiration,
)
if err != nil {
return err
}
if err := rows.Next(row); err != nil {
return err
}
if err := rows.Close(); err != nil {
return err
}
if len(row) != 1 {
return errors.Newf("expected 1 column, got %d", len(row))
}
id, ok = row[0].(int64)
if !ok {
return errors.Newf("expected integer, got %T", row[0])
}
return nil
})
if err != nil {
return -1, nil, err
}
if len(row) != 1 {
return -1, nil, errors.Newf("expected 1 column, got %d", len(row))
}
id, ok := row[0].(int64)
if !ok {
return -1, nil, errors.Newf("expected integer, got %T", row[0])
}

// Spell out the cookie.
sCookie := &serverpb.SessionCookie{ID: id, Secret: secret}
Expand Down Expand Up @@ -204,17 +242,19 @@ func runAuthList(cmd *cobra.Command, args []string) (resErr error) {
}
defer func() { resErr = errors.CombineErrors(resErr, sqlConn.Close()) }()

logoutQuery := clisqlclient.MakeQuery(`
// TODO(yang): Change this to read the user_id directly from the table in 23.2.
authListQuery := clisqlclient.MakeQuery(`
SELECT username,
(SELECT user_id FROM system.users AS u WHERE w.username = u.username) AS "user ID",
id AS "session ID",
"createdAt" as "created",
"expiresAt" as "expires",
"revokedAt" as "revoked",
"lastUsedAt" as "last used"
FROM system.web_sessions`)
FROM system.web_sessions AS w`)
return sqlExecCtx.RunQueryAndFormatResults(
context.Background(),
sqlConn, os.Stdout, os.Stdout, stderr, logoutQuery)
sqlConn, os.Stdout, os.Stdout, stderr, authListQuery)
}

var authCmds = []*cobra.Command{
Expand Down
1 change: 1 addition & 0 deletions pkg/cli/democluster/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ go_library(
"@com_github_cockroachdb_errors//oserror",
"@com_github_cockroachdb_logtags//:logtags",
"@com_github_cockroachdb_redact//:redact",
"@com_github_lib_pq//oid",
"@com_github_nightlyone_lockfile//:lockfile",
"@org_golang_x_time//rate",
],
Expand Down
11 changes: 7 additions & 4 deletions pkg/cli/democluster/session_persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/oserror"
"github.com/cockroachdb/redact"
"github.com/lib/pq/oid"
)

// saveWebSessions persists any currently active web session to disk,
Expand Down Expand Up @@ -114,6 +115,7 @@ type webSessionRow struct {
HashedSecret []byte
Username string
ExpiresAt string
UserID oid.Oid
}

// saveWebSessionsInternal saves the sessions for just one tenant to
Expand All @@ -123,7 +125,7 @@ func (c *transientCluster) saveWebSessionsInternal(
) error {
c.infoLog(ctx, "saving sessions")
rows, err := db.QueryContext(ctx, `
SELECT id, "hashedSecret", username, "expiresAt"
SELECT id, "hashedSecret", username, "expiresAt", user_id
FROM system.web_sessions
WHERE "expiresAt" > now()
AND "revokedAt" IS NULL`)
Expand Down Expand Up @@ -151,7 +153,7 @@ AND "revokedAt" IS NULL`)
numSessions := 0
for rows.Next() {
var row webSessionRow
if err := rows.Scan(&row.ID, &row.HashedSecret, &row.Username, &row.ExpiresAt); err != nil {
if err := rows.Scan(&row.ID, &row.HashedSecret, &row.Username, &row.ExpiresAt, &row.UserID); err != nil {
return err
}
j, err := json.Marshal(row)
Expand Down Expand Up @@ -209,12 +211,13 @@ func (c *transientCluster) restoreWebSessionsInternal(
}

if _, err := db.ExecContext(ctx, `
INSERT INTO system.web_sessions(id, "hashedSecret", username, "expiresAt")
VALUES ($1, $2, $3, $4)`,
INSERT INTO system.web_sessions(id, "hashedSecret", username, "expiresAt", user_id)
VALUES ($1, $2, $3, $4, $5)`,
row.ID,
row.HashedSecret,
row.Username,
row.ExpiresAt,
row.UserID,
); err != nil {
return err
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/clusterversion/cockroach_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,14 @@ const (
// user_id column in the system.privileges table has been backfilled.
V23_1SystemPrivilegesTableUserIDColumnBackfilled

// V23_1WebSessionsTableHasUserIDColumn is the version where the
// user_id column has been added to the system.web_sessions table.
V23_1WebSessionsTableHasUserIDColumn

// V23_1WebSessionsTableUserIDColumnBackfilled is the version where the
// user_id column in the system.web_sessions table has been backfilled.
V23_1WebSessionsTableUserIDColumnBackfilled

// *************************************************
// Step (1): Add new versions here.
// Do not add new versions to a patch release.
Expand Down Expand Up @@ -719,6 +727,14 @@ var rawVersionsSingleton = keyedVersions{
Key: V23_1SystemPrivilegesTableUserIDColumnBackfilled,
Version: roachpb.Version{Major: 22, Minor: 2, Internal: 44},
},
{
Key: V23_1WebSessionsTableHasUserIDColumn,
Version: roachpb.Version{Major: 22, Minor: 2, Internal: 46},
},
{
Key: V23_1WebSessionsTableUserIDColumnBackfilled,
Version: roachpb.Version{Major: 22, Minor: 2, Internal: 48},
},

// *************************************************
// Step (2): Add new versions here.
Expand Down
11 changes: 11 additions & 0 deletions pkg/server/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"time"

"github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/clusterversion"
"github.com/cockroachdb/cockroach/pkg/roachpb"
"github.com/cockroachdb/cockroach/pkg/security"
"github.com/cockroachdb/cockroach/pkg/security/password"
Expand Down Expand Up @@ -489,6 +490,9 @@ func CreateAuthSecret() (secret, hashedSecret []byte, err error) {
func (s *authenticationServer) newAuthSession(
ctx context.Context, userName username.SQLUsername,
) (int64, []byte, error) {
webSessionsTableHasUserIDCol := s.sqlServer.execCfg.Settings.Version.IsActive(ctx,
clusterversion.V23_1WebSessionsTableHasUserIDColumn)

secret, hashedSecret, err := CreateAuthSecret()
if err != nil {
return 0, nil, err
Expand All @@ -501,6 +505,13 @@ INSERT INTO system.web_sessions ("hashedSecret", username, "expiresAt")
VALUES($1, $2, $3)
RETURNING id
`
if webSessionsTableHasUserIDCol {
insertSessionStmt = `
INSERT INTO system.web_sessions ("hashedSecret", username, "expiresAt", user_id)
VALUES($1, $2, $3, (SELECT user_id FROM system.users WHERE username = $2))
RETURNING id
`
}
var id int64

row, err := s.sqlServer.internalExecutor.QueryRowEx(
Expand Down
Loading

0 comments on commit 72a4b69

Please sign in to comment.