Skip to content

Commit

Permalink
detecting Content-Type
Browse files Browse the repository at this point in the history
  • Loading branch information
shogo82148 committed Jun 19, 2023
1 parent 4bbeab2 commit 423bf88
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 25 deletions.
105 changes: 82 additions & 23 deletions ridgenative.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,19 @@ type streamingResponseWriter struct {
wroteHeader bool
header http.Header
statusCode int
err error

// prelude is the first part of the body.
// it is used for detecting content-type.
prelude []byte
}

func newStreamingResponseWriter(w *io.PipeWriter) *streamingResponseWriter {
return &streamingResponseWriter{
w: w,
buf: bufio.NewWriter(w),
header: make(http.Header, 1),
w: w,
buf: bufio.NewWriter(w),
header: make(http.Header, 1),
prelude: make([]byte, 0, 512),
}
}

Expand All @@ -454,49 +460,102 @@ func (rw *streamingResponseWriter) WriteHeader(code int) {
log.Printf("ridgenative: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
return
}
if rw.err != nil {
return
}

if !rw.hasContentType() {
rw.header.Set("Content-Type", http.DetectContentType(rw.prelude))
}

rw.wroteHeader = true
rw.statusCode = code

// build the prelude
h := make(map[string]string, len(rw.header))
for key, value := range rw.header {
if key == "Set-Cookie" {
continue
}
h[key] = strings.Join(value, ", ")
}
cookies := rw.header.Values("Set-Cookie")
r := &streamingResponse{
StatusCode: code,
Headers: h,
Cookies: cookies,
}

data, err := json.Marshal(r)
if err != nil {
log.Printf("ridgenative: %v", err)
rw.err = fmt.Errorf("ridgenative: failed to marshal response: %w", err)
return
}
rw.buf.Write(data)
rw.buf.WriteString("\x00\x00\x00\x00\x00\x00\x00\x00")
rw.buf.Flush()
rw.wroteHeader = true
if _, err := rw.buf.Write(data); err != nil {
rw.err = err
return
}
if _, err := rw.buf.WriteString("\x00\x00\x00\x00\x00\x00\x00\x00"); err != nil {
rw.err = err
return
}
if len(rw.prelude) != 0 {
if _, err := rw.buf.Write(rw.prelude); err != nil {
rw.err = err
return
}
}
if err := rw.buf.Flush(); err != nil {
rw.err = err
}
}

func (rw *streamingResponseWriter) hasContentType() bool {
return rw.header.Get("Content-Type") != ""
}

func (rw *streamingResponseWriter) Write(data []byte) (int, error) {
var m int
if !rw.wroteHeader {
// TODO: detect content type if it is not set.
rw.WriteHeader(http.StatusOK)
if rw.hasContentType() {
rw.WriteHeader(http.StatusOK)
} else {
// save the first part of the body for detecting content-type.
data0 := data
if len(rw.prelude)+len(data0) > cap(rw.prelude) {
data0 = data0[:cap(rw.prelude)-len(rw.prelude)]
}
rw.prelude = append(rw.prelude, data0...)

if len(rw.prelude) == cap(rw.prelude) {
rw.WriteHeader(http.StatusOK)
}
m = len(data0)
data = data[m:]
if len(data) == 0 {
return m, nil
}
}
}
return rw.buf.Write(data)
n, err := rw.buf.Write(data)
return n + m, err
}

func (rw *streamingResponseWriter) closeWithError(err error) error {
if !rw.wroteHeader {
rw.WriteHeader(http.StatusOK)
}
err0 := rw.buf.Flush()
if err1 := rw.w.CloseWithError(err); err0 == nil {
err0 = err1
if rw.err != nil {
err = rw.err
}
return err0
if err0 := rw.buf.Flush(); err0 != nil {
err = err0
}
return rw.w.CloseWithError(err)
}

func (rw *streamingResponseWriter) close() error {
if !rw.wroteHeader {
rw.WriteHeader(http.StatusOK)
}
err0 := rw.buf.Flush()
if err1 := rw.w.Close(); err0 == nil {
err0 = err1
}
return err0
return rw.closeWithError(nil)
}

func (rw *streamingResponseWriter) Flush() {
Expand Down
99 changes: 97 additions & 2 deletions ridgenative_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ func TestLambdaHandlerStreaming(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if got, want := string(data), "{\"statusCode\":200}\x00\x00\x00\x00\x00\x00\x00\x00{\"hello\":\"world\"}"; got != want {
if got, want := string(data), "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"application/json\"}}\x00\x00\x00\x00\x00\x00\x00\x00{\"hello\":\"world\"}"; got != want {
t.Errorf("unexpected body: want %q, got %q", want, got)
}
})
Expand Down Expand Up @@ -819,7 +819,7 @@ func TestLambdaHandlerStreaming(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if got, want := string(buf[:n]), "{\"statusCode\":200}\x00\x00\x00\x00\x00\x00\x00\x00"; got != want {
if got, want := string(buf[:n]), "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"application/json\"}}\x00\x00\x00\x00\x00\x00\x00\x00"; got != want {
t.Errorf("unexpected body: want %q, got %q", want, got)
}

Expand All @@ -841,4 +841,99 @@ func TestLambdaHandlerStreaming(t *testing.T) {
t.Errorf("unexpected read size: want %d, got %d", 0, n)
}
})

t.Run("flush", func(t *testing.T) {
l := newLambdaFunction(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f, ok := w.(http.Flusher)
if !ok {
t.Error("http.ResponseWriter doesn't implement http.Flusher")
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)

io.WriteString(w, `{"hello":`)
f.Flush()
io.WriteString(w, `"world"}`)
}))
r, w := io.Pipe()
contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{
RequestContext: requestContext{
HTTP: &requestContextHTTP{
Path: "/",
},
},
}, w)
if err != nil {
t.Fatal(err)
}
if got, want := contentType, "application/vnd.awslambda.http-integration-response"; got != want {
t.Errorf("unexpected content type: want %q, got %q", want, got)
}

// Reads and Writes on the pipe are matched one to one,
// so we get only the header on first read.
buf := make([]byte, 1024)
n, err := r.Read(buf)
if err != nil {
t.Fatal(err)
}
if got, want := string(buf[:n]), "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"application/json\"}}\x00\x00\x00\x00\x00\x00\x00\x00"; got != want {
t.Errorf("unexpected body: want %q, got %q", want, got)
}

// The second read gets the half of the body.
n, err = r.Read(buf)
if err != nil {
t.Fatal(err)
}
if got, want := string(buf[:n]), "{\"hello\":"; got != want {
t.Errorf("unexpected body: want %q, got %q", want, got)
}

// The third read gets the rest of the body.
n, err = r.Read(buf)
if err != nil {
t.Fatal(err)
}
if got, want := string(buf[:n]), "\"world\"}"; got != want {
t.Errorf("unexpected body: want %q, got %q", want, got)
}

// The forth read gets EOF.
n, err = r.Read(buf)
if err != io.EOF {
t.Errorf("unexpected error: want %v, got %v", io.EOF, err)
}
if n != 0 {
t.Errorf("unexpected read size: want %d, got %d", 0, n)
}
})

t.Run("detect content-type", func(t *testing.T) {
l := newLambdaFunction(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, `<html></html>`)
}))
r, w := io.Pipe()
contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{
RequestContext: requestContext{
HTTP: &requestContextHTTP{
Path: "/",
},
},
}, w)
if err != nil {
t.Fatal(err)
}
if got, want := contentType, "application/vnd.awslambda.http-integration-response"; got != want {
t.Errorf("unexpected content type: want %q, got %q", want, got)
}

data, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if got, want := string(data), "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"text/html; charset=utf-8\"}}\x00\x00\x00\x00\x00\x00\x00\x00<html></html>"; got != want {
t.Errorf("unexpected body: want %q, got %q", want, got)
}
})
}

0 comments on commit 423bf88

Please sign in to comment.