-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
361 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// Package cas provides authentication strategies using CAS. | ||
package cas | ||
|
||
import ( | ||
"fmt" | ||
"log/slog" | ||
"net/http" | ||
"net/url" | ||
|
||
"github.com/dexidp/dex/connector" | ||
"github.com/pkg/errors" | ||
"gopkg.in/cas.v2" | ||
) | ||
|
||
// Config holds configuration options for CAS logins. | ||
type Config struct { | ||
Portal string `json:"portal"` | ||
Mapping map[string]string `json:"mapping"` | ||
} | ||
|
||
// Open returns a strategy for logging in through CAS. | ||
func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) { | ||
casURL, err := url.Parse(c.Portal) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to parse casURL %q: %v", c.Portal, err) | ||
} | ||
return &casConnector{ | ||
client: http.DefaultClient, | ||
portal: casURL, | ||
mapping: c.Mapping, | ||
logger: logger.With(slog.Group("connector", "type", "cas", "id", id)), | ||
pathSuffix: "/" + id, | ||
}, nil | ||
} | ||
|
||
var ( | ||
_ connector.CallbackConnector = (*casConnector)(nil) | ||
) | ||
|
||
type casConnector struct { | ||
client *http.Client | ||
portal *url.URL | ||
mapping map[string]string | ||
logger *slog.Logger | ||
pathSuffix string | ||
} | ||
|
||
// LoginURL returns the URL to redirect the user to login with. | ||
func (m *casConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { | ||
u, err := url.Parse(callbackURL) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) | ||
} | ||
u.Path += m.pathSuffix | ||
// context = $callbackURL + $m.pathSuffix | ||
v := u.Query() | ||
v.Set("context", u.String()) // without query params | ||
v.Set("state", state) | ||
u.RawQuery = v.Encode() | ||
|
||
loginURL := *m.portal | ||
loginURL.Path += "/login" | ||
// service = $callbackURL + $m.pathSuffix ? state=$state & context=$callbackURL + $m.pathSuffix | ||
q := loginURL.Query() | ||
q.Set("service", u.String()) // service = ...?state=...&context=... | ||
loginURL.RawQuery = q.Encode() | ||
return loginURL.String(), nil | ||
} | ||
|
||
// HandleCallback parses the request and returns the user's identity | ||
func (m *casConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { | ||
|
||
state := r.URL.Query().Get("state") | ||
ticket := r.URL.Query().Get("ticket") | ||
|
||
// service=context = $callbackURL + $m.pathSuffix | ||
serviceURL, err := url.Parse(r.URL.Query().Get("context")) | ||
if err != nil { | ||
return connector.Identity{}, fmt.Errorf("failed to parse serviceURL %q: %v", r.URL.Query().Get("ext"), err) | ||
} | ||
// service = $callbackURL + $m.pathSuffix ? state=$state & context=$callbackURL + $m.pathSuffix | ||
q := serviceURL.Query() | ||
q.Set("context", serviceURL.String()) | ||
q.Set("state", state) | ||
serviceURL.RawQuery = q.Encode() | ||
|
||
user, err := m.getCasUserByTicket(ticket, serviceURL) | ||
if err != nil { | ||
return connector.Identity{}, err | ||
} | ||
m.logger.Info("cas user", "user", user) | ||
return user, nil | ||
} | ||
|
||
func (m *casConnector) getCasUserByTicket(ticket string, serviceURL *url.URL) (id connector.Identity, err error) { | ||
|
||
var ( | ||
resp *cas.AuthenticationResponse | ||
) | ||
|
||
// validate ticket | ||
validator := cas.NewServiceTicketValidator(m.client, m.portal) | ||
if resp, err = validator.ValidateTicket(serviceURL, ticket); err != nil { | ||
err = errors.Wrapf(err, "failed to validate ticket via %q with ticket %q", serviceURL, ticket) | ||
return | ||
} | ||
|
||
// fill identity | ||
id.UserID = resp.User | ||
id.Groups = resp.MemberOf | ||
if len(m.mapping) == 0 { | ||
return | ||
} | ||
if username, ok := m.mapping["username"]; ok { | ||
id.Username = resp.Attributes.Get(username) | ||
if id.Username == "" && username == "userid" { | ||
id.Username = resp.User | ||
} | ||
} | ||
if preferredUsername, ok := m.mapping["preferred_username"]; ok { | ||
id.PreferredUsername = resp.Attributes.Get(preferredUsername) | ||
if id.PreferredUsername == "" && preferredUsername == "userid" { | ||
id.PreferredUsername = resp.User | ||
} | ||
} | ||
if email, ok := m.mapping["email"]; ok { | ||
id.Email = resp.Attributes.Get(email) | ||
if id.Email != "" { | ||
id.EmailVerified = true | ||
} | ||
} | ||
// override memberOf | ||
if groups, ok := m.mapping["groups"]; ok { | ||
id.Groups = resp.Attributes[groups] | ||
} | ||
return | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
package cas | ||
|
||
import ( | ||
"fmt" | ||
"log/slog" | ||
"math/rand" | ||
"net/http" | ||
"net/url" | ||
"os" | ||
"reflect" | ||
"testing" | ||
"time" | ||
|
||
"github.com/dexidp/dex/connector" | ||
"github.com/pkg/errors" | ||
"gopkg.in/yaml.v3" | ||
) | ||
|
||
type tcase struct { | ||
xml string | ||
mapping map[string]string | ||
id connector.Identity | ||
err string | ||
} | ||
|
||
func TestOpen(t *testing.T) { | ||
configSection := ` | ||
portal: https://example.org/cas | ||
mapping: | ||
username: name | ||
preferred_username: username | ||
email: email | ||
groups: affiliation | ||
` | ||
|
||
var config Config | ||
if err := yaml.Unmarshal([]byte(configSection), &config); err != nil { | ||
t.Errorf("parse config: %v", err) | ||
return | ||
} | ||
|
||
conn, err := config.Open("cas", slog.Default()) | ||
if err != nil { | ||
t.Errorf("open connector: %v", err) | ||
return | ||
} | ||
|
||
casConnector, _ := conn.(*casConnector) | ||
if casConnector.portal.String() != config.Portal { | ||
t.Errorf("expected portal %q, got %q", config.Portal, casConnector.portal.String()) | ||
return | ||
} | ||
if !reflect.DeepEqual(casConnector.mapping, config.Mapping) { | ||
t.Errorf("expected mapping %v, got %v", config.Mapping, casConnector.mapping) | ||
return | ||
} | ||
} | ||
|
||
func TestCAS(t *testing.T) { | ||
|
||
callback := "https://dex.example.org/dex/callback" | ||
casURL, _ := url.Parse("https://example.org/cas") | ||
scope := connector.Scopes{Groups: true} | ||
|
||
cases := []tcase{{ | ||
xml: "testdata/cas_success.xml", | ||
mapping: map[string]string{ | ||
"username": "name", | ||
"preferred_username": "username", | ||
"email": "email", | ||
}, | ||
id: connector.Identity{ | ||
UserID: "123456", | ||
Username: "jdoe", | ||
PreferredUsername: "jdoe", | ||
Email: "[email protected]", | ||
EmailVerified: true, | ||
Groups: []string{"A", "B"}, | ||
ConnectorData: nil, | ||
}, | ||
err: "", | ||
}, { | ||
xml: "testdata/cas_success.xml", | ||
mapping: map[string]string{ | ||
"username": "name", | ||
"preferred_username": "username", | ||
"email": "email", | ||
"groups": "affiliation", | ||
}, | ||
id: connector.Identity{ | ||
UserID: "123456", | ||
Username: "jdoe", | ||
PreferredUsername: "jdoe", | ||
Email: "[email protected]", | ||
EmailVerified: true, | ||
Groups: []string{"staff", "faculty"}, | ||
ConnectorData: nil, | ||
}, | ||
err: "", | ||
}, { | ||
xml: "testdata/cas_failure.xml", | ||
mapping: map[string]string{}, | ||
id: connector.Identity{}, | ||
err: "INVALID_TICKET: Ticket ST-1856339-aA5Yuvrxzpv8Tau1cYQ7 not recognized", | ||
}} | ||
|
||
seed := rand.NewSource(time.Now().UnixNano()) | ||
for _, tc := range cases { | ||
|
||
ticket := fmt.Sprintf("ST-%d", seed.Int63()) | ||
state := fmt.Sprintf("%d", seed.Int63()) | ||
|
||
conn := &casConnector{ | ||
portal: casURL, | ||
mapping: tc.mapping, | ||
logger: slog.Default(), | ||
pathSuffix: "/cas", | ||
client: &http.Client{ | ||
Transport: &mockTransport{ | ||
ticket: ticket, | ||
file: tc.xml, | ||
}, | ||
}, | ||
} | ||
|
||
// login | ||
login, err := conn.LoginURL(scope, callback, state) | ||
if err != nil { | ||
t.Errorf("get login url: %v", err) | ||
return | ||
} | ||
loginURL, err := url.Parse(login) | ||
if err != nil { | ||
t.Errorf("parse login url: %v", err) | ||
return | ||
} | ||
|
||
// cas server | ||
queryService := loginURL.Query().Get("service") | ||
serviceURL, err := url.Parse(queryService) | ||
if err != nil { | ||
t.Errorf("parse service url: %v", err) | ||
return | ||
} | ||
serviceQueryState := serviceURL.Query().Get("state") | ||
if serviceQueryState != state { | ||
t.Errorf("state: expected %#v, got %#v", state, serviceQueryState) | ||
return | ||
} | ||
req, _ := http.NewRequest(http.MethodGet, queryService, nil) | ||
q := req.URL.Query() | ||
q.Set("ticket", ticket) | ||
req.URL.RawQuery = q.Encode() | ||
|
||
// validate | ||
id, err := conn.HandleCallback(scope, req) | ||
if err != nil { | ||
if c := errors.Cause(err); c != nil && tc.err != "" && c.Error() == tc.err { | ||
continue | ||
} | ||
t.Errorf("handle callback: %v", err) | ||
return | ||
} | ||
if !reflect.DeepEqual(id, tc.id) { | ||
t.Errorf("identity: expected %#v, got %#v", tc.id, id) | ||
return | ||
} | ||
} | ||
} | ||
|
||
type mockTransport struct { | ||
ticket string | ||
file string | ||
} | ||
|
||
func (f *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||
file, err := os.Open(f.file) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if ticket := req.URL.Query().Get("ticket"); ticket != f.ticket { | ||
return nil, fmt.Errorf("ticket: expected %#v, got %#v", f.ticket, ticket) | ||
} | ||
|
||
return &http.Response{ | ||
StatusCode: http.StatusOK, | ||
Body: file, | ||
Header: http.Header{ | ||
"Content-Type": []string{"text/xml"}, | ||
}, | ||
Request: req, | ||
}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas"> | ||
<cas:authenticationFailure code="INVALID_TICKET"> | ||
Ticket ST-1856339-aA5Yuvrxzpv8Tau1cYQ7 not recognized | ||
</cas:authenticationFailure> | ||
</cas:serviceResponse> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas"> | ||
<cas:authenticationSuccess> | ||
<cas:user>123456</cas:user> | ||
<cas:attributes> | ||
<cas:name>jdoe</cas:name> | ||
<cas:username>jdoe</cas:username> | ||
<cas:email>[email protected]</cas:email> | ||
<cas:affiliation>staff</cas:affiliation> | ||
<cas:affiliation>faculty</cas:affiliation> | ||
<cas:memberOf>A</cas:memberOf> | ||
<cas:memberOf>B</cas:memberOf> | ||
</cas:attributes> | ||
<cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket> | ||
</cas:authenticationSuccess> | ||
</cas:serviceResponse> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.