Skip to content

Commit

Permalink
Add an API endpoint and redirect requests to it
Browse files Browse the repository at this point in the history
Remove workaround for iRODS 4.3.x go-irodsclient support
  • Loading branch information
kjsanger committed Apr 2, 2024
1 parent 42aae1c commit 8e52c06
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 64 deletions.
30 changes: 18 additions & 12 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,28 @@
package server

import (
"fmt"
"net/http"
"path"

"github.com/cyverse/go-irodsclient/irods/types"
"github.com/rs/zerolog"
"net/http"
"path"
)

// HandleHomePage is a handler for the static home page.
func HandleHomePage(logger zerolog.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Trace().Msg("HomeHandler called")

requestPath := r.URL.Path

if r.URL.Path != "/" {
redirect := path.Join(EndpointAPI, requestPath)
logger.Trace().
Str("from", requestPath).
Str("to", redirect).
Msg("Redirecting to API")
http.Redirect(w, r, redirect, http.StatusPermanentRedirect)
}

type customData struct {
URL string
Version string
Expand All @@ -51,18 +60,12 @@ func HandleIRODSGet(logger zerolog.Logger, account *types.IRODSAccount) http.Han
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Trace().Msg("iRODS get handler called")

if !r.URL.Query().Has(HTTPParamPath) {
writeErrorResponse(logger, w, http.StatusBadRequest,
fmt.Sprintf("'%s' parameter is missing", HTTPParamPath))
return
}

var corrID string
if val := r.Context().Value(correlationIDKey); val != nil {
corrID = val.(string)
}

dirtyPath := r.URL.Query().Get(HTTPParamPath)
dirtyPath := r.URL.Path
sanPath := userInputPolicy.Sanitize(dirtyPath)
if sanPath != dirtyPath {
logger.Warn().
Expand All @@ -76,6 +79,9 @@ func HandleIRODSGet(logger zerolog.Logger, account *types.IRODSAccount) http.Han
Str("correlation_id", corrID).
Str("irods", "get").Logger()

getFileRange(rodsLogger, w, r, account, path.Clean(sanPath))
sanPath = path.Clean("/" + sanPath)
logger.Debug().Str("path", sanPath).Msg("Getting iRODS data object")

getFileRange(rodsLogger, w, r, account, sanPath)
})
}
37 changes: 24 additions & 13 deletions server/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"path"
"path/filepath"

Expand Down Expand Up @@ -63,16 +64,21 @@ var _ = Describe("iRODS Get Handler", func() {

When("a non-existent path is given", func() {
var r *http.Request
var err error
var handler http.Handler

BeforeEach(func() {
url := fmt.Sprintf("/get?%s=/no/such/file.txt", server.HTTPParamPath)
r, err = http.NewRequest("GET", url, nil)
handler = http.StripPrefix(server.EndpointAPI,
server.HandleIRODSGet(suiteLogger, account))

objPath := path.Join(workColl, "no", "such", "file.txt")
getURL, err := url.JoinPath(server.EndpointAPI, objPath)
Expect(err).NotTo(HaveOccurred())

r, err = http.NewRequest("GET", getURL, nil)
Expect(err).NotTo(HaveOccurred())
})

It("should return NotFound", func() {
handler := server.HandleIRODSGet(suiteLogger, account)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, r)

Expand All @@ -82,31 +88,38 @@ var _ = Describe("iRODS Get Handler", func() {

When("a valid data object path is given", func() {
var r *http.Request
var err error
var handler http.Handler

BeforeEach(func() {
path := path.Join(workColl, testFile)
url := fmt.Sprintf("/get?%s=%s", server.HTTPParamPath, path)
r, err = http.NewRequest("GET", url, nil)
handler = http.StripPrefix(server.EndpointAPI,
server.HandleIRODSGet(suiteLogger, account))

objPath := path.Join(workColl, testFile)
getURL, err := url.JoinPath(server.EndpointAPI, objPath)
Expect(err).NotTo(HaveOccurred())

r, err = http.NewRequest("GET", getURL, nil)
Expect(err).NotTo(HaveOccurred())
})

When("the data object file does not have public read permissions", func() {
It("should return Forbidden", func() {
handler := server.HandleIRODSGet(suiteLogger, account)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, r)

Expect(rec.Code).To(Equal(http.StatusForbidden))
})
})

When("the data object does have public read permissions", func() {
When("the data object has public read permissions", func() {
var conn *connection.IRODSConnection
var acl []*types.IRODSAccess
var err error

BeforeEach(func() {
handler = http.StripPrefix(server.EndpointAPI,
server.HandleIRODSGet(suiteLogger, account))

conn, err = irodsFS.GetIOConnection()
Expect(err).NotTo(HaveOccurred())

Expand All @@ -130,7 +143,7 @@ var _ = Describe("iRODS Get Handler", func() {

if ac.UserName == server.PublicUser &&
ac.UserZone == testZone &&
server.LevelsEqual(ac.AccessLevel, types.IRODSAccessLevelReadObject) {
ac.AccessLevel == types.IRODSAccessLevelReadObject {
publicAccess = true
}
}
Expand All @@ -142,15 +155,13 @@ var _ = Describe("iRODS Get Handler", func() {
})

It("should return OK", func() {
handler := server.HandleIRODSGet(suiteLogger, account)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, r)

Expect(rec.Code).To(Equal(http.StatusOK))
})

It("should serve the correct body content", func() {
handler := server.HandleIRODSGet(suiteLogger, account)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, r)

Expand Down
28 changes: 13 additions & 15 deletions server/irods.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"time"

"github.com/cyverse/go-irodsclient/fs"
Expand Down Expand Up @@ -60,9 +59,12 @@ func IRODSEnvFilePath() string {
// InitIRODS initialises the iRODS environment by creating a populated auth file if it
// does not already exist. This avoids the need to have `iinit` present on the server
// host.
func InitIRODS(manager *icommands.ICommandsEnvironmentManager, password string) error {
func InitIRODS(logger zerolog.Logger, manager *icommands.ICommandsEnvironmentManager, password string) error {
authFile := manager.GetPasswordFilePath()
if _, err := os.Stat(authFile); err != nil && errors.Is(err, os.ErrNotExist) {
logger.Info().
Str("path", authFile).
Msg("Creating an iRODS auth file because one does not exist")
return icommands.EncodePasswordFile(authFile, password, os.Getuid())
}
return nil
Expand Down Expand Up @@ -121,23 +123,19 @@ func NewIRODSAccount(logger zerolog.Logger,
Str("zone", account.ClientZone).
Str("user", account.ClientUser).
Str("auth_scheme", string(account.AuthenticationScheme)).
Bool("cs_neg_required", account.ClientServerNegotiation).
Str("cs_neg_policy", string(account.CSNegotiationPolicy)).
Str("ca_cert_path", account.SSLConfiguration.CACertificatePath).
Str("ca_cert_file", account.SSLConfiguration.CACertificateFile).
Str("enc_alg", account.SSLConfiguration.EncryptionAlgorithm).
Int("key_size", account.SSLConfiguration.EncryptionKeySize).
Int("salt_size", account.SSLConfiguration.SaltSize).
Int("hash_rounds", account.SSLConfiguration.HashRounds).
Msg("iRODS account created")

return account, nil
}

// LevelsEqual compares two iRODS access levels for equality, normalising for the
// differences between iRODS 4.2.x and 4.3.x. This is a workaround for the issue until
// it's addressed upstream in go-irodsclient.
//
// See https://github.com/cyverse/go-irodsclient/issues/38
func LevelsEqual(a types.IRODSAccessLevelType, b types.IRODSAccessLevelType) bool {
normalise := func(lvl types.IRODSAccessLevelType) string {
return strings.ReplaceAll(string(lvl), " ", "_")
}
return normalise(a) == normalise(b)
}

// isPublicReadable checks if the data object at the given path is readable by the
// public user of the zone hosting the file.
//
Expand Down Expand Up @@ -169,7 +167,7 @@ func isPublicReadable(logger zerolog.Logger, filesystem *fs.FileSystem,

if effectiveUserZone == pathZone &&
ac.UserName == PublicUser &&
LevelsEqual(ac.AccessLevel, types.IRODSAccessLevelReadObject) {
ac.AccessLevel == types.IRODSAccessLevelReadObject {
logger.Trace().
Str("path", path).
Msg("Public read access found")
Expand Down
31 changes: 21 additions & 10 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,30 @@ import (
"net/http"
)

const HTTPHeaderCorrelationID = "X-Correlation-ID"
const HTTPForwardedFor = "X-Forwarded-For"
const (
HeaderCorrelationID = "X-Correlation-ID"
HeaderForwardedFor = "X-Forwarded-For"
)

const HTTPParamPath = "path"
const (
EndpointRoot = "/"
EndPointFavicon = "/favicon.ico"
EndpointAPI = EndpointRoot + "api/v1/"
)

func (server *SqyrrlServer) addRoutes(mux *http.ServeMux) {
logRequests := addRequestLogger(server.logger)
correlate := addCorrelationID(server.logger)
getter := HandleIRODSGet(server.logger, server.account)

// The /get endpoint is used to retrieve files from iRODS
mux.Handle("/get", correlate(logRequests(getter)))
logRequest := AddRequestLogger(server.logger)
correlate := AddCorrelationID(server.logger)
getObject := http.StripPrefix(EndpointAPI, HandleIRODSGet(server.logger, server.account))

// The home page is currently a placeholder static page showing the version
mux.Handle("/", correlate(logRequests(HandleHomePage(server.logger))))
//
// Any requests relative to the root are redirected to the API endpoint
mux.Handle("GET "+EndpointRoot, correlate(logRequest(HandleHomePage(server.logger))))

// There is no favicon, this is just to log requests
mux.Handle("GET "+EndPointFavicon, logRequest(http.NotFoundHandler()))

// The API endpoint is used to access files in iRODS
mux.Handle("GET "+EndpointAPI, correlate(logRequest(getObject)))
}
24 changes: 11 additions & 13 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func NewSqyrrlServer(logger zerolog.Logger, config Config) (*SqyrrlServer, error
return nil, err
}

sublogger := logger.With().
subLogger := logger.With().
Str("hostname", hostname).
Str("component", "server").Logger()

Expand All @@ -131,13 +131,13 @@ func NewSqyrrlServer(logger zerolog.Logger, config Config) (*SqyrrlServer, error
}

if err := manager.SetEnvironmentFilePath(config.EnvFilePath); err != nil {
sublogger.Err(err).
subLogger.Err(err).
Str("path", config.EnvFilePath).
Msg("Failed to set the iRODS environment file path")
return nil, err
}

account, err := NewIRODSAccount(sublogger, manager)
account, err := NewIRODSAccount(subLogger, manager)
if err != nil {
logger.Err(err).Msg("Failed to get an iRODS account")
return nil, err
Expand All @@ -155,7 +155,7 @@ func NewSqyrrlServer(logger zerolog.Logger, config Config) (*SqyrrlServer, error
return serverCtx
}},
serverCtx,
sublogger,
subLogger,
manager,
account,
}
Expand Down Expand Up @@ -269,10 +269,10 @@ func (server *SqyrrlServer) waitAndShutdown() {
logger.Info().Msg("Server shutdown cleanly")
}

// addRequestLogger adds an HTTP request suiteLogger to the handler chain.
// AddRequestLogger adds an HTTP request suiteLogger to the handler chain.
//
// If a correlation ID is present in the request context, it is logged.
func addRequestLogger(logger zerolog.Logger) HandlerChain {
func AddRequestLogger(logger zerolog.Logger) HandlerChain {
return func(next http.Handler) http.Handler {
lh := hlog.NewHandler(logger)

Expand All @@ -290,28 +290,26 @@ func addRequestLogger(logger zerolog.Logger) HandlerChain {
Str("method", r.Method).
Str("url", r.URL.RequestURI()).
Str("remote_addr", r.RemoteAddr).
Str("forwarded_for", r.Header.Get(HTTPForwardedFor)).
Str("forwarded_for", r.Header.Get(HeaderForwardedFor)).
Str("user_agent", r.UserAgent()).
Msg("Request served")
})
return lh(ah(next))
}
}

// Check the "Accept" header

// addCorrelationID adds a correlation ID to the request context and response headers.
func addCorrelationID(logger zerolog.Logger) HandlerChain {
// AddCorrelationID adds a correlation ID to the request context and response headers.
func AddCorrelationID(logger zerolog.Logger) HandlerChain {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var corrID string
if corrID = r.Header.Get(HTTPHeaderCorrelationID); corrID == "" {
if corrID = r.Header.Get(HeaderCorrelationID); corrID == "" {
corrID = xid.New().String()
logger.Trace().
Str("correlation_id", corrID).
Str("url", r.URL.RequestURI()).
Msg("Creating a new correlation ID")
w.Header().Add(HTTPHeaderCorrelationID, corrID)
w.Header().Add(HeaderCorrelationID, corrID)
} else {
logger.Trace().
Str("correlation_id", corrID).
Expand Down
2 changes: 1 addition & 1 deletion server/server_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ var _ = BeforeSuite(func() {
err = manager.SetEnvironmentFilePath(iRODSEnvFile)
Expect(err).NotTo(HaveOccurred())

err = server.InitIRODS(manager, "irods")
err = server.InitIRODS(suiteLogger, manager, "irods")
Expect(err).NotTo(HaveOccurred())

account, err = server.NewIRODSAccount(suiteLogger, manager)
Expand Down

0 comments on commit 8e52c06

Please sign in to comment.