Skip to content

Commit

Permalink
add CAS connector
Browse files Browse the repository at this point in the history
  • Loading branch information
mchtech committed Nov 10, 2024
1 parent b211f55 commit a24b6e1
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 1 deletion.
138 changes: 138 additions & 0 deletions connector/cas/cas.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Package cas provides authentication strategies using CAS.
package cas

import (
"fmt"
"log/slog"
"net/http"
"net/url"

"github.com/dexidp/dex/connector"
"github.com/pkg/errors"
"gopkg.in/cas.v2"
)

// Config holds configuration options for CAS logins.
type Config struct {
Portal string `json:"portal"`
Mapping map[string]string `json:"mapping"`
}

// Open returns a strategy for logging in through CAS.
func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
casURL, err := url.Parse(c.Portal)
if err != nil {
return "", fmt.Errorf("failed to parse casURL %q: %v", c.Portal, err)
}
return &casConnector{
client: http.DefaultClient,
portal: casURL,
mapping: c.Mapping,
logger: logger.With(slog.Group("connector", "type", "cas", "id", id)),
pathSuffix: "/" + id,
}, nil
}

var (
_ connector.CallbackConnector = (*casConnector)(nil)
)

type casConnector struct {
client *http.Client
portal *url.URL
mapping map[string]string
logger *slog.Logger
pathSuffix string
}

// LoginURL returns the URL to redirect the user to login with.
func (m *casConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
u, err := url.Parse(callbackURL)
if err != nil {
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
}
u.Path += m.pathSuffix
// context = $callbackURL + $m.pathSuffix
v := u.Query()
v.Set("context", u.String()) // without query params
v.Set("state", state)
u.RawQuery = v.Encode()

loginURL := *m.portal
loginURL.Path += "/login"
// service = $callbackURL + $m.pathSuffix ? state=$state & context=$callbackURL + $m.pathSuffix
q := loginURL.Query()
q.Set("service", u.String()) // service = ...?state=...&context=...
loginURL.RawQuery = q.Encode()
return loginURL.String(), nil
}

// HandleCallback parses the request and returns the user's identity
func (m *casConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) {

state := r.URL.Query().Get("state")
ticket := r.URL.Query().Get("ticket")

// service=context = $callbackURL + $m.pathSuffix
serviceURL, err := url.Parse(r.URL.Query().Get("context"))
if err != nil {
return connector.Identity{}, fmt.Errorf("failed to parse serviceURL %q: %v", r.URL.Query().Get("ext"), err)
}
// service = $callbackURL + $m.pathSuffix ? state=$state & context=$callbackURL + $m.pathSuffix
q := serviceURL.Query()
q.Set("context", serviceURL.String())
q.Set("state", state)
serviceURL.RawQuery = q.Encode()

user, err := m.getCasUserByTicket(ticket, serviceURL)
if err != nil {
return connector.Identity{}, err
}
m.logger.Info("cas user", "user", user)
return user, nil
}

func (m *casConnector) getCasUserByTicket(ticket string, serviceURL *url.URL) (id connector.Identity, err error) {

var (
resp *cas.AuthenticationResponse
)

// validate ticket
validator := cas.NewServiceTicketValidator(m.client, m.portal)
if resp, err = validator.ValidateTicket(serviceURL, ticket); err != nil {
err = errors.Wrapf(err, "failed to validate ticket via %q with ticket %q", serviceURL, ticket)
return
}

// fill identity
id.UserID = resp.User
id.Groups = resp.MemberOf
if len(m.mapping) == 0 {
return
}
if username, ok := m.mapping["username"]; ok {
id.Username = resp.Attributes.Get(username)
if id.Username == "" && username == "userid" {
id.Username = resp.User
}
}
if preferredUsername, ok := m.mapping["preferred_username"]; ok {
id.PreferredUsername = resp.Attributes.Get(preferredUsername)
if id.PreferredUsername == "" && preferredUsername == "userid" {
id.PreferredUsername = resp.User
}
}
if email, ok := m.mapping["email"]; ok {
id.Email = resp.Attributes.Get(email)
if id.Email != "" {
id.EmailVerified = true
}
}
// override memberOf
if groups, ok := m.mapping["groups"]; ok {
id.Groups = resp.Attributes[groups]
}
return

}
194 changes: 194 additions & 0 deletions connector/cas/cas_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package cas

import (
"fmt"
"log/slog"
"math/rand"
"net/http"
"net/url"
"os"
"reflect"
"testing"
"time"

"github.com/dexidp/dex/connector"
"github.com/pkg/errors"
"gopkg.in/yaml.v3"
)

type tcase struct {
xml string
mapping map[string]string
id connector.Identity
err string
}

func TestOpen(t *testing.T) {
configSection := `
portal: https://example.org/cas
mapping:
username: name
preferred_username: username
email: email
groups: affiliation
`

var config Config
if err := yaml.Unmarshal([]byte(configSection), &config); err != nil {
t.Errorf("parse config: %v", err)
return
}

conn, err := config.Open("cas", slog.Default())
if err != nil {
t.Errorf("open connector: %v", err)
return
}

casConnector, _ := conn.(*casConnector)
if casConnector.portal.String() != config.Portal {
t.Errorf("expected portal %q, got %q", config.Portal, casConnector.portal.String())
return
}
if !reflect.DeepEqual(casConnector.mapping, config.Mapping) {
t.Errorf("expected mapping %v, got %v", config.Mapping, casConnector.mapping)
return
}
}

func TestCAS(t *testing.T) {

callback := "https://dex.example.org/dex/callback"
casURL, _ := url.Parse("https://example.org/cas")
scope := connector.Scopes{Groups: true}

cases := []tcase{{
xml: "testdata/cas_success.xml",
mapping: map[string]string{
"username": "name",
"preferred_username": "username",
"email": "email",
},
id: connector.Identity{
UserID: "123456",
Username: "jdoe",
PreferredUsername: "jdoe",
Email: "[email protected]",
EmailVerified: true,
Groups: []string{"A", "B"},
ConnectorData: nil,
},
err: "",
}, {
xml: "testdata/cas_success.xml",
mapping: map[string]string{
"username": "name",
"preferred_username": "username",
"email": "email",
"groups": "affiliation",
},
id: connector.Identity{
UserID: "123456",
Username: "jdoe",
PreferredUsername: "jdoe",
Email: "[email protected]",
EmailVerified: true,
Groups: []string{"staff", "faculty"},
ConnectorData: nil,
},
err: "",
}, {
xml: "testdata/cas_failure.xml",
mapping: map[string]string{},
id: connector.Identity{},
err: "INVALID_TICKET: Ticket ST-1856339-aA5Yuvrxzpv8Tau1cYQ7 not recognized",
}}

seed := rand.NewSource(time.Now().UnixNano())
for _, tc := range cases {

ticket := fmt.Sprintf("ST-%d", seed.Int63())
state := fmt.Sprintf("%d", seed.Int63())

conn := &casConnector{
portal: casURL,
mapping: tc.mapping,
logger: slog.Default(),
pathSuffix: "/cas",
client: &http.Client{
Transport: &mockTransport{
ticket: ticket,
file: tc.xml,
},
},
}

// login
login, err := conn.LoginURL(scope, callback, state)
if err != nil {
t.Errorf("get login url: %v", err)
return
}
loginURL, err := url.Parse(login)
if err != nil {
t.Errorf("parse login url: %v", err)
return
}

// cas server
queryService := loginURL.Query().Get("service")
serviceURL, err := url.Parse(queryService)
if err != nil {
t.Errorf("parse service url: %v", err)
return
}
serviceQueryState := serviceURL.Query().Get("state")
if serviceQueryState != state {
t.Errorf("state: expected %#v, got %#v", state, serviceQueryState)
return
}
req, _ := http.NewRequest(http.MethodGet, queryService, nil)
q := req.URL.Query()
q.Set("ticket", ticket)
req.URL.RawQuery = q.Encode()

// validate
id, err := conn.HandleCallback(scope, req)
if err != nil {
if c := errors.Cause(err); c != nil && tc.err != "" && c.Error() == tc.err {
continue
}
t.Errorf("handle callback: %v", err)
return
}
if !reflect.DeepEqual(id, tc.id) {
t.Errorf("identity: expected %#v, got %#v", tc.id, id)
return
}
}
}

type mockTransport struct {
ticket string
file string
}

func (f *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
file, err := os.Open(f.file)
if err != nil {
return nil, err
}

if ticket := req.URL.Query().Get("ticket"); ticket != f.ticket {
return nil, fmt.Errorf("ticket: expected %#v, got %#v", f.ticket, ticket)
}

return &http.Response{
StatusCode: http.StatusOK,
Body: file,
Header: http.Header{
"Content-Type": []string{"text/xml"},
},
Request: req,
}, nil
}
5 changes: 5 additions & 0 deletions connector/cas/testdata/cas_failure.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
<cas:authenticationFailure code="INVALID_TICKET">
Ticket ST-1856339-aA5Yuvrxzpv8Tau1cYQ7 not recognized
</cas:authenticationFailure>
</cas:serviceResponse>
15 changes: 15 additions & 0 deletions connector/cas/testdata/cas_success.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
<cas:authenticationSuccess>
<cas:user>123456</cas:user>
<cas:attributes>
<cas:name>jdoe</cas:name>
<cas:username>jdoe</cas:username>
<cas:email>[email protected]</cas:email>
<cas:affiliation>staff</cas:affiliation>
<cas:affiliation>faculty</cas:affiliation>
<cas:memberOf>A</cas:memberOf>
<cas:memberOf>B</cas:memberOf>
</cas:attributes>
<cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
</cas:authenticationSuccess>
</cas:serviceResponse>
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ require (
google.golang.org/api v0.203.0
google.golang.org/grpc v1.67.1
google.golang.org/protobuf v1.35.1
gopkg.in/cas.v2 v2.2.2
gopkg.in/yaml.v3 v3.0.1
)

require (
Expand All @@ -63,6 +65,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/inflect v0.19.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/glog v1.2.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/go-cmp v0.6.0 // indirect
Expand Down Expand Up @@ -101,7 +104,6 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

replace github.com/dexidp/dex/api/v2 => ./api/v2
Loading

0 comments on commit a24b6e1

Please sign in to comment.