diff --git a/lib/auth/session_access_test.go b/lib/auth/session_access_test.go index b368bf6c8637f..c2794c058609c 100644 --- a/lib/auth/session_access_test.go +++ b/lib/auth/session_access_test.go @@ -25,10 +25,10 @@ import ( type startTestCase struct { name string - host types.Role - sessionKind types.SessionKind + host []types.Role + sessionKinds []types.SessionKind participants []SessionAccessContext - expected bool + expected []bool } func successStartTestCase(t *testing.T) startTestCase { @@ -39,7 +39,7 @@ func successStartTestCase(t *testing.T) startTestCase { hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{{ Filter: "contains(user.roles, \"participant\")", - Kinds: []string{string(types.SSHSessionKind)}, + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, Count: 2, OnLeave: types.OnSessionLeaveTerminate, Modes: []string{"peer"}, @@ -47,14 +47,14 @@ func successStartTestCase(t *testing.T) startTestCase { participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{ Roles: []string{hostRole.GetName()}, - Kinds: []string{string(types.SSHSessionKind)}, + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, Modes: []string{string("*")}, }}) return startTestCase{ - name: "success", - host: hostRole, - sessionKind: types.SSHSessionKind, + name: "success", + host: []types.Role{hostRole}, + sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind}, participants: []SessionAccessContext{ { Username: "participant", @@ -67,7 +67,7 @@ func successStartTestCase(t *testing.T) startTestCase { Mode: "peer", }, }, - expected: true, + expected: []bool{true, true}, } } @@ -79,21 +79,21 @@ func failCountStartTestCase(t *testing.T) startTestCase { hostRole.SetSessionRequirePolicies([]*types.SessionRequirePolicy{{ Filter: "contains(user.roles, \"participant\")", - Kinds: []string{string(types.SSHSessionKind)}, + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, Count: 3, Modes: []string{"peer"}, }}) participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{ Roles: []string{hostRole.GetName()}, - Kinds: []string{string(types.SSHSessionKind)}, + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, Modes: []string{string("*")}, }}) return startTestCase{ - name: "failCount", - host: hostRole, - sessionKind: types.SSHSessionKind, + name: "failCount", + host: []types.Role{hostRole}, + sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind}, participants: []SessionAccessContext{ { Username: "participant", @@ -106,7 +106,7 @@ func failCountStartTestCase(t *testing.T) startTestCase { Mode: "peer", }, }, - expected: false, + expected: []bool{false, false}, } } @@ -122,10 +122,10 @@ func succeedDiscardPolicySetStartTestCase(t *testing.T) startTestCase { }}) return startTestCase{ - name: "succeedDiscardPolicySet", - host: hostRole, - sessionKind: types.SSHSessionKind, - expected: true, + name: "succeedDiscardPolicySet", + host: []types.Role{hostRole}, + sessionKinds: []types.SessionKind{types.SSHSessionKind}, + expected: []bool{true}, } } @@ -149,9 +149,9 @@ func failFilterStartTestCase(t *testing.T) startTestCase { }}) return startTestCase{ - name: "failFilter", - host: hostRole, - sessionKind: types.SSHSessionKind, + name: "failFilter", + host: []types.Role{hostRole}, + sessionKinds: []types.SessionKind{types.SSHSessionKind}, participants: []SessionAccessContext{ { Username: "participant", @@ -164,7 +164,7 @@ func failFilterStartTestCase(t *testing.T) startTestCase { Mode: "peer", }, }, - expected: false, + expected: []bool{false}, } } @@ -178,21 +178,28 @@ func TestSessionAccessStart(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - policy := testCase.host.GetSessionPolicySet() - evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, testCase.sessionKind) - result, _, err := evaluator.FulfilledFor(testCase.participants) - require.NoError(t, err) - require.Equal(t, testCase.expected, result) + var policies []*types.SessionTrackerPolicySet + for _, role := range testCase.host { + policySet := role.GetSessionPolicySet() + policies = append(policies, &policySet) + } + + for i, kind := range testCase.sessionKinds { + evaluator := NewSessionAccessEvaluator(policies, kind) + result, _, err := evaluator.FulfilledFor(testCase.participants) + require.NoError(t, err) + require.Equal(t, testCase.expected[i], result) + } }) } } type joinTestCase struct { - name string - host types.Role - sessionKind types.SessionKind - participant SessionAccessContext - expected bool + name string + host types.Role + sessionKinds []types.SessionKind + participant SessionAccessContext + expected []bool } func successJoinTestCase(t *testing.T) joinTestCase { @@ -203,19 +210,19 @@ func successJoinTestCase(t *testing.T) joinTestCase { participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{ Roles: []string{hostRole.GetName()}, - Kinds: []string{string(types.SSHSessionKind)}, + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, Modes: []string{string("*")}, }}) return joinTestCase{ - name: "success", - host: hostRole, - sessionKind: types.SSHSessionKind, + name: "success", + host: hostRole, + sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind}, participant: SessionAccessContext{ Username: "participant", Roles: []types.Role{participantRole}, }, - expected: true, + expected: []bool{true, true}, } } @@ -227,19 +234,19 @@ func successGlobJoinTestCase(t *testing.T) joinTestCase { participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{ Roles: []string{"*"}, - Kinds: []string{string(types.SSHSessionKind)}, + Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)}, Modes: []string{string("*")}, }}) return joinTestCase{ - name: "success", - host: hostRole, - sessionKind: types.SSHSessionKind, + name: "success", + host: hostRole, + sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind}, participant: SessionAccessContext{ Username: "participant", Roles: []types.Role{participantRole}, }, - expected: true, + expected: []bool{true, true}, } } @@ -250,14 +257,14 @@ func failRoleJoinTestCase(t *testing.T) joinTestCase { require.NoError(t, err) return joinTestCase{ - name: "failRole", - host: hostRole, - sessionKind: types.SSHSessionKind, + name: "failRole", + host: hostRole, + sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind}, participant: SessionAccessContext{ Username: "participant", Roles: []types.Role{participantRole}, }, - expected: false, + expected: []bool{false, false}, } } @@ -274,14 +281,32 @@ func failKindJoinTestCase(t *testing.T) joinTestCase { }}) return joinTestCase{ - name: "failKind", - host: hostRole, - sessionKind: types.SSHSessionKind, + name: "failKind", + host: hostRole, + sessionKinds: []types.SessionKind{types.SSHSessionKind}, + participant: SessionAccessContext{ + Username: "participant", + Roles: []types.Role{participantRole}, + }, + expected: []bool{false}, + } +} + +func versionDefaultJoinTestCase(t *testing.T) joinTestCase { + hostRole, err := types.NewRole("host", types.RoleSpecV5{}) + require.NoError(t, err) + participantRole, err := types.NewRoleV3("participant", types.RoleSpecV5{}) + require.NoError(t, err) + + return joinTestCase{ + name: "failVersion", + host: hostRole, + sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind}, participant: SessionAccessContext{ Username: "participant", Roles: []types.Role{participantRole}, }, - expected: false, + expected: []bool{true, false}, } } @@ -291,14 +316,17 @@ func TestSessionAccessJoin(t *testing.T) { successGlobJoinTestCase(t), failRoleJoinTestCase(t), failKindJoinTestCase(t), + versionDefaultJoinTestCase(t), } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - policy := testCase.host.GetSessionPolicySet() - evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, testCase.sessionKind) - result := evaluator.CanJoin(testCase.participant) - require.Equal(t, testCase.expected, len(result) > 0) + for i, kind := range testCase.sessionKinds { + policy := testCase.host.GetSessionPolicySet() + evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, kind) + result := evaluator.CanJoin(testCase.participant) + require.Equal(t, testCase.expected[i], len(result) > 0) + } }) } }