diff --git a/codec.go b/codec.go index 5fea24e..8ea8258 100644 --- a/codec.go +++ b/codec.go @@ -18,22 +18,37 @@ type ( type Codec struct { Encode EncodeFunc Decode DecodeFunc + Tags []string } -var Codecs = map[string]*Codec{ - "json": reflectCodec(json.Marshal, json.Unmarshal), - "yaml": reflectCodec(yaml.Marshal, yaml.Unmarshal), - "xml": reflectCodec(xml.Marshal, xml.Unmarshal), - "csv": newCodec(csvEncoder, csvDecoder), - "gob": newCodec(gobEncoder, gobDecoder), +func (c Codec) Tag() string { + return c.Tags[0] } -func reflectCodec(me func(src interface{}) ([]byte, error), md func(buf []byte, dst interface{}) error) *Codec { - return &Codec{Encode: marshalEncoder(me), Decode: unmarshalDecoder(md)} +var Codecs = make(map[string]*Codec) +var CodecTypes []string + +func init() { + registerCodec(reflectCodec(json.Marshal, json.Unmarshal, "application/json", "text/json")) + registerCodec(reflectCodec(yaml.Marshal, yaml.Unmarshal, "application/x-yaml", "text/x-yaml")) + registerCodec(reflectCodec(xml.Marshal, xml.Unmarshal, "application/xml", "text/xml")) + registerCodec(newCodec(csvEncoder, csvDecoder, "application/csv", "text/csv")) + registerCodec(newCodec(gobEncoder, gobDecoder, "application/x-gob", "text/x-gob")) +} + +func registerCodec(codec *Codec) { + for _, tag := range codec.Tags { + Codecs[tag] = codec + CodecTypes = append(CodecTypes, tag) + } +} + +func reflectCodec(me func(src interface{}) ([]byte, error), md func(buf []byte, dst interface{}) error, tags ...string) *Codec { + return &Codec{Encode: marshalEncoder(me), Decode: unmarshalDecoder(md), Tags: tags} } -func newCodec(encoder EncodeFunc, decoder DecodeFunc) *Codec { - return &Codec{Encode: encoder, Decode: decoder} +func newCodec(encoder EncodeFunc, decoder DecodeFunc, tags ...string) *Codec { + return &Codec{Encode: encoder, Decode: decoder, Tags: tags} } func unmarshalDecoder(f func(buf []byte, dst interface{}) error) DecodeFunc { diff --git a/handler.go b/handler.go index d813202..fef8fff 100644 --- a/handler.go +++ b/handler.go @@ -27,14 +27,14 @@ type Context struct { type ContentType struct{} func (h *ContentType) Serve(ctx *Context, w http.ResponseWriter, r *http.Request) { - accept := NegotiateCodec(r, ctx.Config.Codecs, ctx.Config.DefaultCodec) + accept := NegotiateCodec(r, ctx.Config.CodecTypes, ctx.Config.DefaultCodec) codec, available := ctx.Config.Codecs[accept] if !available { httpError( w, codec, http.StatusNotAcceptable, - fmt.Errorf("only able to accept %v", ctx.Config.Codecs), + fmt.Errorf("wanted codec %q, only able to accept %v", accept, ctx.Config.Codecs), ) return } @@ -42,11 +42,11 @@ func (h *ContentType) Serve(ctx *Context, w http.ResponseWriter, r *http.Request w.Header().Set(HeaderContentType, accept) } -func NegotiateCodec(r *http.Request, codecs map[string]*Codec, defaultCodec string) string { +func NegotiateCodec(r *http.Request, codecs []string, defaultCodec string) string { specs := httputil.ParseAccept(r.Header, "Accept") bestCodec, bestQ, bestWild := defaultCodec, -1.0, 3 - for codec := range codecs { + for _, codec := range codecs { for _, spec := range specs { switch { case spec.Q == 0.0: @@ -83,14 +83,9 @@ type ContentLength struct { } func (c *ContentLength) Serve(ctx *Context, w http.ResponseWriter, r *http.Request) { - var codec *Codec - if ctx.Config.Codecs != nil { - codec = ctx.Config.Codecs[w.Header().Get(HeaderContentType)] - } + codec := ctx.Config.Codecs[w.Header().Get(HeaderContentType)] switch { - case r.ContentLength == 0: - httpError(w, codec, http.StatusNoContent, nil) case r.ContentLength < c.Min: httpError( w, codec, @@ -109,10 +104,7 @@ func (c *ContentLength) Serve(ctx *Context, w http.ResponseWriter, r *http.Reque type ContentDecode struct{} func (h *ContentDecode) Serve(ctx *Context, w http.ResponseWriter, r *http.Request) { - var codec *Codec - if ctx.Config.Codecs != nil { - codec = ctx.Config.Codecs[w.Header().Get(HeaderContentType)] - } + codec := ctx.Config.Codecs[r.Header.Get(HeaderContentType)] err := getHeaderParams(r, ctx.In) if err != nil && !errors.Is(err, http.ErrBodyNotAllowed) { @@ -188,3 +180,26 @@ func getBodyParams(r *http.Request, codec *Codec, values map[string]interface{}) return nil } + +type ContentEncode struct{} + +func (h *ContentEncode) Serve(ctx *Context, w http.ResponseWriter, _ *http.Request) { + codec := ctx.Config.Codecs[w.Header().Get(HeaderContentType)] + + if codec == nil { + httpError( + w, codec, + http.StatusNotAcceptable, + fmt.Errorf("only able to accept %v", ctx.Config.Codecs), + ) + return + } + + buf, err := codec.Encode(ctx.Out) + if err != nil && !errors.Is(err, http.ErrBodyNotAllowed) { + httpError(w, codec, http.StatusInternalServerError, err) + return + } + + w.Write(buf) +} diff --git a/http.go b/http.go index 610b722..afee54e 100644 --- a/http.go +++ b/http.go @@ -2,7 +2,6 @@ package flatend import ( "context" - "crypto/tls" "net" "net/http" "sync" @@ -13,6 +12,7 @@ var _ http.Handler = (*Server)(nil) type Config struct { Codecs map[string]*Codec + CodecTypes []string DefaultCodec string Handlers []Handler @@ -21,25 +21,26 @@ type Config struct { func NewDefaultConfig() *Config { return &Config{ Codecs: Codecs, - DefaultCodec: "json", + CodecTypes: CodecTypes, + DefaultCodec: CodecTypes[0], Handlers: []Handler{ &ContentType{}, &ContentLength{Max: 10 * 1024 * 1024}, &ContentDecode{}, + &ContentEncode{}, }, } } type Server struct { - ReadTimeout time.Duration - WriteTimeout time.Duration - IdleTimeout time.Duration - ReadHeaderTimeout time.Duration + IdleTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration - MaxHeaderBytes int + MaxHeaderBytes int + ReadHeaderTimeout time.Duration - TLS *tls.Config Config *Config once sync.Once @@ -51,14 +52,12 @@ func (s *Server) init() { s.http = &http.Server{ Handler: s, - TLSConfig: s.TLS, + IdleTimeout: s.IdleTimeout, + ReadTimeout: s.ReadTimeout, + WriteTimeout: s.WriteTimeout, - ReadTimeout: s.ReadTimeout, - WriteTimeout: s.WriteTimeout, - IdleTimeout: s.IdleTimeout, + MaxHeaderBytes: s.MaxHeaderBytes, ReadHeaderTimeout: s.ReadHeaderTimeout, - - MaxHeaderBytes: s.MaxHeaderBytes, } if s.Config == nil { diff --git a/httputil/accept.go b/httputil/accept.go index 3106d1c..87aa980 100644 --- a/httputil/accept.go +++ b/httputil/accept.go @@ -18,21 +18,21 @@ func ParseAccept(header http.Header, key string) (specs []AcceptSpec) { for _, s := range header[key] { for { - spec.Value, s = expectTokenSlash(s) + spec.Value, s = ExpectTokenSlash(s) if spec.Value == "" { break } - s = skipSpace(s) + s = SkipSpace(s) spec.Q = 1.0 if len(s) > 0 && s[0] == ';' { - s = skipSpace(s[1:]) + s = SkipSpace(s[1:]) if !strings.HasPrefix(s, "q=") { break } - spec.Q, s = expectQuality(s[2:]) + spec.Q, s = ExpectQuality(s[2:]) if spec.Q < 0.0 { break } @@ -40,17 +40,17 @@ func ParseAccept(header http.Header, key string) (specs []AcceptSpec) { specs = append(specs, spec) - s = skipSpace(s) + s = SkipSpace(s) if len(s) == 0 || s[0] != ',' { break } - s = skipSpace(s[1:]) + s = SkipSpace(s[1:]) } } return specs } -func skipSpace(s string) (rest string) { +func SkipSpace(s string) (rest string) { i := 0 for ; i < len(s); i++ { if octetTypes[s[i]]&isSpace == 0 { @@ -60,7 +60,7 @@ func skipSpace(s string) (rest string) { return s[i:] } -func expectTokenSlash(s string) (token, rest string) { +func ExpectTokenSlash(s string) (token, rest string) { i := 0 for ; i < len(s); i++ { b := s[i] @@ -71,7 +71,7 @@ func expectTokenSlash(s string) (token, rest string) { return s[:i], s[i:] } -func expectQuality(s string) (q float64, rest string) { +func ExpectQuality(s string) (q float64, rest string) { switch { case len(s) == 0: return -1, ""