diff --git a/src/net/textproto/reader.go b/src/net/textproto/reader.go index 8e800088c1fe79..48ae2946a6b0e8 100644 --- a/src/net/textproto/reader.go +++ b/src/net/textproto/reader.go @@ -489,8 +489,11 @@ func readMIMEHeader(r *Reader, lim int64) (MIMEHeader, error) { // large one ahead of time which we'll cut up into smaller // slices. If this isn't big enough later, we allocate small ones. var strs []string - hint := r.upcomingHeaderNewlines() + hint := r.upcomingHeaderKeys() if hint > 0 { + if hint > 1000 { + hint = 1000 // set a cap to avoid overallocation + } strs = make([]string, hint) } @@ -581,9 +584,9 @@ func mustHaveFieldNameColon(line []byte) error { var nl = []byte("\n") -// upcomingHeaderNewlines returns an approximation of the number of newlines +// upcomingHeaderKeys returns an approximation of the number of keys // that will be in this header. If it gets confused, it returns 0. -func (r *Reader) upcomingHeaderNewlines() (n int) { +func (r *Reader) upcomingHeaderKeys() (n int) { // Try to determine the 'hint' size. r.R.Peek(1) // force a buffer load if empty s := r.R.Buffered() @@ -591,7 +594,20 @@ func (r *Reader) upcomingHeaderNewlines() (n int) { return } peek, _ := r.R.Peek(s) - return bytes.Count(peek, nl) + for len(peek) > 0 && n < 1000 { + var line []byte + line, peek, _ = bytes.Cut(peek, nl) + if len(line) == 0 || (len(line) == 1 && line[0] == '\r') { + // Blank line separating headers from the body. + break + } + if line[0] == ' ' || line[0] == '\t' { + // Folded continuation of the previous line. + continue + } + n++ + } + return n } // CanonicalMIMEHeaderKey returns the canonical format of the diff --git a/src/net/textproto/reader_test.go b/src/net/textproto/reader_test.go index 9618b874e6f853..696ae406f3860f 100644 --- a/src/net/textproto/reader_test.go +++ b/src/net/textproto/reader_test.go @@ -10,6 +10,7 @@ import ( "io" "net" "reflect" + "runtime" "strings" "sync" "testing" @@ -129,6 +130,42 @@ func TestReadMIMEHeaderSingle(t *testing.T) { } } +// TestReaderUpcomingHeaderKeys is testing an internal function, but it's very +// difficult to test well via the external API. +func TestReaderUpcomingHeaderKeys(t *testing.T) { + for _, test := range []struct { + input string + want int + }{{ + input: "", + want: 0, + }, { + input: "A: v", + want: 1, + }, { + input: "A: v\r\nB: v\r\n", + want: 2, + }, { + input: "A: v\nB: v\n", + want: 2, + }, { + input: "A: v\r\n continued\r\n still continued\r\nB: v\r\n\r\n", + want: 2, + }, { + input: "A: v\r\n\r\nB: v\r\nC: v\r\n", + want: 1, + }, { + input: "A: v" + strings.Repeat("\n", 1000), + want: 1, + }} { + r := reader(test.input) + got := r.upcomingHeaderKeys() + if test.want != got { + t.Fatalf("upcomingHeaderKeys(%q): %v; want %v", test.input, got, test.want) + } + } +} + func TestReadMIMEHeaderNoKey(t *testing.T) { r := reader(": bar\ntest-1: 1\n\n") m, err := r.ReadMIMEHeader() @@ -271,6 +308,28 @@ func TestReadMIMEHeaderTrimContinued(t *testing.T) { } } +// Test that reading a header doesn't overallocate. Issue 58975. +func TestReadMIMEHeaderAllocations(t *testing.T) { + var totalAlloc uint64 + const count = 200 + for i := 0; i < count; i++ { + r := reader("A: b\r\n\r\n" + strings.Repeat("\n", 4096)) + var m1, m2 runtime.MemStats + runtime.ReadMemStats(&m1) + _, err := r.ReadMIMEHeader() + if err != nil { + t.Fatalf("ReadMIMEHeader: %v", err) + } + runtime.ReadMemStats(&m2) + totalAlloc += m2.TotalAlloc - m1.TotalAlloc + } + // 32k is large and we actually allocate substantially less, + // but prior to the fix for #58975 we allocated ~400k in this case. + if got, want := totalAlloc/count, uint64(32768); got > want { + t.Fatalf("ReadMIMEHeader allocated %v bytes, want < %v", got, want) + } +} + type readResponseTest struct { in string inCode int