diff --git a/writer.go b/writer.go index 72e8e3a1..0275e738 100644 --- a/writer.go +++ b/writer.go @@ -4,6 +4,7 @@ import ( "bufio" "io" "runtime" + "strings" ) // Writer at INFO level. See WriterLevel for details. @@ -54,14 +55,37 @@ func (entry *Entry) WriterLevel(level Level) *io.PipeWriter { return writer } +// writerScanner scans the input from the reader and writes it to the logger func (entry *Entry) writerScanner(reader *io.PipeReader, printFunc func(args ...interface{})) { scanner := bufio.NewScanner(reader) + + // Set the buffer size to the maximum token size to avoid buffer overflows + scanner.Buffer(make([]byte, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize) + + // Define a split function to split the input into chunks of up to 64KB + chunkSize := bufio.MaxScanTokenSize // 64KB + splitFunc := func(data []byte, atEOF bool) (int, []byte, error) { + if len(data) >= chunkSize { + return chunkSize, data[:chunkSize], nil + } + + return bufio.ScanLines(data, atEOF) + } + + // Use the custom split function to split the input + scanner.Split(splitFunc) + + // Scan the input and write it to the logger using the specified print function for scanner.Scan() { - printFunc(scanner.Text()) + printFunc(strings.TrimRight(scanner.Text(), "\r\n")) } + + // If there was an error while scanning the input, log an error if err := scanner.Err(); err != nil { entry.Errorf("Error while reading from Writer: %s", err) } + + // Close the reader when we are done reader.Close() } diff --git a/writer_test.go b/writer_test.go index 5c34927d..5b6261bd 100644 --- a/writer_test.go +++ b/writer_test.go @@ -1,10 +1,16 @@ package logrus_test import ( + "bufio" + "bytes" "log" "net/http" + "strings" + "testing" + "time" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" ) func ExampleLogger_Writer_httpServer() { @@ -32,3 +38,61 @@ func ExampleLogger_Writer_stdlib() { // Not logrus imported under the name `log`. log.SetOutput(logger.Writer()) } + +func TestWriterSplitNewlines(t *testing.T) { + buf := bytes.NewBuffer(nil) + logger := logrus.New() + logger.Formatter = &logrus.TextFormatter{ + DisableColors: true, + DisableTimestamp: true, + } + logger.SetOutput(buf) + writer := logger.Writer() + + const logNum = 10 + + for i := 0; i < logNum; i++ { + _, err := writer.Write([]byte("bar\nfoo\n")) + assert.NoError(t, err, "writer.Write failed") + } + writer.Close() + // Test is flaky because it writes in another goroutine, + // we need to make sure to wait a bit so all write are done. + time.Sleep(500 * time.Millisecond) + + lines := strings.Split(strings.TrimRight(buf.String(), "\n"), "\n") + assert.Len(t, lines, logNum*2, "logger printed incorrect number of lines") +} + +func TestWriterSplitsMax64KB(t *testing.T) { + buf := bytes.NewBuffer(nil) + logger := logrus.New() + logger.Formatter = &logrus.TextFormatter{ + DisableColors: true, + DisableTimestamp: true, + } + logger.SetOutput(buf) + writer := logger.Writer() + + // write more than 64KB + const bigWriteLen = bufio.MaxScanTokenSize + 100 + output := make([]byte, bigWriteLen) + // lets not write zero bytes + for i := 0; i < bigWriteLen; i++ { + output[i] = 'A' + } + + for i := 0; i < 3; i++ { + len, err := writer.Write(output) + assert.NoError(t, err, "writer.Write failed") + assert.Equal(t, bigWriteLen, len, "bytes written") + } + writer.Close() + // Test is flaky because it writes in another goroutine, + // we need to make sure to wait a bit so all write are done. + time.Sleep(500 * time.Millisecond) + + lines := strings.Split(strings.TrimRight(buf.String(), "\n"), "\n") + // we should have 4 lines because we wrote more than 64 KB each time + assert.Len(t, lines, 4, "logger printed incorrect number of lines") +}