Skip to content

Commit

Permalink
misc improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunKoyalwar committed Oct 30, 2023
1 parent fcc9ffa commit 55613f7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 13 deletions.
31 changes: 18 additions & 13 deletions reader/readmax.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"bytes"
"errors"
"io"
"syscall"
)

const (
// although this is more than enough for most cases
MaxReadSize = 1 << 23 // 8MB
BuffAllocSize = 1 << 12 // 4KB
MaxReadSize = 1 << 23 // 8MB
)

var (
Expand All @@ -30,19 +30,24 @@ func ConnReadN(reader io.Reader, N int64) ([]byte, error) {
} else if N > MaxReadSize {
return nil, ErrTooLarge
}
// no need to allocate slice upfront if N is signficantly larger than BuffAllocSize
// since we use this to read from network connection lets read 4KB at a time
// net.Conn has 2KB and 4KB variants while reading http request from net.Conn
allocSize := N
if N > BuffAllocSize {
allocSize = BuffAllocSize
}

buff := bytes.NewBuffer(make([]byte, 0, allocSize))
var buff bytes.Buffer
// read N bytes or until EOF
_, err := io.CopyN(buff, reader, N)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
_, err := io.CopyN(&buff, io.LimitReader(reader, N), N)
if err != nil && !IsAcceptedError(err) {
return nil, err
}
return buff.Bytes(), nil
}

// IsAcceptedError checks if the error is accepted error
// for example: timeout, connection refused, io.EOF, io.ErrUnexpectedEOF
// while reading from connection
func IsAcceptedError(err error) bool {
if err == io.EOF || err == io.ErrUnexpectedEOF {
return true
}
if errors.Is(err, syscall.ECONNREFUSED) {
return true
}
return false
}
15 changes: 15 additions & 0 deletions reader/readmax_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package reader

import (
"bytes"
"crypto/tls"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestConnReadN(t *testing.T) {
Expand Down Expand Up @@ -47,4 +51,15 @@ func TestConnReadN(t *testing.T) {
t.Errorf("Expected 'Hello', got '%s'", string(data))
}
})
t.Run("Read From Connection", func(t *testing.T) {
conn, err := tls.Dial("tcp", "projectdiscovery.io:443", &tls.Config{InsecureSkipVerify: true})
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
require.Nil(t, err, "could not connect to projectdiscovery.io over tls")
defer conn.Close()
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: projectdiscovery.io\r\nConnection: close\r\n\r\n"))
require.Nil(t, err, "could not write to connection")
data, err := ConnReadN(conn, -1)
require.Nilf(t, err, "could not read from connection: %s", err)
require.NotEmpty(t, data, "could not read from connection")
})
}

0 comments on commit 55613f7

Please sign in to comment.