Skip to content

Commit

Permalink
(v9) Security fixes (#13301)
Browse files Browse the repository at this point in the history
* Rework Agent Forwarding Permissions logic to prevent escalation attacks.

* Add CSRF mitigations

This commit includes two fixes:

1. Enforce an application/json Content-Type server-side.
2. When checking the bearer token, verify that the user
   associated with the token matches the user associated
   with the cookie.

* Fix TEL-Q122-13: Access Requests Denial Of Service Via Request Reason (#125) (#127)

* Ignore input when data flow is off in TermManager

When data flow is disabled in TermManager (at the beginning or when TermManager.Off was called) we should ignore all input we receive (currently we buffer it)

Co-authored-by: joerger <[email protected]>
Co-authored-by: Lisa Kim <[email protected]>
Co-authored-by: Joel <[email protected]>
Co-authored-by: Przemko Robakowski <[email protected]>
  • Loading branch information
5 people authored Jun 8, 2022
1 parent 4d47169 commit 82c446f
Show file tree
Hide file tree
Showing 21 changed files with 1,084 additions and 699 deletions.
14 changes: 11 additions & 3 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,14 +746,22 @@ func (c *Client) GetBotUsers(ctx context.Context) ([]types.User, error) {

// GetAccessRequests retrieves a list of all access requests matching the provided filter.
func (c *Client) GetAccessRequests(ctx context.Context, filter types.AccessRequestFilter) ([]types.AccessRequest, error) {
rsp, err := c.grpc.GetAccessRequests(ctx, &filter, c.callOpts...)
stream, err := c.grpc.GetAccessRequestsV2(ctx, &filter, c.callOpts...)
if err != nil {
return nil, trail.FromGRPC(err)
}
reqs := make([]types.AccessRequest, 0, len(rsp.AccessRequests))
for _, req := range rsp.AccessRequests {
var reqs []types.AccessRequest
for {
req, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return nil, trail.FromGRPC(err)
}
reqs = append(reqs, req)
}

return reqs, nil
}

Expand Down
1,286 changes: 677 additions & 609 deletions api/client/proto/authservice.pb.go

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions api/client/proto/authservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,10 @@ service AuthService {
rpc IsMFARequired(IsMFARequiredRequest) returns (IsMFARequiredResponse);

// GetAccessRequests gets all pending access requests.
// DEPRECATED, DELETE IN 11.0.0: Use GetAccessRequestsV2 instead.
rpc GetAccessRequests(types.AccessRequestFilter) returns (AccessRequests);
// GetAccessRequestsV2 gets all pending access requests.
rpc GetAccessRequestsV2(types.AccessRequestFilter) returns (stream types.AccessRequestV3);
// CreateAccessRequest creates a new access request.
rpc CreateAccessRequest(types.AccessRequestV3) returns (google.protobuf.Empty);
// DeleteAccessRequest deletes an access request.
Expand Down
82 changes: 82 additions & 0 deletions integration/agent_forwarding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package integration

import (
"os"
"os/user"
"runtime"
"syscall"
"testing"

"github.com/gravitational/teleport/lib/teleagent"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)

func TestAgentSocketPermissions(t *testing.T) {
if !isRoot() {
t.Skip("This test will be skipped because tests are not being run as root.")
}

agentServer := teleagent.NewServer(nil)

agentServer.SetTestPermissions(func() {
// ListenUnixSocket should not have its uid changed from root
require.True(t, isRoot())

done := make(chan struct{})

// Start goroutine to attempt privilege escalation during
// permission updates on the unix socket.
//
// For each step of permission updating, it should be impossible
// for the user to unlink/remove the socket. If they can unlink
// or remove the socket, then it could be replaced with a symlink
// which can be used to acquire the permissions of the original socket.
go func() {
defer close(done)

// Update uid to nonroot
_, _, serr := syscall.Syscall(syscall.SYS_SETUID, 1000, 0, 0)
require.Zero(t, serr)
require.True(t, !isRoot())

err := unix.Unlink(agentServer.Path)
require.Error(t, err)
err = os.Remove(agentServer.Path)
require.Error(t, err)
err = os.Rename(agentServer.Path, agentServer.Path)
require.Error(t, err)
}()
<-done

// ListenUnixSocket should not have its uid changed from root
require.True(t, isRoot())
})

nonRoot, err := user.LookupId("1000")
require.NoError(t, err)

// lock goroutine to root so that ListenUnixSocket doesn't
// pick up the syscall in the testPermissions func
runtime.LockOSThread()
defer runtime.UnlockOSThread()

err = agentServer.ListenUnixSocket("test", "sock.agent", nonRoot)
require.NoError(t, err)
}
1 change: 1 addition & 0 deletions integration/app_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,7 @@ func (p *pack) makeWebapiRequest(method, endpoint string, payload []byte) (int,
Value: p.webCookie,
})
req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", p.webToken))
req.Header.Add("Content-Type", "application/json")

statusCode, body, err := p.sendRequest(req, nil)
return statusCode, []byte(body), trace.Wrap(err)
Expand Down
20 changes: 4 additions & 16 deletions integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1695,20 +1695,8 @@ func externalSSHCommand(o commandOptions) (*exec.Cmd, error) {
// clobber your system agent.
func createAgent(me *user.User, privateKeyByte []byte, certificateBytes []byte) (*teleagent.AgentServer, string, string, error) {
// create a path to the unix socket
sockDir, err := ioutil.TempDir("", "int-test")
if err != nil {
return nil, "", "", trace.Wrap(err)
}
sockPath := filepath.Join(sockDir, "agent.sock")

uid, err := strconv.Atoi(me.Uid)
if err != nil {
return nil, "", "", trace.Wrap(err)
}
gid, err := strconv.Atoi(me.Gid)
if err != nil {
return nil, "", "", trace.Wrap(err)
}
sockDirName := "int-test"
sockName := "agent.sock"

// transform the key and certificate bytes into something the agent can understand
publicKey, _, _, _, err := ssh.ParseAuthorizedKey(certificateBytes)
Expand Down Expand Up @@ -1737,13 +1725,13 @@ func createAgent(me *user.User, privateKeyByte []byte, certificateBytes []byte)
})

// start the SSH agent
err = teleAgent.ListenUnixSocket(sockPath, uid, gid, 0600)
err = teleAgent.ListenUnixSocket(sockDirName, sockName, me)
if err != nil {
return nil, "", "", trace.Wrap(err)
}
go teleAgent.Serve()

return teleAgent, sockDir, sockPath, nil
return teleAgent, teleAgent.Dir, teleAgent.Path, nil
}

func closeAgent(teleAgent *teleagent.AgentServer, socketDirPath string) error {
Expand Down
1 change: 1 addition & 0 deletions lib/auth/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ func TestUpsertServer(t *testing.T) {
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "http://localhost", bytes.NewReader(body))
req.RemoteAddr = remoteAddr
req.Header.Add("Content-Type", "application/json")

_, err = new(APIServer).upsertServer(s, tt.role, req, httprouter.Params{httprouter.Param{Key: "namespace", Value: apidefaults.Namespace}})
tt.assertErr(t, err)
Expand Down
29 changes: 29 additions & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ func (g *GRPCServer) GetUsers(req *proto.GetUsersRequest, stream proto.AuthServi
return nil
}

// DEPRECATED, DELETE IN 11.0.0: Use GetAccessRequestsV2 instead.
func (g *GRPCServer) GetAccessRequests(ctx context.Context, f *types.AccessRequestFilter) (*proto.AccessRequests, error) {
auth, err := g.authenticate(ctx)
if err != nil {
Expand Down Expand Up @@ -541,6 +542,34 @@ func (g *GRPCServer) GetAccessRequests(ctx context.Context, f *types.AccessReque
}, nil
}

func (g *GRPCServer) GetAccessRequestsV2(f *types.AccessRequestFilter, stream proto.AuthService_GetAccessRequestsV2Server) error {
ctx := stream.Context()
auth, err := g.authenticate(ctx)
if err != nil {
return trace.Wrap(err)
}
var filter types.AccessRequestFilter
if f != nil {
filter = *f
}
reqs, err := auth.ServerWithRoles.GetAccessRequests(ctx, filter)
if err != nil {
return trace.Wrap(err)
}
for _, req := range reqs {
r, ok := req.(*types.AccessRequestV3)
if !ok {
err = trace.BadParameter("unexpected access request type %T", req)
return trace.Wrap(err)
}

if err := stream.Send(r); err != nil {
return trace.Wrap(err)
}
}
return nil
}

func (g *GRPCServer) CreateAccessRequest(ctx context.Context, req *types.AccessRequestV3) (*empty.Empty, error) {
auth, err := g.authenticate(ctx)
if err != nil {
Expand Down
13 changes: 13 additions & 0 deletions lib/httplib/httplib.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package httplib
import (
"encoding/json"
"errors"
"mime"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -115,6 +116,18 @@ func WithCSRFProtection(fn HandlerFunc) httprouter.Handle {
// ReadJSON reads HTTP json request and unmarshals it
// into passed interface{} obj
func ReadJSON(r *http.Request, val interface{}) error {
// Check content type to mitigate CSRF attack.
contentType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
log.Warningf("Error parsing media type for reading JSON: %v", err)
return trace.BadParameter("invalid request")
}

if contentType != "application/json" {
log.Warningf("Invalid HTTP request header content-type %q for reading JSON", contentType)
return trace.BadParameter("invalid request")
}

data, err := utils.ReadAtMost(r.Body, teleport.MaxHTTPRequestSize)
if err != nil {
return trace.Wrap(err)
Expand Down
60 changes: 60 additions & 0 deletions lib/httplib/httplib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ limitations under the License.
package httplib

import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/require"
. "gopkg.in/check.v1"
)

Expand Down Expand Up @@ -77,3 +81,59 @@ func (h *testHandler) postSessionChunkNamespace(w http.ResponseWriter, r *http.R
h.capturedID = p.ByName("id")
return "ok", nil
}

func TestReadJSON_ContentType(t *testing.T) {
t.Parallel()

type TestJSON struct {
Name string `json:"name"`
Age int `json:"age"`
}

testCases := []struct {
name string
contentType string
wantErr bool
}{
{
name: "empty value",
contentType: "",
wantErr: true,
},
{
name: "invalid type",
contentType: "multipart/form-data",
wantErr: true,
},
{
name: "just type/subtype",
contentType: "application/json",
},
{
name: "type/subtype with params",
contentType: "application/json; charset=utf-8",
},
}

body := TestJSON{Name: "foo", Age: 60}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
payloadBuf := new(bytes.Buffer)
require.NoError(t, json.NewEncoder(payloadBuf).Encode(body))

httpReq, err := http.NewRequest("", "", payloadBuf)
require.NoError(t, err)
httpReq.Header.Add("Content-Type", tc.contentType)

output := TestJSON{}
err = ReadJSON(httpReq, &output)
if tc.wantErr {
require.True(t, strings.Contains(err.Error(), "invalid request"))
require.Empty(t, output)
} else {
require.NoError(t, err)
require.Equal(t, body, output)
}
})
}
}
8 changes: 8 additions & 0 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import (
"github.com/vulcand/predicate"
)

const maxAccessRequestReasonSize = 4096

// ValidateAccessRequest validates the AccessRequest and sets default values
func ValidateAccessRequest(ar types.AccessRequest) error {
if err := ar.CheckAndSetDefaults(); err != nil {
Expand All @@ -41,6 +43,12 @@ func ValidateAccessRequest(ar types.AccessRequest) error {
if err != nil {
return trace.BadParameter("invalid access request id %q", ar.GetName())
}
if len(ar.GetRequestReason()) > maxAccessRequestReasonSize {
return trace.BadParameter("access request reason is too long, max %v bytes", maxAccessRequestReasonSize)
}
if len(ar.GetResolveReason()) > maxAccessRequestReasonSize {
return trace.BadParameter("access request resolve reason is too long, max %v bytes", maxAccessRequestReasonSize)
}
return nil
}

Expand Down
14 changes: 14 additions & 0 deletions lib/services/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,20 @@ func TestReviewThresholds(t *testing.T) {
}
}

// TestMaxLength tests that we reject too large access requests.
func TestMaxLength(t *testing.T) {
req, err := types.NewAccessRequest("some-id", "dave", "dictator", "never")
require.NoError(t, err)

var s []byte
for i := 0; i <= maxAccessRequestReasonSize; i++ {
s = append(s, 'a')
}

req.SetRequestReason(string(s))
require.Error(t, ValidateAccessRequest(req))
}

// TestThresholdReviewFilter verifies basic filter syntax.
func TestThresholdReviewFilter(t *testing.T) {

Expand Down
Loading

0 comments on commit 82c446f

Please sign in to comment.