Skip to content

Commit

Permalink
[idp] Adress review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
csweichel committed Mar 3, 2023
1 parent 7218c8b commit 0f47870
Show file tree
Hide file tree
Showing 42 changed files with 934 additions and 542 deletions.
2 changes: 1 addition & 1 deletion components/gitpod-cli/cmd/idp-token.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func idpToken(ctx context.Context, audience []string) (idToken string, err error
if err != nil {
return "", err
}
tkn, err := c.IDP.GetIDToken(ctx, &connect.Request[v1.GetIDTokenRequest]{
tkn, err := c.IdentityProvider.GetIDToken(ctx, &connect.Request[v1.GetIDTokenRequest]{
Msg: &v1.GetIDTokenRequest{
Audience: audience,
WorkspaceId: wsInfo.WorkspaceId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,31 @@ type IDTokenSource interface {
IDToken(ctx context.Context, org string, audience []string, userInfo oidc.UserInfo) (string, error)
}

func NewIDPService(serverConnPool proxy.ServerConnectionPool, source IDTokenSource) *IDPService {
return &IDPService{
func NewIdentityProviderService(serverConnPool proxy.ServerConnectionPool, source IDTokenSource) *IdentityProviderService {
return &IdentityProviderService{
connectionPool: serverConnPool,
idTokenSource: source,
}
}

type IDPService struct {
type IdentityProviderService struct {
connectionPool proxy.ServerConnectionPool
idTokenSource IDTokenSource

v1connect.UnimplementedWorkspacesServiceHandler
}

var _ v1connect.IDPServiceHandler = ((*IDPService)(nil))
var _ v1connect.IdentityProviderServiceHandler = ((*IdentityProviderService)(nil))

// GetIDToken implements v1connect.IDPServiceHandler
func (srv *IDPService) GetIDToken(ctx context.Context, req *connect.Request[v1.GetIDTokenRequest]) (*connect.Response[v1.GetIDTokenResponse], error) {
func (srv *IdentityProviderService) GetIDToken(ctx context.Context, req *connect.Request[v1.GetIDTokenRequest]) (*connect.Response[v1.GetIDTokenResponse], error) {
workspaceID, err := validateWorkspaceID(req.Msg.GetWorkspaceId())
if err != nil {
return nil, err
}

if len(req.Msg.Audience) < 1 {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("must have at least one audience entry"))
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("Must have at least one audience entry"))
}

logger := ctxlogrus.Extract(ctx).WithField("workspace_id", workspaceID)
Expand Down Expand Up @@ -72,6 +72,11 @@ func (srv *IDPService) GetIDToken(ctx context.Context, req *connect.Request[v1.G
return nil, proxy.ConvertError(err)
}

if workspace.Workspace == nil {
logger.WithError(err).Error("Server did not return a workspace.")
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("workspace not found"))
}

subject := workspace.Workspace.ContextURL
userInfo := oidc.NewUserInfo()
userInfo.SetName(user.Name)
Expand Down
203 changes: 203 additions & 0 deletions components/public-api-server/pkg/apiv1/identityprovider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package apiv1

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"

connect "github.com/bufbuild/connect-go"
v1 "github.com/gitpod-io/gitpod/components/public-api/go/experimental/v1"
"github.com/gitpod-io/gitpod/components/public-api/go/experimental/v1/v1connect"
protocol "github.com/gitpod-io/gitpod/gitpod-protocol"
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
"github.com/golang/mock/gomock"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/sourcegraph/jsonrpc2"
"github.com/zitadel/oidc/pkg/oidc"
)

func TestGetIDToken(t *testing.T) {
const workspaceID = "gitpodio-gitpod-te23l4bjejv"
type Expectation struct {
Error string
Response *v1.GetIDTokenResponse
}
tests := []struct {
Name string
TokenSource IDTokenSource
ServerSetup func(*protocol.MockAPIInterface)
Request *v1.GetIDTokenRequest

Expectation Expectation
}{
{
Name: "happy path",
TokenSource: functionIDTokenSource(func(ctx context.Context, org string, audience []string, userInfo oidc.UserInfo) (string, error) {
return "foobar", nil
}),
ServerSetup: func(ma *protocol.MockAPIInterface) {
ma.EXPECT().GetIDToken(gomock.Any()).MinTimes(1).Return(nil)
ma.EXPECT().GetWorkspace(gomock.Any(), workspaceID).MinTimes(1).Return(
&protocol.WorkspaceInfo{
Workspace: &protocol.Workspace{
ContextURL: "https://github.com/gitpod-io/gitpod",
},
},
nil,
)
ma.EXPECT().GetLoggedInUser(gomock.Any()).Return(
&protocol.User{
Name: "foobar",
},
nil,
)
},
Request: &v1.GetIDTokenRequest{
WorkspaceId: workspaceID,
Audience: []string{"some.audience.com"},
},
Expectation: Expectation{
Response: &v1.GetIDTokenResponse{
Token: "foobar",
},
},
},
{
Name: "workspace not found",
TokenSource: functionIDTokenSource(func(ctx context.Context, org string, audience []string, userInfo oidc.UserInfo) (string, error) {
return "foobar", nil
}),
ServerSetup: func(ma *protocol.MockAPIInterface) {
ma.EXPECT().GetIDToken(gomock.Any()).MinTimes(1).Return(nil)
ma.EXPECT().GetWorkspace(gomock.Any(), workspaceID).MinTimes(1).Return(
nil,
&jsonrpc2.Error{Code: 400, Message: "workspace not found"},
)
},
Request: &v1.GetIDTokenRequest{
WorkspaceId: workspaceID,
Audience: []string{"some.audience.com"},
},
Expectation: Expectation{
Error: connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("workspace not found")).Error(),
},
},
{
Name: "no logged in user",
TokenSource: functionIDTokenSource(func(ctx context.Context, org string, audience []string, userInfo oidc.UserInfo) (string, error) {
return "foobar", nil
}),
ServerSetup: func(ma *protocol.MockAPIInterface) {
ma.EXPECT().GetIDToken(gomock.Any()).MinTimes(1).Return(nil)
ma.EXPECT().GetWorkspace(gomock.Any(), workspaceID).MinTimes(1).Return(
&protocol.WorkspaceInfo{
Workspace: &protocol.Workspace{
ContextURL: "https://github.com/gitpod-io/gitpod",
},
},
nil,
)
ma.EXPECT().GetLoggedInUser(gomock.Any()).Return(
nil,
&jsonrpc2.Error{Code: 401, Message: "User is not authenticated. Please login."},
)
},
Request: &v1.GetIDTokenRequest{
WorkspaceId: workspaceID,
Audience: []string{"some.audience.com"},
},
Expectation: Expectation{
Error: connect.NewError(connect.CodePermissionDenied, fmt.Errorf("User is not authenticated. Please login.")).Error(),
},
},
{
Name: "no audience",
TokenSource: functionIDTokenSource(func(ctx context.Context, org string, audience []string, userInfo oidc.UserInfo) (string, error) {
return "foobar", nil
}),
Request: &v1.GetIDTokenRequest{
WorkspaceId: workspaceID,
},
Expectation: Expectation{
Error: connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("Must have at least one audience entry")).Error(),
},
},
{
Name: "token source error",
TokenSource: functionIDTokenSource(func(ctx context.Context, org string, audience []string, userInfo oidc.UserInfo) (string, error) {
return "", fmt.Errorf("cannot produce token")
}),
ServerSetup: func(ma *protocol.MockAPIInterface) {
ma.EXPECT().GetIDToken(gomock.Any()).MinTimes(1).Return(nil)
ma.EXPECT().GetWorkspace(gomock.Any(), workspaceID).MinTimes(1).Return(
&protocol.WorkspaceInfo{
Workspace: &protocol.Workspace{
ContextURL: "https://github.com/gitpod-io/gitpod",
},
},
nil,
)
ma.EXPECT().GetLoggedInUser(gomock.Any()).Return(
&protocol.User{
Name: "foobar",
},
nil,
)
},
Request: &v1.GetIDTokenRequest{
WorkspaceId: workspaceID,
Audience: []string{"some.audience.com"},
},
Expectation: Expectation{
Error: connect.NewError(connect.CodeInternal, fmt.Errorf("cannot produce token")).Error(),
},
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
serverMock := protocol.NewMockAPIInterface(ctrl)
if test.ServerSetup != nil {
test.ServerSetup(serverMock)
}

svc := NewIdentityProviderService(&FakeServerConnPool{api: serverMock}, test.TokenSource)
_, handler := v1connect.NewIdentityProviderServiceHandler(svc, connect.WithInterceptors(auth.NewServerInterceptor()))
srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)

client := v1connect.NewIdentityProviderServiceClient(http.DefaultClient, srv.URL, connect.WithInterceptors(
auth.NewClientInterceptor("auth-token"),
))

resp, err := client.GetIDToken(context.Background(), &connect.Request[v1.GetIDTokenRequest]{
Msg: test.Request,
})
var act Expectation
if err != nil {
act.Error = err.Error()
} else {
act.Response = resp.Msg
}

if diff := cmp.Diff(test.Expectation, act, cmpopts.IgnoreUnexported(v1.GetIDTokenResponse{})); diff != "" {
t.Errorf("GetIDToken() mismatch (-want +got):\n%s", diff)
}
})
}
}

type functionIDTokenSource func(ctx context.Context, org string, audience []string, userInfo oidc.UserInfo) (string, error)

func (f functionIDTokenSource) IDToken(ctx context.Context, org string, audience []string, userInfo oidc.UserInfo) (string, error) {
return f(ctx, org, audience, userInfo)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package idp
package identityprovider

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package idp
package identityprovider

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package idp
package identityprovider

import (
"context"
Expand Down Expand Up @@ -62,7 +62,7 @@ func (kp *Service) Router() http.Handler {
h.ServeHTTP(w, r)
})
})
mux.Handle(oidc.DiscoveryEndpoint, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mux.Get(oidc.DiscoveryEndpoint, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
keysURL, err := url.JoinPath(kp.IssuerBaseURL, "keys")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -129,13 +129,16 @@ func (kp *Service) Router() http.Handler {
return
}
}))
mux.Handle("/keys", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mux.Get("/keys", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
keys, err := kp.keys.PublicKeys(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write(keys)
_, err = w.Write(keys)
if err != nil {
log.WithError(err).Error("cannot repond to /keys")
}
}))

return mux
Expand Down Expand Up @@ -163,5 +166,10 @@ func (kp *Service) IDToken(ctx context.Context, org string, audience []string, u
return "", err
}

return crypto.Sign(claims, signer)
token, err := crypto.Sign(claims, signer)
if err != nil {
log.WithError(err).Error("cannot sign OIDC ID token")
return "", fmt.Errorf("cannot sign ID token")
}
return token, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package idp
package identityprovider

import (
"context"
Expand Down
Loading

0 comments on commit 0f47870

Please sign in to comment.