Skip to content

Commit

Permalink
feat(cookies): add cookie support
Browse files Browse the repository at this point in the history
Signed-off-by: Felipe Zipitria <[email protected]>
  • Loading branch information
fzipi committed Jun 13, 2021
1 parent deacd50 commit a7ffb4f
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 141 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ require (
github.com/spf13/cobra v1.1.3
github.com/yargevad/filepathx v1.0.0
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a // indirect
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 // indirect
)
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
Expand Down
79 changes: 79 additions & 0 deletions http/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package http

import (
"crypto/tls"
"fmt"
"net"
"net/http/cookiejar"
"strings"
"time"

"github.com/rs/zerolog/log"
"golang.org/x/net/publicsuffix"
)

// NewClient initializes the http client, creating the cookiejar
func NewClient() *Client {
// All users of cookiejar should import "golang.org/x/net/publicsuffix"
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
log.Fatal().Err(err)
}
c := &Client{
Jar: jar,
// default Timeout
Timeout: 3 * time.Second,
}
return c
}

// NewConnection creates a new Connection based on a Destination
func (c *Client) NewConnection(d Destination) error {
var err error
var netConn net.Conn

hostPort := fmt.Sprintf("%s:%d", d.DestAddr, d.Port)

// Fatal error: dial tcp 127.0.0.1:80: connect: connection refused
// strings.HasSuffix(err.String(), "connection refused") {
if strings.ToLower(d.Protocol) == "https" {
// Commenting InsecureSkipVerify: true.
netConn, err = tls.DialWithDialer(&net.Dialer{Timeout: c.Timeout}, "tcp", hostPort, &tls.Config{})
} else {
netConn, err = net.DialTimeout("tcp", hostPort, c.Timeout)
}

if err == nil {
c.Transport = &Connection{
connection: netConn,
protocol: d.Protocol,
duration: NewRoundTripTime(),
}
}

return err
}

// Do performs the http request roundtrip
func (c *Client) Do(req Request) (*Response, error) {
var response *Response

err := c.Transport.Request(&req)

if err != nil {
log.Error().Msgf("http/client: error sending request: %s\n", err.Error())
} else {
response, err = c.Transport.Response()
if err != nil {
log.Debug().Msgf("ftw/run: error receiving response: %s\n", err.Error())
// This error might be expected. Let's continue
}
}

return response, err
}

// GetRoundTripTime returns the time taken from the initial send till receiving the full response
func (c *Client) GetRoundTripTime() *RoundTripTime {
return c.Transport.GetRoundTripTime()
}
70 changes: 7 additions & 63 deletions http/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,16 @@
package http

import (
"crypto/tls"
"errors"
"fmt"
"io/ioutil"
"net"
"net/url"
"strconv"
"strings"
"time"

"github.com/rs/zerolog/log"
)

// NewConnection creates a new Connection based on a Destination
func NewConnection(d Destination) (*Connection, error) {
var netConn net.Conn
var tlsConn *tls.Conn
var err error
var timeout time.Duration

hostPort := fmt.Sprintf("%s:%d", d.DestAddr, d.Port)
timeout = 3 * time.Second

// Fatal error: dial tcp 127.0.0.1:80: connect: connection refused
// strings.HasSuffix(err.String(), "connection refused") {
if strings.ToLower(d.Protocol) == "https" {
// Commenting InsecureSkipVerify: true.
tlsConn, err = tls.Dial("tcp", hostPort, &tls.Config{})
} else {
netConn, err = net.DialTimeout("tcp", hostPort, timeout)
}
c := &Connection{
netConn: netConn,
tlsConn: tlsConn,
protocol: d.Protocol,
duration: NewRoundTripTime(),
}

return c, err
}

// DestinationFromString create a Destination from String
func DestinationFromString(urlString string) *Destination {
u, _ := url.Parse(urlString)
Expand Down Expand Up @@ -79,21 +48,10 @@ func (c *Connection) send(data []byte) (int, error) {
// Store times for searching in logs, if necessary
c.startTracking()

switch c.protocol {
case "http":
if c.netConn != nil {
sent, err = c.netConn.Write(data)
} else {
err = errors.New("ftw/http: http selected but not connected to http")
}
case "https":
if c.tlsConn != nil {
sent, err = c.tlsConn.Write(data)
} else {
err = errors.New("ftw/http: https selected but not connected to https")
}
default:
err = fmt.Errorf("ftw/http: unsupported protocol %s", c.protocol)
if c.connection != nil {
sent, err = c.connection.Write(data)
} else {
err = errors.New("ftw/http/send: not connected to server")
}

return sent, err
Expand All @@ -111,18 +69,10 @@ func (c *Connection) receive() ([]byte, error) {

// We assume the response body can be handled in memory without problems
// That's why we use ioutil.ReadAll
switch c.protocol {
case "https":
defer c.tlsConn.Close()
if err = c.tlsConn.SetReadDeadline(time.Now().Add(timeoutDuration)); err == nil {
buf, err = ioutil.ReadAll(c.tlsConn)
}
default:
defer c.netConn.Close()
if err = c.netConn.SetReadDeadline(time.Now().Add(timeoutDuration)); err == nil {
buf, err = ioutil.ReadAll(c.netConn)
}
if err = c.connection.SetReadDeadline(time.Now().Add(timeoutDuration)); err == nil {
buf, err = ioutil.ReadAll(c.connection)
}

if neterr, ok := err.(net.Error); ok && !neterr.Timeout() {
log.Error().Msgf("ftw/http: %s\n", err.Error())
} else {
Expand All @@ -133,9 +83,3 @@ func (c *Connection) receive() ([]byte, error) {

return buf, err
}

// All users of cookiejar should import "golang.org/x/net/publicsuffix"
// jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
// if err != nil {
// log.Fatal(err)
// }
4 changes: 2 additions & 2 deletions http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (r *Request) AddStandardHeaders(size int) {
}

// Request will use all the inputs and send a raw http request to the destination
func (c *Connection) Request(request *Request) (*Connection, error) {
func (c *Connection) Request(request *Request) error {
// Build request first, then connect and send, so timers are accurate
data, err := buildRequest(request)
if err != nil {
Expand All @@ -148,7 +148,7 @@ func (c *Connection) Request(request *Request) (*Connection, error) {
log.Error().Msgf("ftw/http: error writing data: %s", err.Error())
}

return c, err
return err
}

// isRaw is a helper that returns true if raw or encoded data
Expand Down
88 changes: 81 additions & 7 deletions http/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,50 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
)

func generateRequestForTesting() *Request {
func generateRequestForTesting(keepalive bool) *Request {
var req *Request
var connection string

rl := &RequestLine{
Method: "GET",
URI: "/",
Version: "HTTP/1.1",
}

h := Header{"Host": "localhost", "User-Agent": "Go Tests"}
if keepalive {
connection = "keep-alive"
} else {
connection = "close"
}
h := Header{
"Host": "localhost",
"User-Agent": "Go Tests",
"Connection": connection,
}

req = NewRequest(rl, h, nil, true)

return req
}

func generateRequestWithCookiesForTesting() *Request {
var req *Request

rl := &RequestLine{
Method: "GET",
URI: "/",
Version: "HTTP/1.1",
}

h := Header{
"Host": "localhost",
"User-Agent": "Go Tests",
"Cookie": "THISISACOOKIE",
"Connection": "Keep-Alive",
}

req = NewRequest(rl, h, nil, true)

Expand All @@ -33,34 +65,76 @@ func testServer() (server *httptest.Server) {
return ts
}

func testServerWithCookies() (server *httptest.Server) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expiration := time.Now().Add(365 * 24 * time.Hour)
cookie := http.Cookie{Name: "username", Value: "go-ftw", Expires: expiration}
http.SetCookie(w, &cookie)
fmt.Fprintln(w, "Setting Cookies!")
}))

return ts
}

func TestResponse(t *testing.T) {
server := testServer()

defer server.Close()

d := DestinationFromString(server.URL)

req := generateRequestForTesting()
req := generateRequestForTesting(true)

client, err := NewConnection(*d)
client := NewClient()
err := client.NewConnection(*d)

if err != nil {
t.Fatalf("Error! %s", err.Error())
}
client, err = client.Request(req)

response, err := client.Do(*req)

if err != nil {
t.Logf("Failed !")
}

response, err := client.Response()
if response.GetBodyAsString() != "Hello, client\n" {
t.Errorf("Error!")
}

}

func TestResponseWithCookies(t *testing.T) {
server := testServerWithCookies()

defer server.Close()

d := DestinationFromString(server.URL)

req := generateRequestForTesting(true)

client := NewClient()
err := client.NewConnection(*d)

if err != nil {
t.Fatalf("Error! %s", err.Error())
}

response, err := client.Do(*req)

if err != nil {
t.Logf("Failed !")
}

if response.GetBodyAsString() != "Hello, client\n" {
if response.GetBodyAsString() != "Setting Cookies!\n" {
t.Errorf("Error!")
}

cookiereq := generateRequestWithCookiesForTesting()

_, err = client.Do(*cookiereq)

if err != nil {
t.Logf("Failed !")
}
}
15 changes: 10 additions & 5 deletions http/types.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
package http

import (
"crypto/tls"
"net"
"net/http"
"time"
)

// Client is the top level abstraction in http
type Client struct {
Transport *Connection
Jar http.CookieJar
Timeout time.Duration
}

// Connection is the type used for sending/receiving data
type Connection struct {
netConn net.Conn
tlsConn *tls.Conn
protocol string
duration *RoundTripTime
connection net.Conn
protocol string
duration *RoundTripTime
}

// RoundTripTime abstracts the time a transaction takes
Expand Down
Loading

0 comments on commit a7ffb4f

Please sign in to comment.