diff --git a/server/handlers.go b/server/handlers.go index 9f4f954..5863467 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -18,12 +18,10 @@ package server import ( - "fmt" - "net/http" - "path" - "github.com/cyverse/go-irodsclient/irods/types" "github.com/rs/zerolog" + "net/http" + "path" ) // HandleHomePage is a handler for the static home page. @@ -31,6 +29,17 @@ func HandleHomePage(logger zerolog.Logger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Trace().Msg("HomeHandler called") + requestPath := r.URL.Path + + if r.URL.Path != "/" { + redirect := path.Join(EndpointAPI, requestPath) + logger.Trace(). + Str("from", requestPath). + Str("to", redirect). + Msg("Redirecting to API") + http.Redirect(w, r, redirect, http.StatusPermanentRedirect) + } + type customData struct { URL string Version string @@ -51,18 +60,12 @@ func HandleIRODSGet(logger zerolog.Logger, account *types.IRODSAccount) http.Han return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Trace().Msg("iRODS get handler called") - if !r.URL.Query().Has(HTTPParamPath) { - writeErrorResponse(logger, w, http.StatusBadRequest, - fmt.Sprintf("'%s' parameter is missing", HTTPParamPath)) - return - } - var corrID string if val := r.Context().Value(correlationIDKey); val != nil { corrID = val.(string) } - dirtyPath := r.URL.Query().Get(HTTPParamPath) + dirtyPath := r.URL.Path sanPath := userInputPolicy.Sanitize(dirtyPath) if sanPath != dirtyPath { logger.Warn(). @@ -76,6 +79,9 @@ func HandleIRODSGet(logger zerolog.Logger, account *types.IRODSAccount) http.Han Str("correlation_id", corrID). Str("irods", "get").Logger() - getFileRange(rodsLogger, w, r, account, path.Clean(sanPath)) + sanPath = path.Clean("/" + sanPath) + logger.Debug().Str("path", sanPath).Msg("Getting iRODS data object") + + getFileRange(rodsLogger, w, r, account, sanPath) }) } diff --git a/server/handlers_test.go b/server/handlers_test.go index b41a0eb..5c0e8c2 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -21,6 +21,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "path" "path/filepath" @@ -63,16 +64,21 @@ var _ = Describe("iRODS Get Handler", func() { When("a non-existent path is given", func() { var r *http.Request - var err error + var handler http.Handler BeforeEach(func() { - url := fmt.Sprintf("/get?%s=/no/such/file.txt", server.HTTPParamPath) - r, err = http.NewRequest("GET", url, nil) + handler = http.StripPrefix(server.EndpointAPI, + server.HandleIRODSGet(suiteLogger, account)) + + objPath := path.Join(workColl, "no", "such", "file.txt") + getURL, err := url.JoinPath(server.EndpointAPI, objPath) + Expect(err).NotTo(HaveOccurred()) + + r, err = http.NewRequest("GET", getURL, nil) Expect(err).NotTo(HaveOccurred()) }) It("should return NotFound", func() { - handler := server.HandleIRODSGet(suiteLogger, account) rec := httptest.NewRecorder() handler.ServeHTTP(rec, r) @@ -82,18 +88,22 @@ var _ = Describe("iRODS Get Handler", func() { When("a valid data object path is given", func() { var r *http.Request - var err error + var handler http.Handler BeforeEach(func() { - path := path.Join(workColl, testFile) - url := fmt.Sprintf("/get?%s=%s", server.HTTPParamPath, path) - r, err = http.NewRequest("GET", url, nil) + handler = http.StripPrefix(server.EndpointAPI, + server.HandleIRODSGet(suiteLogger, account)) + + objPath := path.Join(workColl, testFile) + getURL, err := url.JoinPath(server.EndpointAPI, objPath) + Expect(err).NotTo(HaveOccurred()) + + r, err = http.NewRequest("GET", getURL, nil) Expect(err).NotTo(HaveOccurred()) }) When("the data object file does not have public read permissions", func() { It("should return Forbidden", func() { - handler := server.HandleIRODSGet(suiteLogger, account) rec := httptest.NewRecorder() handler.ServeHTTP(rec, r) @@ -101,12 +111,15 @@ var _ = Describe("iRODS Get Handler", func() { }) }) - When("the data object does have public read permissions", func() { + When("the data object has public read permissions", func() { var conn *connection.IRODSConnection var acl []*types.IRODSAccess var err error BeforeEach(func() { + handler = http.StripPrefix(server.EndpointAPI, + server.HandleIRODSGet(suiteLogger, account)) + conn, err = irodsFS.GetIOConnection() Expect(err).NotTo(HaveOccurred()) @@ -130,7 +143,7 @@ var _ = Describe("iRODS Get Handler", func() { if ac.UserName == server.PublicUser && ac.UserZone == testZone && - server.LevelsEqual(ac.AccessLevel, types.IRODSAccessLevelReadObject) { + ac.AccessLevel == types.IRODSAccessLevelReadObject { publicAccess = true } } @@ -142,7 +155,6 @@ var _ = Describe("iRODS Get Handler", func() { }) It("should return OK", func() { - handler := server.HandleIRODSGet(suiteLogger, account) rec := httptest.NewRecorder() handler.ServeHTTP(rec, r) @@ -150,7 +162,6 @@ var _ = Describe("iRODS Get Handler", func() { }) It("should serve the correct body content", func() { - handler := server.HandleIRODSGet(suiteLogger, account) rec := httptest.NewRecorder() handler.ServeHTTP(rec, r) diff --git a/server/irods.go b/server/irods.go index a750fb9..4483639 100644 --- a/server/irods.go +++ b/server/irods.go @@ -22,7 +22,6 @@ import ( "net/http" "os" "path/filepath" - "strings" "time" "github.com/cyverse/go-irodsclient/fs" @@ -60,9 +59,12 @@ func IRODSEnvFilePath() string { // InitIRODS initialises the iRODS environment by creating a populated auth file if it // does not already exist. This avoids the need to have `iinit` present on the server // host. -func InitIRODS(manager *icommands.ICommandsEnvironmentManager, password string) error { +func InitIRODS(logger zerolog.Logger, manager *icommands.ICommandsEnvironmentManager, password string) error { authFile := manager.GetPasswordFilePath() if _, err := os.Stat(authFile); err != nil && errors.Is(err, os.ErrNotExist) { + logger.Info(). + Str("path", authFile). + Msg("Creating an iRODS auth file because one does not exist") return icommands.EncodePasswordFile(authFile, password, os.Getuid()) } return nil @@ -121,23 +123,19 @@ func NewIRODSAccount(logger zerolog.Logger, Str("zone", account.ClientZone). Str("user", account.ClientUser). Str("auth_scheme", string(account.AuthenticationScheme)). + Bool("cs_neg_required", account.ClientServerNegotiation). + Str("cs_neg_policy", string(account.CSNegotiationPolicy)). + Str("ca_cert_path", account.SSLConfiguration.CACertificatePath). + Str("ca_cert_file", account.SSLConfiguration.CACertificateFile). + Str("enc_alg", account.SSLConfiguration.EncryptionAlgorithm). + Int("key_size", account.SSLConfiguration.EncryptionKeySize). + Int("salt_size", account.SSLConfiguration.SaltSize). + Int("hash_rounds", account.SSLConfiguration.HashRounds). Msg("iRODS account created") return account, nil } -// LevelsEqual compares two iRODS access levels for equality, normalising for the -// differences between iRODS 4.2.x and 4.3.x. This is a workaround for the issue until -// it's addressed upstream in go-irodsclient. -// -// See https://github.com/cyverse/go-irodsclient/issues/38 -func LevelsEqual(a types.IRODSAccessLevelType, b types.IRODSAccessLevelType) bool { - normalise := func(lvl types.IRODSAccessLevelType) string { - return strings.ReplaceAll(string(lvl), " ", "_") - } - return normalise(a) == normalise(b) -} - // isPublicReadable checks if the data object at the given path is readable by the // public user of the zone hosting the file. // @@ -169,7 +167,7 @@ func isPublicReadable(logger zerolog.Logger, filesystem *fs.FileSystem, if effectiveUserZone == pathZone && ac.UserName == PublicUser && - LevelsEqual(ac.AccessLevel, types.IRODSAccessLevelReadObject) { + ac.AccessLevel == types.IRODSAccessLevelReadObject { logger.Trace(). Str("path", path). Msg("Public read access found") diff --git a/server/routes.go b/server/routes.go index 28dabf1..8aa8e01 100644 --- a/server/routes.go +++ b/server/routes.go @@ -21,19 +21,30 @@ import ( "net/http" ) -const HTTPHeaderCorrelationID = "X-Correlation-ID" -const HTTPForwardedFor = "X-Forwarded-For" +const ( + HeaderCorrelationID = "X-Correlation-ID" + HeaderForwardedFor = "X-Forwarded-For" +) -const HTTPParamPath = "path" +const ( + EndpointRoot = "/" + EndPointFavicon = "/favicon.ico" + EndpointAPI = EndpointRoot + "api/v1/" +) func (server *SqyrrlServer) addRoutes(mux *http.ServeMux) { - logRequests := addRequestLogger(server.logger) - correlate := addCorrelationID(server.logger) - getter := HandleIRODSGet(server.logger, server.account) - - // The /get endpoint is used to retrieve files from iRODS - mux.Handle("/get", correlate(logRequests(getter))) + logRequest := AddRequestLogger(server.logger) + correlate := AddCorrelationID(server.logger) + getObject := http.StripPrefix(EndpointAPI, HandleIRODSGet(server.logger, server.account)) // The home page is currently a placeholder static page showing the version - mux.Handle("/", correlate(logRequests(HandleHomePage(server.logger)))) + // + // Any requests relative to the root are redirected to the API endpoint + mux.Handle("GET "+EndpointRoot, correlate(logRequest(HandleHomePage(server.logger)))) + + // There is no favicon, this is just to log requests + mux.Handle("GET "+EndPointFavicon, logRequest(http.NotFoundHandler())) + + // The API endpoint is used to access files in iRODS + mux.Handle("GET "+EndpointAPI, correlate(logRequest(getObject))) } diff --git a/server/server.go b/server/server.go index a0d1166..14c51eb 100644 --- a/server/server.go +++ b/server/server.go @@ -120,7 +120,7 @@ func NewSqyrrlServer(logger zerolog.Logger, config Config) (*SqyrrlServer, error return nil, err } - sublogger := logger.With(). + subLogger := logger.With(). Str("hostname", hostname). Str("component", "server").Logger() @@ -131,13 +131,13 @@ func NewSqyrrlServer(logger zerolog.Logger, config Config) (*SqyrrlServer, error } if err := manager.SetEnvironmentFilePath(config.EnvFilePath); err != nil { - sublogger.Err(err). + subLogger.Err(err). Str("path", config.EnvFilePath). Msg("Failed to set the iRODS environment file path") return nil, err } - account, err := NewIRODSAccount(sublogger, manager) + account, err := NewIRODSAccount(subLogger, manager) if err != nil { logger.Err(err).Msg("Failed to get an iRODS account") return nil, err @@ -155,7 +155,7 @@ func NewSqyrrlServer(logger zerolog.Logger, config Config) (*SqyrrlServer, error return serverCtx }}, serverCtx, - sublogger, + subLogger, manager, account, } @@ -269,10 +269,10 @@ func (server *SqyrrlServer) waitAndShutdown() { logger.Info().Msg("Server shutdown cleanly") } -// addRequestLogger adds an HTTP request suiteLogger to the handler chain. +// AddRequestLogger adds an HTTP request suiteLogger to the handler chain. // // If a correlation ID is present in the request context, it is logged. -func addRequestLogger(logger zerolog.Logger) HandlerChain { +func AddRequestLogger(logger zerolog.Logger) HandlerChain { return func(next http.Handler) http.Handler { lh := hlog.NewHandler(logger) @@ -290,7 +290,7 @@ func addRequestLogger(logger zerolog.Logger) HandlerChain { Str("method", r.Method). Str("url", r.URL.RequestURI()). Str("remote_addr", r.RemoteAddr). - Str("forwarded_for", r.Header.Get(HTTPForwardedFor)). + Str("forwarded_for", r.Header.Get(HeaderForwardedFor)). Str("user_agent", r.UserAgent()). Msg("Request served") }) @@ -298,20 +298,18 @@ func addRequestLogger(logger zerolog.Logger) HandlerChain { } } -// Check the "Accept" header - -// addCorrelationID adds a correlation ID to the request context and response headers. -func addCorrelationID(logger zerolog.Logger) HandlerChain { +// AddCorrelationID adds a correlation ID to the request context and response headers. +func AddCorrelationID(logger zerolog.Logger) HandlerChain { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var corrID string - if corrID = r.Header.Get(HTTPHeaderCorrelationID); corrID == "" { + if corrID = r.Header.Get(HeaderCorrelationID); corrID == "" { corrID = xid.New().String() logger.Trace(). Str("correlation_id", corrID). Str("url", r.URL.RequestURI()). Msg("Creating a new correlation ID") - w.Header().Add(HTTPHeaderCorrelationID, corrID) + w.Header().Add(HeaderCorrelationID, corrID) } else { logger.Trace(). Str("correlation_id", corrID). diff --git a/server/server_suite_test.go b/server/server_suite_test.go index 1601238..69693c0 100644 --- a/server/server_suite_test.go +++ b/server/server_suite_test.go @@ -63,7 +63,7 @@ var _ = BeforeSuite(func() { err = manager.SetEnvironmentFilePath(iRODSEnvFile) Expect(err).NotTo(HaveOccurred()) - err = server.InitIRODS(manager, "irods") + err = server.InitIRODS(suiteLogger, manager, "irods") Expect(err).NotTo(HaveOccurred()) account, err = server.NewIRODSAccount(suiteLogger, manager)