Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(v9) Security fixes #13301

Merged
merged 4 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,14 +746,22 @@ func (c *Client) GetBotUsers(ctx context.Context) ([]types.User, error) {

// GetAccessRequests retrieves a list of all access requests matching the provided filter.
func (c *Client) GetAccessRequests(ctx context.Context, filter types.AccessRequestFilter) ([]types.AccessRequest, error) {
rsp, err := c.grpc.GetAccessRequests(ctx, &filter, c.callOpts...)
stream, err := c.grpc.GetAccessRequestsV2(ctx, &filter, c.callOpts...)
if err != nil {
return nil, trail.FromGRPC(err)
}
reqs := make([]types.AccessRequest, 0, len(rsp.AccessRequests))
for _, req := range rsp.AccessRequests {
var reqs []types.AccessRequest
for {
req, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return nil, trail.FromGRPC(err)
}
reqs = append(reqs, req)
}

return reqs, nil
}

Expand Down
1,286 changes: 677 additions & 609 deletions api/client/proto/authservice.pb.go

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions api/client/proto/authservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,10 @@ service AuthService {
rpc IsMFARequired(IsMFARequiredRequest) returns (IsMFARequiredResponse);

// GetAccessRequests gets all pending access requests.
// DEPRECATED, DELETE IN 11.0.0: Use GetAccessRequestsV2 instead.
rpc GetAccessRequests(types.AccessRequestFilter) returns (AccessRequests);
// GetAccessRequestsV2 gets all pending access requests.
rpc GetAccessRequestsV2(types.AccessRequestFilter) returns (stream types.AccessRequestV3);
// CreateAccessRequest creates a new access request.
rpc CreateAccessRequest(types.AccessRequestV3) returns (google.protobuf.Empty);
// DeleteAccessRequest deletes an access request.
Expand Down
82 changes: 82 additions & 0 deletions integration/agent_forwarding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Copyright 2022 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package integration

import (
"os"
"os/user"
"runtime"
"syscall"
"testing"

"github.com/gravitational/teleport/lib/teleagent"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)

func TestAgentSocketPermissions(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test can run as a root. The name needs to be changed: https://github.com/gravitational/teleport/blob/master/Makefile#L627-L638

if !isRoot() {
t.Skip("This test will be skipped because tests are not being run as root.")
}

agentServer := teleagent.NewServer(nil)

agentServer.SetTestPermissions(func() {
// ListenUnixSocket should not have its uid changed from root
require.True(t, isRoot())

done := make(chan struct{})

// Start goroutine to attempt privilege escalation during
// permission updates on the unix socket.
//
// For each step of permission updating, it should be impossible
// for the user to unlink/remove the socket. If they can unlink
// or remove the socket, then it could be replaced with a symlink
// which can be used to acquire the permissions of the original socket.
go func() {
defer close(done)

// Update uid to nonroot
_, _, serr := syscall.Syscall(syscall.SYS_SETUID, 1000, 0, 0)
require.Zero(t, serr)
require.True(t, !isRoot())

err := unix.Unlink(agentServer.Path)
require.Error(t, err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably check what error is returned

err = os.Remove(agentServer.Path)
require.Error(t, err)
err = os.Rename(agentServer.Path, agentServer.Path)
require.Error(t, err)
}()
<-done

// ListenUnixSocket should not have its uid changed from root
require.True(t, isRoot())
})

nonRoot, err := user.LookupId("1000")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is system dependent. Not every OS has a user with ID 1000.

require.NoError(t, err)

// lock goroutine to root so that ListenUnixSocket doesn't
// pick up the syscall in the testPermissions func
runtime.LockOSThread()
defer runtime.UnlockOSThread()

err = agentServer.ListenUnixSocket("test", "sock.agent", nonRoot)
require.NoError(t, err)
}
1 change: 1 addition & 0 deletions integration/app_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,7 @@ func (p *pack) makeWebapiRequest(method, endpoint string, payload []byte) (int,
Value: p.webCookie,
})
req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", p.webToken))
req.Header.Add("Content-Type", "application/json")

statusCode, body, err := p.sendRequest(req, nil)
return statusCode, []byte(body), trace.Wrap(err)
Expand Down
20 changes: 4 additions & 16 deletions integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1695,20 +1695,8 @@ func externalSSHCommand(o commandOptions) (*exec.Cmd, error) {
// clobber your system agent.
func createAgent(me *user.User, privateKeyByte []byte, certificateBytes []byte) (*teleagent.AgentServer, string, string, error) {
// create a path to the unix socket
sockDir, err := ioutil.TempDir("", "int-test")
if err != nil {
return nil, "", "", trace.Wrap(err)
}
sockPath := filepath.Join(sockDir, "agent.sock")

uid, err := strconv.Atoi(me.Uid)
if err != nil {
return nil, "", "", trace.Wrap(err)
}
gid, err := strconv.Atoi(me.Gid)
if err != nil {
return nil, "", "", trace.Wrap(err)
}
sockDirName := "int-test"
sockName := "agent.sock"

// transform the key and certificate bytes into something the agent can understand
publicKey, _, _, _, err := ssh.ParseAuthorizedKey(certificateBytes)
Expand Down Expand Up @@ -1737,13 +1725,13 @@ func createAgent(me *user.User, privateKeyByte []byte, certificateBytes []byte)
})

// start the SSH agent
err = teleAgent.ListenUnixSocket(sockPath, uid, gid, 0600)
err = teleAgent.ListenUnixSocket(sockDirName, sockName, me)
if err != nil {
return nil, "", "", trace.Wrap(err)
}
go teleAgent.Serve()

return teleAgent, sockDir, sockPath, nil
return teleAgent, teleAgent.Dir, teleAgent.Path, nil
}

func closeAgent(teleAgent *teleagent.AgentServer, socketDirPath string) error {
Expand Down
1 change: 1 addition & 0 deletions lib/auth/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ func TestUpsertServer(t *testing.T) {
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "http://localhost", bytes.NewReader(body))
req.RemoteAddr = remoteAddr
req.Header.Add("Content-Type", "application/json")

_, err = new(APIServer).upsertServer(s, tt.role, req, httprouter.Params{httprouter.Param{Key: "namespace", Value: apidefaults.Namespace}})
tt.assertErr(t, err)
Expand Down
29 changes: 29 additions & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ func (g *GRPCServer) GetUsers(req *proto.GetUsersRequest, stream proto.AuthServi
return nil
}

// DEPRECATED, DELETE IN 11.0.0: Use GetAccessRequestsV2 instead.
func (g *GRPCServer) GetAccessRequests(ctx context.Context, f *types.AccessRequestFilter) (*proto.AccessRequests, error) {
auth, err := g.authenticate(ctx)
if err != nil {
Expand Down Expand Up @@ -541,6 +542,34 @@ func (g *GRPCServer) GetAccessRequests(ctx context.Context, f *types.AccessReque
}, nil
}

func (g *GRPCServer) GetAccessRequestsV2(f *types.AccessRequestFilter, stream proto.AuthService_GetAccessRequestsV2Server) error {
ctx := stream.Context()
auth, err := g.authenticate(ctx)
if err != nil {
return trace.Wrap(err)
}
var filter types.AccessRequestFilter
if f != nil {
filter = *f
}
reqs, err := auth.ServerWithRoles.GetAccessRequests(ctx, filter)
if err != nil {
return trace.Wrap(err)
}
for _, req := range reqs {
r, ok := req.(*types.AccessRequestV3)
if !ok {
err = trace.BadParameter("unexpected access request type %T", req)
return trace.Wrap(err)
}

if err := stream.Send(r); err != nil {
return trace.Wrap(err)
}
}
return nil
}

func (g *GRPCServer) CreateAccessRequest(ctx context.Context, req *types.AccessRequestV3) (*empty.Empty, error) {
auth, err := g.authenticate(ctx)
if err != nil {
Expand Down
13 changes: 13 additions & 0 deletions lib/httplib/httplib.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package httplib
import (
"encoding/json"
"errors"
"mime"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -115,6 +116,18 @@ func WithCSRFProtection(fn HandlerFunc) httprouter.Handle {
// ReadJSON reads HTTP json request and unmarshals it
// into passed interface{} obj
func ReadJSON(r *http.Request, val interface{}) error {
// Check content type to mitigate CSRF attack.
contentType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
log.Warningf("Error parsing media type for reading JSON: %v", err)
return trace.BadParameter("invalid request")
}

if contentType != "application/json" {
log.Warningf("Invalid HTTP request header content-type %q for reading JSON", contentType)
return trace.BadParameter("invalid request")
}

data, err := utils.ReadAtMost(r.Body, teleport.MaxHTTPRequestSize)
if err != nil {
return trace.Wrap(err)
Expand Down
60 changes: 60 additions & 0 deletions lib/httplib/httplib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ limitations under the License.
package httplib

import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/require"
. "gopkg.in/check.v1"
)

Expand Down Expand Up @@ -77,3 +81,59 @@ func (h *testHandler) postSessionChunkNamespace(w http.ResponseWriter, r *http.R
h.capturedID = p.ByName("id")
return "ok", nil
}

func TestReadJSON_ContentType(t *testing.T) {
t.Parallel()

type TestJSON struct {
Name string `json:"name"`
Age int `json:"age"`
}

testCases := []struct {
name string
contentType string
wantErr bool
}{
{
name: "empty value",
contentType: "",
wantErr: true,
},
{
name: "invalid type",
contentType: "multipart/form-data",
wantErr: true,
},
{
name: "just type/subtype",
contentType: "application/json",
},
{
name: "type/subtype with params",
contentType: "application/json; charset=utf-8",
},
}

body := TestJSON{Name: "foo", Age: 60}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
payloadBuf := new(bytes.Buffer)
require.NoError(t, json.NewEncoder(payloadBuf).Encode(body))

httpReq, err := http.NewRequest("", "", payloadBuf)
require.NoError(t, err)
httpReq.Header.Add("Content-Type", tc.contentType)

output := TestJSON{}
err = ReadJSON(httpReq, &output)
if tc.wantErr {
require.True(t, strings.Contains(err.Error(), "invalid request"))
require.Empty(t, output)
} else {
require.NoError(t, err)
require.Equal(t, body, output)
}
})
}
}
8 changes: 8 additions & 0 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import (
"github.com/vulcand/predicate"
)

const maxAccessRequestReasonSize = 4096

// ValidateAccessRequest validates the AccessRequest and sets default values
func ValidateAccessRequest(ar types.AccessRequest) error {
if err := ar.CheckAndSetDefaults(); err != nil {
Expand All @@ -41,6 +43,12 @@ func ValidateAccessRequest(ar types.AccessRequest) error {
if err != nil {
return trace.BadParameter("invalid access request id %q", ar.GetName())
}
if len(ar.GetRequestReason()) > maxAccessRequestReasonSize {
return trace.BadParameter("access request reason is too long, max %v bytes", maxAccessRequestReasonSize)
}
if len(ar.GetResolveReason()) > maxAccessRequestReasonSize {
return trace.BadParameter("access request resolve reason is too long, max %v bytes", maxAccessRequestReasonSize)
}
return nil
}

Expand Down
14 changes: 14 additions & 0 deletions lib/services/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,20 @@ func TestReviewThresholds(t *testing.T) {
}
}

// TestMaxLength tests that we reject too large access requests.
func TestMaxLength(t *testing.T) {
req, err := types.NewAccessRequest("some-id", "dave", "dictator", "never")
require.NoError(t, err)

var s []byte
for i := 0; i <= maxAccessRequestReasonSize; i++ {
s = append(s, 'a')
}

req.SetRequestReason(string(s))
require.Error(t, ValidateAccessRequest(req))
}

// TestThresholdReviewFilter verifies basic filter syntax.
func TestThresholdReviewFilter(t *testing.T) {

Expand Down
Loading