Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(session): Improve list response time #2160

Merged
merged 9 commits into from
Jun 9, 2022
13 changes: 13 additions & 0 deletions api/sessions/option.gen.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sessions

import (
"fmt"
"strconv"
"strings"

Expand Down Expand Up @@ -86,3 +87,15 @@ func WithRecursive(recurse bool) Option {
o.withRecursive = true
}
}

func WithIncludeTerminated(inIncludeTerminated bool) Option {
return func(o *options) {
o.queryMap["include_terminated"] = fmt.Sprintf("%v", inIncludeTerminated)
}
}

func DefaultIncludeTerminated() Option {
return func(o *options) {
o.postMap["include_terminated"] = nil
}
}
8 changes: 8 additions & 0 deletions internal/api/genapi/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,14 @@ var inputStructs = []*structInfo{
readTemplate,
listTemplate,
},
extraFields: []fieldInfo{
{
Name: "IncludeTerminated",
ProtoName: "include_terminated",
FieldType: "bool",
Query: true,
},
},
pluralResourceName: "sessions",
createResponseTypes: true,
fieldFilter: []string{"private_key"},
Expand Down
31 changes: 31 additions & 0 deletions internal/cmd/commands/sessionscmd/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,46 @@ import (
"github.com/hashicorp/boundary/internal/cmd/base"
)

const (
flagIncludeTerminated = "include-terminated"
)

func init() {
extraActionsFlagsMapFunc = extraActionsFlagsMapFuncImpl
extraFlagsFunc = extraFlagsFuncImpl
extraFlagsHandlingFunc = extraFlagsHandlingFuncImpl
executeExtraActions = executeExtraActionsImpl
}

func extraActionsFlagsMapFuncImpl() map[string][]string {
return map[string][]string{
"cancel": {"id"},
"list": {flagIncludeTerminated},
}
}

type extraCmdVars struct {
flagIncludeTerminated bool
}

func extraFlagsFuncImpl(c *Command, set *base.FlagSets, f *base.FlagSet) {
for _, name := range flagsMap[c.Func] {
switch name {
case flagIncludeTerminated:
f.BoolVar(&base.BoolVar{
Name: flagIncludeTerminated,
Target: &c.flagIncludeTerminated,
Usage: "If set, terminated sessions will be included in the results.",
})
}
}
}

func extraFlagsHandlingFuncImpl(c *Command, _ *base.FlagSets, opts *[]sessions.Option) bool {
if c.flagIncludeTerminated {
*opts = append(*opts, sessions.WithIncludeTerminated(c.flagIncludeTerminated))
}
return true
}

func (c *Command) extraHelpFunc(helpMap map[string]func() string) string {
Expand Down
2 changes: 2 additions & 0 deletions internal/cmd/commands/sessionscmd/sessions.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions internal/cmd/gencli/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,14 @@ var inputStructs = map[string][]*cmdInfo{
},
"sessions": {
{
ResourceType: resource.Session.String(),
Pkg: "sessions",
StdActions: []string{"read", "list"},
Container: "Scope",
HasExtraHelpFunc: true,
HasId: true,
VersionedActions: []string{"cancel"},
ResourceType: resource.Session.String(),
Pkg: "sessions",
StdActions: []string{"read", "list"},
Container: "Scope",
HasExtraCommandVars: true,
HasExtraHelpFunc: true,
HasId: true,
VersionedActions: []string{"cancel"},
},
},
"targets": {
Expand Down
2 changes: 1 addition & 1 deletion internal/daemon/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func (c *Controller) registerJobs() error {
if err := pluginhost.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.conf.HostPlugins); err != nil {
return err
}
if err := session.RegisterJobs(c.baseContext, c.scheduler, rw, c.conf.StatusGracePeriodDuration); err != nil {
if err := session.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.conf.StatusGracePeriodDuration); err != nil {
return err
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest)
RootScopeId: req.GetScopeId(),
Type: resource.Session,
Recursive: req.GetRecursive(),
AuthzProtectedEntityProvider: repo,
AuthzProtectedEntityProvider: session.ListForAuthzCheck(repo, session.WithTerminated(req.IncludeTerminated)),
ActionSet: IdActions,
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ func TestList(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
wrap := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrap)
ctx := context.Background()

iamRepo := iam.TestRepo(t, conn, wrap)

Expand Down Expand Up @@ -260,16 +261,17 @@ func TestList(t *testing.T) {
hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0]
h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0]
static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h})
tar := tcp.TestTarget(context.Background(), t, conn, pWithSessions.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()}))
tar := tcp.TestTarget(ctx, t, conn, pWithSessions.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()}))

hcOther := static.TestCatalogs(t, conn, pWithOtherSessions.GetPublicId(), 1)[0]
hsOther := static.TestSets(t, conn, hcOther.GetPublicId(), 1)[0]
hOther := static.TestHosts(t, conn, hcOther.GetPublicId(), 1)[0]
static.TestSetMembers(t, conn, hsOther.GetPublicId(), []*static.Host{hOther})
tarOther := tcp.TestTarget(context.Background(), t, conn, pWithOtherSessions.GetPublicId(), "test", target.WithHostSources([]string{hsOther.GetPublicId()}))
tarOther := tcp.TestTarget(ctx, t, conn, pWithOtherSessions.GetPublicId(), "test", target.WithHostSources([]string{hsOther.GetPublicId()}))

var wantSession []*pb.Session
var totalSession []*pb.Session
var wantIncludeTerminatedSessions []*pb.Session
for i := 0; i < 10; i++ {
sess := session.TestSession(t, conn, wrap, session.ComposedOf{
UserId: uId,
Expand All @@ -281,7 +283,7 @@ func TestList(t *testing.T) {
Endpoint: "tcp://127.0.0.1:22",
})

c := session.TestConnection(t, conn, sess.PublicId, "127.0.0.1", 22, "127.0.0.2", 23, "127.0.0.1")
session.TestConnection(t, conn, sess.PublicId, "127.0.0.1", 22, "127.0.0.2", 23, "127.0.0.1")

status, states := convertStates(sess.States)

Expand All @@ -304,17 +306,11 @@ func TestList(t *testing.T) {
Certificate: sess.Certificate,
Type: tcp.Subtype.String(),
AuthorizedActions: testAuthorizedActions,
Connections: []*pb.Connection{
{
ClientTcpAddress: c.ClientTcpAddress,
ClientTcpPort: c.ClientTcpPort,
EndpointTcpAddress: c.EndpointTcpAddress,
EndpointTcpPort: c.EndpointTcpPort,
},
},
Connections: []*pb.Connection{}, // connections should not be returned for list
})

totalSession = append(totalSession, wantSession[i])
wantIncludeTerminatedSessions = append(wantIncludeTerminatedSessions, wantSession[i])

sess = session.TestSession(t, conn, wrap, session.ComposedOf{
UserId: uIdOther,
Expand All @@ -326,7 +322,7 @@ func TestList(t *testing.T) {
Endpoint: "tcp://127.0.0.1:22",
})

c = session.TestConnection(t, conn, sess.PublicId, "127.0.0.1", 22, "127.0.0.2", 23, "127.0.0.1")
session.TestConnection(t, conn, sess.PublicId, "127.0.0.1", 22, "127.0.0.2", 23, "127.0.0.1")

status, states = convertStates(sess.States)

Expand All @@ -349,15 +345,55 @@ func TestList(t *testing.T) {
Certificate: sess.Certificate,
Type: tcp.Subtype.String(),
AuthorizedActions: testAuthorizedActions,
Connections: []*pb.Connection{
{
ClientTcpAddress: c.ClientTcpAddress,
ClientTcpPort: c.ClientTcpPort,
EndpointTcpAddress: c.EndpointTcpAddress,
EndpointTcpPort: c.EndpointTcpPort,
},
},
Connections: []*pb.Connection{}, // connections should not be returned for list
})
}

{
sess := session.TestSession(t, conn, wrap, session.ComposedOf{
UserId: uId,
HostId: h.GetPublicId(),
TargetId: tar.GetPublicId(),
HostSetId: hs.GetPublicId(),
AuthTokenId: at.GetPublicId(),
ScopeId: pWithSessions.GetPublicId(),
Endpoint: "tcp://127.0.0.1:22",
})

sess, err := sessRepo.CancelSession(ctx, sess.PublicId, sess.Version)
require.NoError(t, err)
terminated, err := sessRepo.TerminateCompletedSessions(ctx)
require.NoError(t, err)
require.Equal(t, 1, terminated)

sess, _, err = sessRepo.LookupSession(ctx, sess.PublicId)
require.NoError(t, err)
status, states := convertStates(sess.States)

expected := &pb.Session{
Id: sess.GetPublicId(),
ScopeId: pWithSessions.GetPublicId(),
AuthTokenId: at.GetPublicId(),
UserId: at.GetIamUserId(),
TargetId: sess.TargetId,
Endpoint: sess.Endpoint,
HostSetId: sess.HostSetId,
HostId: sess.HostId,
Version: sess.Version,
UpdatedTime: sess.UpdateTime.GetTimestamp(),
CreatedTime: sess.CreateTime.GetTimestamp(),
ExpirationTime: sess.ExpirationTime.GetTimestamp(),
Scope: &scopes.ScopeInfo{Id: pWithSessions.GetPublicId(), Type: scope.Project.String(), ParentScopeId: o.GetPublicId()},
Status: status,
States: states,
Certificate: sess.Certificate,
TerminationReason: sess.TerminationReason,
Type: tcp.Subtype.String(),
AuthorizedActions: testAuthorizedActions,
Connections: []*pb.Connection{}, // connections should not be returned for list
}

wantIncludeTerminatedSessions = append(wantIncludeTerminatedSessions, expected)
}

cases := []struct {
Expand All @@ -371,6 +407,11 @@ func TestList(t *testing.T) {
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId()},
res: &pbs.ListSessionsResponse{Items: wantSession},
},
{
name: "List Many Include Terminated",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), IncludeTerminated: true},
res: &pbs.ListSessionsResponse{Items: wantIncludeTerminatedSessions},
},
{
name: "List No Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pNoSessions.GetPublicId()},
Expand Down Expand Up @@ -423,14 +464,8 @@ func TestList(t *testing.T) {
require.Equal(len(tc.res.GetItems()), len(got.GetItems()), "Didn't get expected number of sessions: %v", got.GetItems())
for i, wantSess := range tc.res.GetItems() {
assert.True(got.GetItems()[i].GetExpirationTime().AsTime().Sub(wantSess.GetExpirationTime().AsTime()) < 10*time.Millisecond)
assert.Equal(1, len(wantSess.GetConnections()))
assert.Equal(0, len(wantSess.GetConnections())) // no connections on list
wantSess.ExpirationTime = got.GetItems()[i].GetExpirationTime()
for _, c := range got.GetItems()[i].GetConnections() {
assert.Equal("127.0.0.1", c.ClientTcpAddress)
assert.Equal(uint32(22), c.ClientTcpPort)
assert.Equal("127.0.0.2", c.EndpointTcpAddress)
assert.Equal(uint32(23), c.EndpointTcpPort)
}
}
}
assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "ListSessions(%q) got response %q, wanted %q", tc.req, got, tc.res)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
begin;

-- Replaces the view created in 1/01 to include connections
-- Replaced in 31/0_session_list_no_connections
drop view session_with_state;
create view session_list as
select
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
begin;
delete from session
using session_state
where
session.public_id = session_state.session_id
and
session_state.state = 'terminated'
and
session_state.start_time < wt_sub_seconds_from_now(3600);

analyze;
commit;
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
begin;

-- Replaces the view created in 2/09_session_list_view
drop view session_list;
create view session_list as
select
s.public_id,
s.user_id,
s.host_id,
s.server_id,
s.server_type,
s.target_id,
s.host_set_id,
s.auth_token_id,
s.scope_id,
s.certificate,
s.expiration_time,
s.connection_limit,
s.tofu_token,
s.key_id,
s.termination_reason,
s.version,
s.create_time,
s.update_time,
s.endpoint,
s.worker_filter,
ss.state,
ss.previous_end_time,
ss.start_time,
ss.end_time
from
session s
join
session_state ss
on
s.public_id = ss.session_id;

commit;
7 changes: 7 additions & 0 deletions internal/gen/controller.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -2523,6 +2523,13 @@
"in": "query",
"required": false,
"type": "string"
},
{
"name": "include_terminated",
"description": "Experimental. By default only non-terminated (i.e. pending, active, canceling) are returned.\nSet this option to include terminated sessions as well.",
"in": "query",
"required": false,
"type": "boolean"
}
],
"tags": [
Expand Down
Loading