Skip to content

Commit

Permalink
g
Browse files Browse the repository at this point in the history
  • Loading branch information
komuw committed Aug 17, 2024
1 parent cbb1ea9 commit 6e311a6
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 38 deletions.
37 changes: 20 additions & 17 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config

import (
"crypto/x509"
"errors"
"fmt"
"log/slog"
"net/http"
Expand Down Expand Up @@ -187,16 +188,17 @@ func (o Opts) GoString() string {
// domain is the domain name of your website. It can be an exact domain, subdomain or wildcard.
// port is the TLS port where the server will listen on. Http requests will also redirected to that port.
//
// logger is an [slog.Logger] that will be used for logging.
//
// secretKey is used for securing signed data. It should be unique & kept secret.
// If it becomes compromised, generate a new one and restart your application using the new one.
//
// strategy is the algorithm to use when fetching the client's IP address; see [ClientIPstrategy].
// It is important to choose your strategy carefully, see the warning in [ClientIPstrategy].
//
// logger is an [slog.Logger] that will be used for logging.
//
// logFunc is a function that dictates what/how middleware is going to log. It is also used to log any recovered panics in the middleware.
// If it is nil, no logging happens in the middleware. For server logging, see [logger]
// This function is not used in the server.
// If it is nil, a suitable default is used. To disable logging, use a function that does nothing.
//
// rateLimit is the maximum requests allowed (from one IP address) per second. If it is les than 1.0, [DefaultRateLimit] is used instead.
//
Expand All @@ -219,8 +221,6 @@ func (o Opts) GoString() string {
// sessionAntiReplayFunc is the function used to return a token that will be used to try and mitigate against [replay attacks]. This mitigation not foolproof.
// If it is nil, [DefaultSessionAntiReplayFunc] is used instead.
//
// logger is an [slog.Logger] that will be used for logging in the server but not middleware. For middleware see [logFunc]
//
// maxBodyBytes is the maximum size in bytes for incoming request bodies. If this is zero, a reasonable default is used.
//
// serverLogLevel is the log level of the logger that will be passed into [http.Server.ErrorLog]
Expand Down Expand Up @@ -254,6 +254,7 @@ func New(
// common
domain string,
port uint16,
logger *slog.Logger,

// middleware
secretKey string,
Expand All @@ -272,7 +273,6 @@ func New(
sessionCookieDuration time.Duration,
sessionAntiReplayFunc func(r http.Request) string,
// server
logger *slog.Logger,
maxBodyBytes uint64,
serverLogLevel slog.Level,
readHeaderTimeout time.Duration,
Expand All @@ -290,6 +290,7 @@ func New(
middlewareOpts, err := newMiddlewareOpts(
domain,
port,
logger,
secretKey,
strategy,
logFunc,
Expand All @@ -315,7 +316,6 @@ func New(
serverOpts: newServerOpts(
domain,
port,
logger,
maxBodyBytes,
serverLogLevel,
readHeaderTimeout,
Expand Down Expand Up @@ -352,6 +352,7 @@ func WithOpts(
// common
domain,
httpsPort,
logger,

// middleware
secretKey,
Expand All @@ -370,7 +371,6 @@ func WithOpts(
DefaultSessionCookieDuration,
DefaultSessionAntiReplayFunc,
// server
logger,
DefaultMaxBodyBytes,
DefaultServerLogLevel,
defaultReadHeaderTimeout,
Expand Down Expand Up @@ -401,6 +401,7 @@ func DevOpts(logger *slog.Logger, secretKey string) Opts {
// common
domain,
httpsPort,
logger,

// middleware
secretKey,
Expand All @@ -419,7 +420,6 @@ func DevOpts(logger *slog.Logger, secretKey string) Opts {
DefaultSessionCookieDuration,
DefaultSessionAntiReplayFunc,
// server
logger,
DefaultMaxBodyBytes,
DefaultServerLogLevel,
defaultReadHeaderTimeout,
Expand Down Expand Up @@ -457,6 +457,7 @@ func CertOpts(
// common
domain,
httpsPort,
logger,

// middleware
secretKey,
Expand All @@ -475,7 +476,6 @@ func CertOpts(
DefaultSessionCookieDuration,
DefaultSessionAntiReplayFunc,
// server
logger,
DefaultMaxBodyBytes,
DefaultServerLogLevel,
defaultReadHeaderTimeout,
Expand Down Expand Up @@ -516,6 +516,7 @@ func AcmeOpts(
// common
domain,
httpsPort,
logger,

// middleware
secretKey,
Expand All @@ -534,7 +535,6 @@ func AcmeOpts(
DefaultSessionCookieDuration,
DefaultSessionAntiReplayFunc,
// server
logger,
DefaultMaxBodyBytes,
DefaultServerLogLevel,
defaultReadHeaderTimeout,
Expand Down Expand Up @@ -574,6 +574,7 @@ func LetsEncryptOpts(
// common
domain,
httpsPort,
logger,

// middleware
secretKey,
Expand All @@ -592,7 +593,6 @@ func LetsEncryptOpts(
DefaultSessionCookieDuration,
DefaultSessionAntiReplayFunc,
// server
logger,
DefaultMaxBodyBytes,
DefaultServerLogLevel,
defaultReadHeaderTimeout,
Expand Down Expand Up @@ -629,6 +629,8 @@ func (s secureKey) GoString() string {
type middlewareOpts struct {
Domain string
HttpsPort uint16
Logger *slog.Logger

// When printing a struct, fmt does not invoke custom formatting methods on unexported fields.
// We thus need to make this field to be exported.
// - https://pkg.go.dev/fmt#:~:text=When%20printing%20a%20struct
Expand Down Expand Up @@ -707,6 +709,7 @@ func (m middlewareOpts) GoString() string {
func newMiddlewareOpts(
domain string,
httpsPort uint16,
logger *slog.Logger,
secretKey string,
strategy ClientIPstrategy,
logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any),
Expand All @@ -732,6 +735,10 @@ func newMiddlewareOpts(
domain = domain[2:]
}

if logger == nil && logFunc == nil {
return middlewareOpts{}, errors.New("both logger and logFunc should not be nil at the same time")
}

if err := key.IsSecure(secretKey); err != nil {
return middlewareOpts{}, err
}
Expand Down Expand Up @@ -760,6 +767,7 @@ func newMiddlewareOpts(
return middlewareOpts{
Domain: domain,
HttpsPort: httpsPort,
Logger: logger,
SecretKey: secureKey(secretKey),
Strategy: strategy,
LogFunc: logFunc,
Expand Down Expand Up @@ -831,7 +839,6 @@ func (t tlsOpts) GoString() string {
// serverOpts are the various parameters(optionals) that can be used to configure a HTTP server.
type serverOpts struct {
port uint16 // tcp port is a 16bit unsigned integer.
Logger *slog.Logger
MaxBodyBytes uint64 // max size of request body allowed.
ServerLogLevel slog.Level
ReadHeaderTimeout time.Duration
Expand All @@ -853,7 +860,6 @@ type serverOpts struct {
func newServerOpts(
domain string,
port uint16,
logger *slog.Logger,
maxBodyBytes uint64,
serverLogLevel slog.Level,
readHeaderTimeout time.Duration,
Expand Down Expand Up @@ -899,7 +905,6 @@ func newServerOpts(

return serverOpts{
port: port,
Logger: logger,
MaxBodyBytes: maxBodyBytes,
ServerLogLevel: serverLogLevel,
ReadHeaderTimeout: readHeaderTimeout,
Expand Down Expand Up @@ -930,7 +935,6 @@ func newServerOpts(
func (s serverOpts) String() string {
return fmt.Sprintf(`serverOpts{
port: %v,
Logger: %v,
MaxBodyBytes: %v,
ServerLogLevel: %v,
ReadHeaderTimeout: %v,
Expand All @@ -946,7 +950,6 @@ func (s serverOpts) String() string {
HttpPort: %v,
}`,
s.port,
s.Logger,
s.MaxBodyBytes,
s.ServerLogLevel,
s.ReadHeaderTimeout,
Expand Down
7 changes: 5 additions & 2 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func validOpts(t *testing.T) Opts {
"example.com",
// The https port that our application will be listening on.
443,
// Logger.
l,
// The security key to use for securing signed data.
"super-h@rd-Pas1word",
// In this case, the actual client IP address is fetched from the given http header.
Expand Down Expand Up @@ -63,8 +65,6 @@ func validOpts(t *testing.T) Opts {
// Use a given header to try and mitigate against replay-attacks.
func(r http.Request) string { return r.Header.Get("Anti-Replay") },
//
// Logger.
l,
// The maximum size in bytes for incoming request bodies.
2*1024*1024,
// Log level of the logger that will be passed into [http.Server.ErrorLog]
Expand Down Expand Up @@ -135,6 +135,7 @@ func TestNewMiddlewareOpts(t *testing.T) {
o, err := newMiddlewareOpts(
opt.Domain,
opt.HttpsPort,
slog.Default(),
string(opt.SecretKey),
opt.Strategy,
nil,
Expand Down Expand Up @@ -195,6 +196,7 @@ func TestNewMiddlewareOptsDomain(t *testing.T) {
_, err := newMiddlewareOpts(
tt.domain,
443,
slog.Default(),
tst.SecretKey(),
clientip.DirectIpStrategy,
nil,
Expand All @@ -216,6 +218,7 @@ func TestNewMiddlewareOptsDomain(t *testing.T) {
_, err := newMiddlewareOpts(
tt.domain,
443,
slog.Default(),
tst.SecretKey(),
clientip.DirectIpStrategy,
nil,
Expand Down
4 changes: 2 additions & 2 deletions config/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ func ExampleNew() {
"example.com",
// The https port that our application will be listening on.
443,
// Logger.
l,
// The security key to use for securing signed data.
"super-h@rd-Pas1word",
// In this case, the actual client IP address is fetched from the given http header.
Expand Down Expand Up @@ -60,8 +62,6 @@ func ExampleNew() {
// Use a given header to try and mitigate against replay-attacks.
func(r http.Request) string { return r.Header.Get("Anti-Replay") },
//
// Logger.
l,
// The maximum size in bytes for incoming request bodies.
2*1024*1024,
// Log level of the logger that will be passed into [http.Server.ErrorLog]
Expand Down
54 changes: 51 additions & 3 deletions middleware/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,28 @@ import (
"bufio"
"fmt"
"io"
"log/slog"
mathRand "math/rand/v2"
"net"
"net/http"
"time"

"github.com/komuw/ong/log"
)

// logger is a middleware that logs http requests and responses using [log.Logger].
func logger(
wrappedHandler http.Handler,
logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any),
l *slog.Logger,
) http.HandlerFunc {
// The middleware should ideally share the same logger as the app.
// That way, if the app logs an error, the middleware logs are also flushed.
// That's one reason why we pass in the logFunc.This makes debugging easier for developers.
// Another reason is so that app developers are in control of what(and how) exactly gets logged.
if logFunc == nil {
logFunc = defaultLogFunc(l)
}

return func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
Expand Down Expand Up @@ -47,9 +55,7 @@ func logger(
// 1xx class or the modified headers are trailers.
lrw.Header().Del(ongMiddlewareErrorHeader)

if logFunc != nil {
logFunc(w, *r, lrw.code, flds)
}
logFunc(w, *r, lrw.code, flds)
}()

wrappedHandler.ServeHTTP(lrw, r)
Expand Down Expand Up @@ -146,3 +152,45 @@ func (lrw *logRW) ReadFrom(src io.Reader) (n int64, err error) {
func (lrw *logRW) Unwrap() http.ResponseWriter {
return lrw.ResponseWriter
}

// defaultLogFunc is the logging function used if the user did not explicitly provide one.
func defaultLogFunc(l *slog.Logger) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) {
const (
msg = "http_server"
// rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default.
rateShedSamplePercent = 10
)

if l == nil {
return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) {}
}

return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) {
// Each request should get its own context. That's why we call `log.WithID` for every request.
reqL := log.WithID(r.Context(), l)

if (statusCode == http.StatusServiceUnavailable || statusCode == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" {
// We are either in load shedding or rate-limiting.
// Only log (rateShedSamplePercent)% of the errors.
shouldLog := mathRand.IntN(100) <= rateShedSamplePercent
if shouldLog {
reqL.Error(msg, fields...)
return
}
}

if statusCode >= http.StatusBadRequest {
// Both client and server errors.
if statusCode == http.StatusNotFound || statusCode == http.StatusMethodNotAllowed || statusCode == http.StatusTeapot {
// These ones are more of an annoyance, than been actual errors.
reqL.Info(msg, fields...)
return
}

reqL.Error(msg, fields...)
return
}

reqL.Info(msg, fields...)
}
}
Loading

0 comments on commit 6e311a6

Please sign in to comment.