diff --git a/internal/trace/context.go b/internal/trace/context.go index d55c3812d..1acf25811 100644 --- a/internal/trace/context.go +++ b/internal/trace/context.go @@ -36,7 +36,7 @@ func NewLogger(ctx context.Context, debug bool) (context.Context, logrus.FieldLo } logger := logrus.New() - logger.SetFormatter(&logrus.TextFormatter{DisableQuote: true}) + logger.SetFormatter(&TextFormatter{}) logger.SetLevel(logLevel) entry := logger.WithContext(ctx) return context.WithValue(ctx, loggerKey, entry), entry diff --git a/internal/trace/text_formatter.go b/internal/trace/text_formatter.go new file mode 100644 index 000000000..9c6472464 --- /dev/null +++ b/internal/trace/text_formatter.go @@ -0,0 +1,50 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trace + +import ( + "bytes" + "fmt" + "strings" + "time" + + "github.com/sirupsen/logrus" +) + +// TextFormatter formats logs into text. +type TextFormatter struct{} + +// logEntrySeperator is the separator between log entries. +const logEntrySeperator = "\n\n" // two empty lines + +// Format renders a single log entry. +func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) { + var buf bytes.Buffer + + timestamp := entry.Time.Format(time.RFC3339Nano) + levelText := strings.ToUpper(entry.Level.String()) + fmt.Fprintf(&buf, "[%s][%s]: %s\n", timestamp, levelText, entry.Message) + // print data fields + if len(entry.Data) > 0 { + buf.WriteString("[Data]:\n") + for k, v := range entry.Data { + fmt.Fprintf(&buf, " %s=%v\n", k, v) + } + } + + buf.WriteString(logEntrySeperator) + return buf.Bytes(), nil +} diff --git a/internal/trace/text_formatter_test.go b/internal/trace/text_formatter_test.go new file mode 100644 index 000000000..c7ac5b9c1 --- /dev/null +++ b/internal/trace/text_formatter_test.go @@ -0,0 +1,99 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trace + +import ( + "reflect" + "testing" + "time" + + "github.com/sirupsen/logrus" +) + +func TestTextFormatter_Format(t *testing.T) { + tests := []struct { + name string + f *TextFormatter + entry *logrus.Entry + want []byte + wantErr bool + }{ + { + name: "debug log entry", + f: &TextFormatter{}, + entry: &logrus.Entry{ + Time: time.Date(2024, time.December, 1, 23, 30, 1, 55, time.UTC), + Level: logrus.DebugLevel, + Message: "test debug", + Data: logrus.Fields{}, + }, + want: []byte("[2024-12-01T23:30:01.000000055Z][DEBUG]: test debug\n\n\n"), + wantErr: false, + }, + { + name: "info log entry", + f: &TextFormatter{}, + entry: &logrus.Entry{ + Time: time.Date(2024, time.December, 1, 23, 30, 1, 55, time.UTC), + Level: logrus.InfoLevel, + Message: "test info", + Data: logrus.Fields{}, + }, + want: []byte("[2024-12-01T23:30:01.000000055Z][INFO]: test info\n\n\n"), + wantErr: false, + }, + { + name: "warning log entry with data fields", + f: &TextFormatter{}, + entry: &logrus.Entry{ + Time: time.Date(2024, time.December, 1, 23, 30, 1, 55, time.UTC), + Level: logrus.WarnLevel, + Message: "test warning with fields", + Data: logrus.Fields{ + "testkey": "testval", + }, + }, + want: []byte("[2024-12-01T23:30:01.000000055Z][WARNING]: test warning with fields\n[Data]:\n testkey=testval\n\n\n"), + wantErr: false, + }, + { + name: "error log entry with data fields", + f: &TextFormatter{}, + entry: &logrus.Entry{ + Time: time.Date(2024, time.December, 1, 23, 30, 1, 55, time.UTC), + Level: logrus.ErrorLevel, + Message: "test warning with fields", + Data: logrus.Fields{ + "testkey": 123, + }, + }, + want: []byte("[2024-12-01T23:30:01.000000055Z][ERROR]: test warning with fields\n[Data]:\n testkey=123\n\n\n"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.f.Format(tt.entry) + if (err != nil) != tt.wantErr { + t.Errorf("TextFormatter.Format() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("TextFormatter.Format() = %s, want %s", got, tt.want) + } + }) + } +} diff --git a/internal/trace/transport.go b/internal/trace/transport.go index 76630eb23..fa02f30d0 100644 --- a/internal/trace/transport.go +++ b/internal/trace/transport.go @@ -16,7 +16,10 @@ limitations under the License. package trace import ( + "bytes" "fmt" + "io" + "mime" "net/http" "strings" "sync/atomic" @@ -34,6 +37,9 @@ var ( } ) +// payloadSizeLimit limits the maximum size of the response body to be printed. +const payloadSizeLimit int64 = 16 * 1024 // 16 KiB + // Transport is an http.RoundTripper that keeps track of the in-flight // request and add hooks to report HTTP tracing events. type Transport struct { @@ -54,18 +60,18 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error e := Logger(ctx) // log the request - e.Debugf("Request #%d\n> Request URL: %q\n> Request method: %q\n> Request headers:\n%s", + e.Debugf("--> Request #%d\n> Request URL: %q\n> Request method: %q\n> Request headers:\n%s", id, req.URL, req.Method, logHeader(req.Header)) // log the response resp, err = t.RoundTripper.RoundTrip(req) if err != nil { - e.Errorf("Error in getting response: %w", err) + e.Errorf("<-- Response #%d\nError in getting response: %v", id, err) } else if resp == nil { - e.Errorf("No response obtained for request %s %q", req.Method, req.URL) + e.Errorf("<-- Response #%d\nNo response obtained for request %s %q", id, req.Method, req.URL) } else { - e.Debugf("Response #%d\n< Response Status: %q\n< Response headers:\n%s", - id, resp.Status, logHeader(resp.Header)) + e.Debugf("<-- Response #%d\n< Response Status: %q\n< Response headers:\n%s\n< Response body:\n%s", + id, resp.Status, logHeader(resp.Header), logResponseBody(resp)) } return resp, err } @@ -87,3 +93,67 @@ func logHeader(header http.Header) string { } return " Empty header" } + +// logResponseBody prints out the response body if it is printable and within +// the size limit. +func logResponseBody(resp *http.Response) string { + if resp.Body == nil || resp.Body == http.NoBody { + return " No response body to print" + } + + // non-applicable body is not printed and remains untouched for subsequent processing + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + return " Response body without a content type is not printed" + } + if !isPrintableContentType(contentType) { + return fmt.Sprintf(" Response body of content type %q is not printed", contentType) + } + + buf := bytes.NewBuffer(nil) + body := resp.Body + // restore the body by concatenating the read body with the remaining body + resp.Body = struct { + io.Reader + io.Closer + }{ + Reader: io.MultiReader(buf, body), + Closer: body, + } + // read the body up to limit+1 to check if the body exceeds the limit + if _, err := io.CopyN(buf, body, payloadSizeLimit+1); err != nil && err != io.EOF { + return fmt.Sprintf(" Error reading response body: %v", err) + } + + readBody := buf.String() + if len(readBody) == 0 { + return " Response body is empty" + } + if containsCredentials(readBody) { + return " Response body redacted due to potential credentials" + } + if len(readBody) > int(payloadSizeLimit) { + return readBody[:payloadSizeLimit] + "\n...(truncated)" + } + return readBody +} + +// isPrintableContentType returns true if the content of contentType is printable. +func isPrintableContentType(contentType string) bool { + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return false + } + + switch mediaType { + case "application/json", // JSON types + "text/plain", "text/html": // text types + return true + } + return strings.HasSuffix(mediaType, "+json") +} + +// containsCredentials returns true if the body contains potential credentials. +func containsCredentials(body string) bool { + return strings.Contains(body, `"token"`) || strings.Contains(body, `"access_token"`) +} diff --git a/internal/trace/transport_test.go b/internal/trace/transport_test.go new file mode 100644 index 000000000..764f216bf --- /dev/null +++ b/internal/trace/transport_test.go @@ -0,0 +1,398 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package trace + +import ( + "bytes" + "errors" + "io" + "net/http" + "testing" +) + +var errMockRead = errors.New("mock read error") + +type errorReader struct{} + +func (e *errorReader) Read(p []byte) (n int, err error) { + return 0, errMockRead +} + +func Test_isPrintableContentType(t *testing.T) { + tests := []struct { + name string + contentType string + want bool + }{ + { + name: "Empty content type", + contentType: "", + want: false, + }, + { + name: "General JSON type", + contentType: "application/json", + want: true, + }, + { + name: "General JSON type with charset", + contentType: "application/json; charset=utf-8", + want: true, + }, + { + name: "Random type with application/json prefix", + contentType: "application/jsonwhatever", + want: false, + }, + { + name: "Manifest type in JSON", + contentType: "application/vnd.oci.image.manifest.v1+json", + want: true, + }, + { + name: "Manifest type in JSON with charset", + contentType: "application/vnd.oci.image.manifest.v1+json; charset=utf-8", + want: true, + }, + { + name: "Random content type in JSON", + contentType: "application/whatever+json", + want: true, + }, + { + name: "Plain text type", + contentType: "text/plain", + want: true, + }, + { + name: "Plain text type with charset", + contentType: "text/plain; charset=utf-8", + want: true, + }, + { + name: "Random type with text/plain prefix", + contentType: "text/plainnnnn", + want: false, + }, + { + name: "HTML type", + contentType: "text/html", + want: true, + }, + { + name: "Plain text type with charset", + contentType: "text/html; charset=utf-8", + want: true, + }, + { + name: "Random type with text/html prefix", + contentType: "text/htmlllll", + want: false, + }, + { + name: "Binary type", + contentType: "application/octet-stream", + want: false, + }, + { + name: "Unknown type", + contentType: "unknown/unknown", + want: false, + }, + { + name: "Invalid type", + contentType: "text/", + want: false, + }, + { + name: "Random string", + contentType: "random123!@#", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isPrintableContentType(tt.contentType); got != tt.want { + t.Errorf("isPrintableContentType() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_logResponseBody(t *testing.T) { + tests := []struct { + name string + resp *http.Response + want string + wantData []byte + }{ + { + name: "Nil body", + resp: &http.Response{ + Body: nil, + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, + want: " No response body to print", + }, + { + name: "No body", + wantData: nil, + resp: &http.Response{ + Body: http.NoBody, + ContentLength: 100, // in case of HEAD response, the content length is set but the body is empty + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, + want: " No response body to print", + }, + { + name: "Empty body", + wantData: []byte(""), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte(""))), + ContentLength: 0, + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + want: " Response body is empty", + }, + { + name: "Unknown content length", + wantData: []byte("whatever"), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte("whatever"))), + ContentLength: -1, + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + want: "whatever", + }, + { + name: "Missing content type header", + wantData: []byte("whatever"), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte("whatever"))), + ContentLength: 8, + }, + want: " Response body without a content type is not printed", + }, + { + name: "Empty content type header", + wantData: []byte("whatever"), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte("whatever"))), + ContentLength: 8, + Header: http.Header{"Content-Type": []string{""}}, + }, + want: " Response body without a content type is not printed", + }, + { + name: "Non-printable content type", + wantData: []byte("binary data"), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte("binary data"))), + ContentLength: 11, + Header: http.Header{"Content-Type": []string{"application/octet-stream"}}, + }, + want: " Response body of content type \"application/octet-stream\" is not printed", + }, + { + name: "Body at the limit", + wantData: bytes.Repeat([]byte("a"), int(payloadSizeLimit)), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader(bytes.Repeat([]byte("a"), int(payloadSizeLimit)))), + ContentLength: payloadSizeLimit, + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + want: string(bytes.Repeat([]byte("a"), int(payloadSizeLimit))), + }, + { + name: "Body larger than limit", + wantData: bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader(bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)))), // 1 byte larger than limit + ContentLength: payloadSizeLimit + 1, + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + want: string(bytes.Repeat([]byte("a"), int(payloadSizeLimit))) + "\n...(truncated)", + }, + { + name: "Printable content type within limit", + wantData: []byte("data"), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte("data"))), + ContentLength: 4, + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + want: "data", + }, + { + name: "Actual body size is larger than content length", + wantData: []byte("data"), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte("data"))), + ContentLength: 3, // mismatched content length + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + want: "data", + }, + { + name: "Actual body size is larger than content length and exceeds limit", + wantData: bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader(bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)))), // 1 byte larger than limit + ContentLength: 1, // mismatched content length + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + want: string(bytes.Repeat([]byte("a"), int(payloadSizeLimit))) + "\n...(truncated)", + }, + { + name: "Actual body size is smaller than content length", + wantData: []byte("data"), + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte("data"))), + ContentLength: 5, // mismatched content length + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + want: "data", + }, + { + name: "Body contains token", + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte(`{"token":"12345"}`))), + ContentLength: 17, + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, + wantData: []byte(`{"token":"12345"}`), + want: " Response body redacted due to potential credentials", + }, + { + name: "Body contains access_token", + resp: &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte(`{"access_token":"12345"}`))), + ContentLength: 17, + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, + wantData: []byte(`{"access_token":"12345"}`), + want: " Response body redacted due to potential credentials", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := logResponseBody(tt.resp); got != tt.want { + t.Errorf("logResponseBody() = %v, want %v", got, tt.want) + } + // validate the response body + if tt.resp.Body != nil { + readBytes, err := io.ReadAll(tt.resp.Body) + if err != nil { + t.Errorf("failed to read body after logResponseBody(), err= %v", err) + } + if !bytes.Equal(readBytes, tt.wantData) { + t.Errorf("resp.Body after logResponseBody() = %v, want %v", readBytes, tt.wantData) + } + if closeErr := tt.resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close body after logResponseBody(), err= %v", closeErr) + } + } + }) + } +} + +func Test_logResponseBody_error(t *testing.T) { + tests := []struct { + name string + resp *http.Response + want string + }{ + { + name: "Error reading body", + resp: &http.Response{ + Body: io.NopCloser(&errorReader{}), + ContentLength: 10, + Header: http.Header{"Content-Type": []string{"text/plain"}}, + }, + want: " Error reading response body: mock read error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := logResponseBody(tt.resp); got != tt.want { + t.Errorf("logResponseBody() = %v, want %v", got, tt.want) + } + if closeErr := tt.resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close body after logResponseBody(), err= %v", closeErr) + } + }) + } +} + +func Test_containsCredentials(t *testing.T) { + tests := []struct { + name string + body string + want bool + }{ + { + name: "Contains token keyword", + body: `{"token": "12345"}`, + want: true, + }, + { + name: "Contains quoted token keyword", + body: `whatever "token" blah`, + want: true, + }, + { + name: "Contains unquoted token keyword", + body: `whatever token blah`, + want: false, + }, + { + name: "Contains access_token keyword", + body: `{"access_token": "12345"}`, + want: true, + }, + { + name: "Contains quoted access_token keyword", + body: `whatever "access_token" blah`, + want: true, + }, + { + name: "Contains unquoted access_token keyword", + body: `whatever access_token blah`, + want: false, + }, + { + name: "Does not contain credentials", + body: `{"key": "value"}`, + want: false, + }, + { + name: "Empty body", + body: ``, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := containsCredentials(tt.body); got != tt.want { + t.Errorf("containsCredentials() = %v, want %v", got, tt.want) + } + }) + } +}