diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index 183dcf9f76..af81e1024b 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -36,6 +36,13 @@ func DNSLookupOptionLogger(value model.Logger) DNSLookupOption { } } +// DNSLookupOptionTags allows to set tags to tag observations. +func DNSLookupOptionTags(value ...string) DNSLookupOption { + return func(dis *DomainToResolve) { + dis.Tags = append(dis.Tags, value...) + } +} + // DNSLookupOptionZeroTime configures the measurement's zero time. // See DomainToResolve docs for more information. func DNSLookupOptionZeroTime(value time.Time) DNSLookupOption { @@ -52,6 +59,7 @@ func NewDomainToResolve(domain DomainName, options ...DNSLookupOption) *DomainTo Domain: string(domain), IDGenerator: &atomic.Int64{}, Logger: model.DiscardLogger, + Tags: []string{}, ZeroTime: time.Now(), } for _, option := range options { @@ -81,6 +89,9 @@ type DomainToResolve struct { // implemented by NewDomainToResolve uses model.DiscardLogger. Logger model.Logger + // Tags contains OPTIONAL tags to tag observations. + Tags []string + // ZeroTime is the MANDATORY zero time of the measurement. We will // use this field as the zero value to compute relative elapsed times // when generating measurements. The default construction by @@ -132,7 +143,7 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime) + trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) // start the operation logger ol := measurexlite.NewOperationLogger( @@ -195,7 +206,7 @@ func (f *dnsLookupUDPFunc) Apply( ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime) + trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) // start the operation logger ol := measurexlite.NewOperationLogger( diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index e65a61d6ce..2f3e292804 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -35,6 +36,7 @@ func TestNewDomainToResolve(t *testing.T) { DNSLookupOptionIDGenerator(idGen), DNSLookupOptionLogger(model.DiscardLogger), DNSLookupOptionZeroTime(zt), + DNSLookupOptionTags("antani"), ) if domainToResolve.Domain != "www.example.com" { t.Fatalf("unexpected domain") @@ -48,6 +50,9 @@ func TestNewDomainToResolve(t *testing.T) { if domainToResolve.ZeroTime != zt { t.Fatalf("unexpected zerotime") } + if diff := cmp.Diff([]string{"antani"}, domainToResolve.Tags); diff != "" { + t.Fatal(diff) + } }) }) } @@ -73,13 +78,14 @@ func TestGetaddrinfo(t *testing.T) { Domain: "example.com", Logger: model.DiscardLogger, IDGenerator: &atomic.Int64{}, + Tags: []string{"antani"}, ZeroTime: time.Time{}, } t.Run("with nil resolver", func(t *testing.T) { f := dnsLookupGetaddrinfoFunc{} ctx, cancel := context.WithCancel(context.Background()) - cancel() + cancel() // immediately cancel the lookup res := f.Apply(ctx, domain) if res.Observations == nil || len(res.Observations) <= 0 { t.Fatal("unexpected empty observations") @@ -130,6 +136,9 @@ func TestGetaddrinfo(t *testing.T) { if len(res.State.Addresses) != 1 || res.State.Addresses[0] != "93.184.216.34" { t.Fatal("unexpected addresses") } + if diff := cmp.Diff([]string{"antani"}, res.State.Trace.Tags()); diff != "" { + t.Fatal(diff) + } }) }) } @@ -155,6 +164,7 @@ func TestLookupUDP(t *testing.T) { Domain: "example.com", Logger: model.DiscardLogger, IDGenerator: &atomic.Int64{}, + Tags: []string{"antani"}, ZeroTime: time.Time{}, } @@ -214,6 +224,9 @@ func TestLookupUDP(t *testing.T) { if len(res.State.Addresses) != 1 || res.State.Addresses[0] != "93.184.216.34" { t.Fatal("unexpected addresses") } + if diff := cmp.Diff([]string{"antani"}, res.State.Trace.Tags()); diff != "" { + t.Fatal(diff) + } }) }) } diff --git a/internal/dslx/endpoint.go b/internal/dslx/endpoint.go index 51e5fa0f4b..cd17d68c88 100644 --- a/internal/dslx/endpoint.go +++ b/internal/dslx/endpoint.go @@ -38,6 +38,9 @@ type Endpoint struct { // Network is the MANDATORY endpoint network. Network string + // Tags contains OPTIONAL tags for tagging observations. + Tags []string + // ZeroTime is the MANDATORY zero time of the measurement. ZeroTime time.Time } @@ -66,6 +69,13 @@ func EndpointOptionLogger(value model.Logger) EndpointOption { } } +// EndpointOptionTags allows to set tags to tag observations. +func EndpointOptionTags(value ...string) EndpointOption { + return func(es *Endpoint) { + es.Tags = append(es.Tags, value...) + } +} + // EndpointOptionZeroTime allows to set the zero time. func EndpointOptionZeroTime(value time.Time) EndpointOption { return func(es *Endpoint) { @@ -92,6 +102,7 @@ func NewEndpoint( IDGenerator: &atomic.Int64{}, Logger: model.DiscardLogger, Network: string(network), + Tags: []string{}, ZeroTime: time.Now(), } for _, option := range options { diff --git a/internal/dslx/endpoint_test.go b/internal/dslx/endpoint_test.go index 134c639da3..61170f1fc0 100644 --- a/internal/dslx/endpoint_test.go +++ b/internal/dslx/endpoint_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -21,6 +22,7 @@ func TestEndpoint(t *testing.T) { EndpointOptionIDGenerator(idGen), EndpointOptionLogger(model.DiscardLogger), EndpointOptionZeroTime(zt), + EndpointOptionTags("antani"), ) if testEndpoint.Network != "network" { t.Fatalf("unexpected network") @@ -40,5 +42,8 @@ func TestEndpoint(t *testing.T) { if testEndpoint.ZeroTime != zt { t.Fatalf("unexpected zero time") } + if diff := cmp.Diff([]string{"antani"}, testEndpoint.Tags); diff != "" { + t.Fatal(diff) + } }) } diff --git a/internal/dslx/http_test.go b/internal/dslx/http_test.go index 73ed5b66fd..17d2419de9 100644 --- a/internal/dslx/http_test.go +++ b/internal/dslx/http_test.go @@ -2,6 +2,8 @@ package dslx import ( "context" + "errors" + "fmt" "io" "net/http" "strings" @@ -9,6 +11,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" @@ -89,7 +92,7 @@ func TestHTTPRequest(t *testing.T) { } idGen := &atomic.Int64{} zeroTime := time.Time{} - trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) + trace := measurexlite.NewTrace(idGen.Add(1), zeroTime, "antani") t.Run("with EOF", func(t *testing.T) { httpTransport := HTTPTransport{ @@ -159,6 +162,35 @@ func TestHTTPRequest(t *testing.T) { } }) + // makeSureObservationsContainTags ensures the observations you can extract from + // the given HTTPResponse contain the tags we configured when testing + makeSureObservationsContainTags := func(res *Maybe[*HTTPResponse]) error { + // exclude the case where there was an error + if res.Error != nil { + return fmt.Errorf("unexpected error: %w", res.Error) + } + + // obtain the observations + for _, obs := range ExtractObservations(res) { + + // check the network events + for _, ev := range obs.NetworkEvents { + if diff := cmp.Diff([]string{"antani"}, ev.Tags); diff != "" { + return errors.New(diff) + } + } + + // check the HTTP events + for _, ev := range obs.Requests { + if diff := cmp.Diff([]string{"antani"}, ev.Tags); diff != "" { + return errors.New(diff) + } + } + } + + return nil + } + t.Run("with success (https)", func(t *testing.T) { httpTransport := HTTPTransport{ Address: "1.2.3.4:443", @@ -178,6 +210,7 @@ func TestHTTPRequest(t *testing.T) { if res.State.HTTPResponse == nil || res.State.HTTPResponse.Status != "expected" { t.Fatal("unexpected request") } + makeSureObservationsContainTags(res) }) t.Run("with success (http)", func(t *testing.T) { @@ -199,6 +232,7 @@ func TestHTTPRequest(t *testing.T) { if res.State.HTTPResponse == nil || res.State.HTTPResponse.Status != "expected" { t.Fatal("unexpected request") } + makeSureObservationsContainTags(res) }) t.Run("with header options", func(t *testing.T) { @@ -239,7 +273,6 @@ func TestHTTPRequest(t *testing.T) { t.Fatal("unexpected URL path", res.State.HTTPRequest.URL.Path) } }) - }) } diff --git a/internal/dslx/httpcore.go b/internal/dslx/httpcore.go index d8fa1021e8..fa6f49e0a5 100644 --- a/internal/dslx/httpcore.go +++ b/internal/dslx/httpcore.go @@ -287,15 +287,19 @@ func (f *httpRequestFunc) do( ) (*http.Response, []byte, []*Observations, error) { const maxbody = 1 << 19 // TODO(bassosimone): allow to configure this value? started := input.Trace.TimeSince(input.Trace.ZeroTime) + + // manually create a single 1-length observations structure because + // the trace cannot automatically capture HTTP events observations := []*Observations{ NewObservations(), - } // one entry + } observations[0].NetworkEvents = append(observations[0].NetworkEvents, measurexlite.NewAnnotationArchivalNetworkEvent( input.Trace.Index, started, "http_transaction_start", + input.Trace.Tags()..., )) resp, err := input.Transport.RoundTrip(req) @@ -312,6 +316,7 @@ func (f *httpRequestFunc) do( input.Trace.Index, finished, "http_transaction_done", + input.Trace.Tags()..., )) observations[0].Requests = append(observations[0].Requests, @@ -328,6 +333,7 @@ func (f *httpRequestFunc) do( body, err, finished, + input.Trace.Tags()..., )) return resp, body, observations, err diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index d7a5fa1a33..be74af55f0 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -82,7 +82,7 @@ type quicHandshakeFunc struct { func (f *quicHandshakeFunc) Apply( ctx context.Context, input *Endpoint) *Maybe[*QUICConnection] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime) + trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) // use defaults or user-configured overrides serverName := f.serverName(input) diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index d20072cca4..44fec90bb0 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" @@ -67,37 +68,70 @@ func TestQUICHandshake(t *testing.T) { tests := map[string]struct { dialer model.QUICDialer sni string + tags []string expectConn quic.EarlyConnection expectErr error closed bool }{ - "with EOF": {expectConn: nil, expectErr: io.EOF, closed: false, dialer: eofDialer}, - "success": {expectConn: plainConn, expectErr: nil, closed: true, dialer: goodDialer}, - "with sni": {expectConn: plainConn, expectErr: nil, closed: true, dialer: goodDialer, sni: "sni.com"}, + "with EOF": { + tags: []string{}, + expectConn: nil, + expectErr: io.EOF, + closed: false, + dialer: eofDialer, + }, + "success": { + tags: []string{"antani"}, + expectConn: plainConn, + expectErr: nil, + closed: true, + dialer: goodDialer, + }, + "with sni": { + tags: []string{}, + expectConn: plainConn, + expectErr: nil, + closed: true, + dialer: goodDialer, + sni: "sni.com", + }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { pool := &ConnPool{} - quicHandshake := &quicHandshakeFunc{Pool: pool, dialer: tt.dialer, ServerName: tt.sni} + quicHandshake := &quicHandshakeFunc{ + Pool: pool, + dialer: tt.dialer, + ServerName: tt.sni, + } endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "udp", IDGenerator: &atomic.Int64{}, Logger: model.DiscardLogger, + Tags: tt.tags, ZeroTime: time.Time{}, } res := quicHandshake.Apply(context.Background(), endpoint) if res.Error != tt.expectErr { t.Fatalf("unexpected error: %s", res.Error) } - if res.State.QUICConn != tt.expectConn { - t.Fatalf("unexpected conn: %s", res.State.QUICConn) + if res.State == nil || res.State.QUICConn != tt.expectConn { + t.Fatal("unexpected conn") } pool.Close() if wasClosed != tt.closed { t.Fatalf("unexpected connection closed state: %v", wasClosed) } + if len(tt.tags) > 0 { + if res.State == nil { + t.Fatal("expected non-nil res.State") + } + if diff := cmp.Diff([]string{"antani"}, res.State.Trace.Tags()); diff != "" { + t.Fatal(diff) + } + } }) wasClosed = false } @@ -145,6 +179,7 @@ func TestServerNameQUIC(t *testing.T) { t.Fatalf("unexpected server name: %s", serverName) } }) + t.Run("With input domain", func(t *testing.T) { domain := "domain" endpoint := &Endpoint{ @@ -158,6 +193,7 @@ func TestServerNameQUIC(t *testing.T) { t.Fatalf("unexpected server name: %s", serverName) } }) + t.Run("With input host address", func(t *testing.T) { hostaddr := "example.com" endpoint := &Endpoint{ @@ -170,6 +206,7 @@ func TestServerNameQUIC(t *testing.T) { t.Fatalf("unexpected server name: %s", serverName) } }) + t.Run("With input IP address", func(t *testing.T) { ip := "1.1.1.1" endpoint := &Endpoint{ diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index 97c08096a7..d8e7b6652c 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -32,7 +32,7 @@ func (f *tcpConnectFunc) Apply( ctx context.Context, input *Endpoint) *Maybe[*TCPConnection] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime) + trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) // start the operation logger ol := measurexlite.NewOperationLogger( @@ -47,10 +47,8 @@ func (f *tcpConnectFunc) Apply( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - dialer := f.dialer - if dialer == nil { - dialer = trace.NewDialerWithoutResolver(input.Logger) - } + // obtain the dialer to use + dialer := f.dialerOrDefault(trace, input.Logger) // connect conn, err := dialer.DialContext(ctx, "tcp", input.Address) @@ -80,6 +78,15 @@ func (f *tcpConnectFunc) Apply( } } +// dialerOrDefault is the function used to obtain a dialer +func (f *tcpConnectFunc) dialerOrDefault(trace *measurexlite.Trace, logger model.Logger) model.Dialer { + dialer := f.dialer + if dialer == nil { + dialer = trace.NewDialerWithoutResolver(logger) + } + return dialer +} + // TCPConnection is an established TCP connection. If you initialize // manually, init at least the ones marked as MANDATORY. type TCPConnection struct { diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index 088e04b2f7..6a94ea35c5 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -43,13 +45,26 @@ func TestTCPConnect(t *testing.T) { } tests := map[string]struct { + tags []string dialer model.Dialer expectConn net.Conn expectErr error closed bool }{ - "with EOF": {expectConn: nil, expectErr: io.EOF, closed: false, dialer: eofDialer}, - "success": {expectConn: plainConn, expectErr: nil, closed: true, dialer: goodDialer}, + "with EOF": { + tags: []string{}, + expectConn: nil, + expectErr: io.EOF, + closed: false, + dialer: eofDialer, + }, + "success": { + tags: []string{"antani"}, + expectConn: plainConn, + expectErr: nil, + closed: true, + dialer: goodDialer, + }, } for name, tt := range tests { @@ -61,21 +76,42 @@ func TestTCPConnect(t *testing.T) { Network: "tcp", IDGenerator: &atomic.Int64{}, Logger: model.DiscardLogger, + Tags: tt.tags, ZeroTime: time.Time{}, } res := tcpConnect.Apply(context.Background(), endpoint) if res.Error != tt.expectErr { t.Fatalf("unexpected error: %s", res.Error) } - if res.State.Conn != tt.expectConn { - t.Fatalf("unexpected conn: %s", res.State.Conn) + if res.State == nil || res.State.Conn != tt.expectConn { + t.Fatal("unexpected conn") } pool.Close() if wasClosed != tt.closed { t.Fatalf("unexpected connection closed state: %v", wasClosed) } + if len(tt.tags) > 0 { + if res.State == nil { + t.Fatal("expected non-nil res.State") + } + if diff := cmp.Diff([]string{"antani"}, res.State.Trace.Tags()); diff != "" { + t.Fatal(diff) + } + } }) wasClosed = false } }) } + +// Make sure we get a valid dialer if no mocked dialer is configured +func TestDialerOrDefault(t *testing.T) { + f := &tcpConnectFunc{ + p: &ConnPool{}, + dialer: nil, + } + dialer := f.dialerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) + if dialer == nil { + t.Fatal("expected non-nil dialer here") + } +} diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index 3f7bcb73a9..f74a51cad1 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -108,11 +108,10 @@ func (f *tlsHandshakeFunc) Apply( nextProto, ) + // obtain the handshaker for use + handshaker := f.handshakerOrDefault(trace, input.Logger) + // setup - handshaker := f.handshaker - if handshaker == nil { - handshaker = trace.NewTLSHandshakerStdlib(input.Logger) - } config := &tls.Config{ NextProtos: nextProto, InsecureSkipVerify: f.InsecureSkipVerify, @@ -157,6 +156,15 @@ func (f *tlsHandshakeFunc) Apply( } } +// handshakerOrDefault is the function used to obtain an handshaker +func (f *tlsHandshakeFunc) handshakerOrDefault(trace *measurexlite.Trace, logger model.Logger) model.TLSHandshaker { + handshaker := f.handshaker + if handshaker == nil { + handshaker = trace.NewTLSHandshakerStdlib(logger) + } + return handshaker +} + func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { if f.ServerName != "" { return f.ServerName diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index e8282ec542..652b688e0b 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -235,3 +235,19 @@ func TestServerNameTLS(t *testing.T) { } }) } + +// Make sure we get a valid handshaker if no mocked handshaker is configured +func TestHandshakerOrDefault(t *testing.T) { + f := &tlsHandshakeFunc{ + InsecureSkipVerify: false, + NextProto: []string{}, + Pool: &ConnPool{}, + RootCAs: &x509.CertPool{}, + ServerName: "", + handshaker: nil, + } + handshaker := f.handshakerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) + if handshaker == nil { + t.Fatal("expected non-nil handshaker here") + } +} diff --git a/internal/measurexlite/conn.go b/internal/measurexlite/conn.go index 066086f335..57ec71caab 100644 --- a/internal/measurexlite/conn.go +++ b/internal/measurexlite/conn.go @@ -44,13 +44,17 @@ func (c *connTrace) Read(b []byte) (int, error) { network := c.RemoteAddr().Network() addr := c.RemoteAddr().String() started := c.tx.TimeSince(c.tx.ZeroTime) + count, err := c.Conn.Read(b) + finished := c.tx.TimeSince(c.tx.ZeroTime) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.ReadOperation, network, addr, count, err, finished): + c.tx.Index, started, netxlite.ReadOperation, network, addr, count, + err, finished, c.tx.tags...): default: // buffer is full } + return count, err } @@ -59,13 +63,17 @@ func (c *connTrace) Write(b []byte) (int, error) { network := c.RemoteAddr().Network() addr := c.RemoteAddr().String() started := c.tx.TimeSince(c.tx.ZeroTime) + count, err := c.Conn.Write(b) + finished := c.tx.TimeSince(c.tx.ZeroTime) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.WriteOperation, network, addr, count, err, finished): + c.tx.Index, started, netxlite.WriteOperation, network, addr, count, + err, finished, c.tx.tags...): default: // buffer is full } + return count, err } @@ -96,14 +104,18 @@ type udpLikeConnTrace struct { // Read implements model.UDPLikeConn.ReadFrom and saves network events. func (c *udpLikeConnTrace) ReadFrom(b []byte) (int, net.Addr, error) { started := c.tx.TimeSince(c.tx.ZeroTime) + count, addr, err := c.UDPLikeConn.ReadFrom(b) + finished := c.tx.TimeSince(c.tx.ZeroTime) address := addrStringIfNotNil(addr) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.ReadFromOperation, "udp", address, count, err, finished): + c.tx.Index, started, netxlite.ReadFromOperation, "udp", address, count, + err, finished, c.tx.tags...): default: // buffer is full } + return count, addr, err } @@ -111,13 +123,17 @@ func (c *udpLikeConnTrace) ReadFrom(b []byte) (int, net.Addr, error) { func (c *udpLikeConnTrace) WriteTo(b []byte, addr net.Addr) (int, error) { started := c.tx.TimeSince(c.tx.ZeroTime) address := addr.String() + count, err := c.UDPLikeConn.WriteTo(b, addr) + finished := c.tx.TimeSince(c.tx.ZeroTime) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.WriteToOperation, "udp", address, count, err, finished): + c.tx.Index, started, netxlite.WriteToOperation, "udp", address, count, + err, finished, c.tx.tags...): default: // buffer is full } + return count, err } @@ -131,8 +147,9 @@ func addrStringIfNotNil(addr net.Addr) (out string) { } // NewArchivalNetworkEvent creates a new model.ArchivalNetworkEvent. -func NewArchivalNetworkEvent(index int64, started time.Duration, operation string, network string, - address string, count int, err error, finished time.Duration) *model.ArchivalNetworkEvent { +func NewArchivalNetworkEvent(index int64, started time.Duration, operation string, + network string, address string, count int, err error, finished time.Duration, + tags ...string) *model.ArchivalNetworkEvent { return &model.ArchivalNetworkEvent{ Address: address, Failure: tracex.NewFailure(err), @@ -142,15 +159,15 @@ func NewArchivalNetworkEvent(index int64, started time.Duration, operation strin T0: started.Seconds(), T: finished.Seconds(), TransactionID: index, - Tags: []string{}, + Tags: copyAndNormalizeTags(tags), } } // NewAnnotationArchivalNetworkEvent is a simplified NewArchivalNetworkEvent // where we create a simple annotation without attached I/O info. func NewAnnotationArchivalNetworkEvent( - index int64, time time.Duration, operation string) *model.ArchivalNetworkEvent { - return NewArchivalNetworkEvent(index, time, operation, "", "", 0, nil, time) + index int64, time time.Duration, operation string, tags ...string) *model.ArchivalNetworkEvent { + return NewArchivalNetworkEvent(index, time, operation, "", "", 0, nil, time, tags...) } // NetworkEvents drains the network events buffered inside the NetworkEvent channel. @@ -174,3 +191,12 @@ func (tx *Trace) FirstNetworkEventOrNil() *model.ArchivalNetworkEvent { } return ev[0] } + +// copyAndNormalizeTags ensures that we map nil tags to []string +// and that we return a copy of the tags. +func copyAndNormalizeTags(tags []string) []string { + if len(tags) <= 0 { + tags = []string{} + } + return append([]string{}, tags...) +} diff --git a/internal/measurexlite/conn_test.go b/internal/measurexlite/conn_test.go index 56dc01d504..ae6e195eb3 100644 --- a/internal/measurexlite/conn_test.go +++ b/internal/measurexlite/conn_test.go @@ -35,6 +35,29 @@ func TestMaybeClose(t *testing.T) { }) } +func TestMaybeCloseUDPLikeConn(t *testing.T) { + t.Run("with nil conn", func(t *testing.T) { + var conn model.UDPLikeConn = nil + MaybeCloseUDPLikeConn(conn) + }) + + t.Run("with nonnil conn", func(t *testing.T) { + var called bool + conn := &mocks.UDPLikeConn{ + MockClose: func() error { + called = true + return nil + }, + } + if err := MaybeCloseUDPLikeConn(conn); err != nil { + t.Fatal(err) + } + if !called { + t.Fatal("not called") + } + }) +} + func TestWrapNetConn(t *testing.T) { t.Run("WrapNetConn wraps the conn", func(t *testing.T) { underlying := &mocks.Conn{} @@ -68,8 +91,8 @@ func TestWrapNetConn(t *testing.T) { } zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now // deterministic time counting + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now // deterministic time counting conn := trace.MaybeWrapNetConn(underlying) const bufsiz = 128 buffer := make([]byte, bufsiz) @@ -91,7 +114,7 @@ func TestWrapNetConn(t *testing.T) { Operation: netxlite.ReadOperation, Proto: "tcp", T: 1.0, - Tags: []string{}, + Tags: []string{"antani"}, } got := events[0] if diff := cmp.Diff(expect, got); diff != "" { @@ -152,8 +175,8 @@ func TestWrapNetConn(t *testing.T) { } zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now // deterministic time tracking + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now // deterministic time tracking conn := trace.MaybeWrapNetConn(underlying) const bufsiz = 128 buffer := make([]byte, bufsiz) @@ -175,7 +198,7 @@ func TestWrapNetConn(t *testing.T) { Operation: netxlite.WriteOperation, Proto: "tcp", T: 1.0, - Tags: []string{}, + Tags: []string{"antani"}, } got := events[0] if diff := cmp.Diff(expect, got); diff != "" { @@ -246,8 +269,8 @@ func TestWrapUDPLikeConn(t *testing.T) { } zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now // deterministic time counting + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now // deterministic time counting conn := trace.MaybeWrapUDPLikeConn(underlying) const bufsiz = 128 buffer := make([]byte, bufsiz) @@ -272,7 +295,7 @@ func TestWrapUDPLikeConn(t *testing.T) { Operation: "read_from", Proto: "udp", T: 1.0, - Tags: []string{}, + Tags: []string{"antani"}, } got := events[0] if diff := cmp.Diff(expect, got); diff != "" { @@ -320,8 +343,8 @@ func TestWrapUDPLikeConn(t *testing.T) { } zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now // deterministic time tracking + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now // deterministic time tracking conn := trace.MaybeWrapUDPLikeConn(underlying) const bufsiz = 128 buffer := make([]byte, bufsiz) @@ -348,7 +371,7 @@ func TestWrapUDPLikeConn(t *testing.T) { Operation: "write_to", Proto: "udp", T: 1.0, - Tags: []string{}, + Tags: []string{"antani"}, } got := events[0] if diff := cmp.Diff(expect, got); diff != "" { @@ -445,10 +468,10 @@ func TestNewAnnotationArchivalNetworkEvent(t *testing.T) { T0: duration.Seconds(), T: duration.Seconds(), TransactionID: index, - Tags: []string{}, + Tags: []string{"antani"}, } got := NewAnnotationArchivalNetworkEvent( - index, duration, operation, + index, duration, operation, "antani", ) if diff := cmp.Diff(expect, got); diff != "" { t.Fatal(diff) diff --git a/internal/measurexlite/dialer.go b/internal/measurexlite/dialer.go index bfd67e17e4..0a785fb797 100644 --- a/internal/measurexlite/dialer.go +++ b/internal/measurexlite/dialer.go @@ -22,7 +22,7 @@ import ( // // Note: unlike code in netx or measurex, this factory DOES NOT return you a // dialer that also performs wrapping of a net.Conn in case of success. If you -// want to wrap the conn, you need to wrap it explicitly using WrapNetConn. +// want to wrap the conn, you need to wrap it explicitly using model.Trace.WrapNetConn. func (tx *Trace) NewDialerWithoutResolver(dl model.DebugLogger) model.Dialer { return &dialerTrace{ d: tx.newDialerWithoutResolver(dl), @@ -62,6 +62,7 @@ func (tx *Trace) OnConnectDone( remoteAddr, err, finished.Sub(tx.ZeroTime), + tx.tags..., ): default: // buffer is full } @@ -78,6 +79,7 @@ func (tx *Trace) OnConnectDone( 0, err, finished.Sub(tx.ZeroTime), + tx.tags..., ): default: // buffer is full } @@ -91,7 +93,7 @@ func (tx *Trace) OnConnectDone( // NewArchivalTCPConnectResult generates a model.ArchivalTCPConnectResult // from the available information right after connect returns. func NewArchivalTCPConnectResult(index int64, started time.Duration, address string, - err error, finished time.Duration) *model.ArchivalTCPConnectResult { + err error, finished time.Duration, tags ...string) *model.ArchivalTCPConnectResult { ip, port := archivalSplitHostPort(address) return &model.ArchivalTCPConnectResult{ IP: ip, @@ -103,6 +105,7 @@ func NewArchivalTCPConnectResult(index int64, started time.Duration, address str }, T0: started.Seconds(), T: finished.Seconds(), + Tags: copyAndNormalizeTags(tags), TransactionID: index, } } diff --git a/internal/measurexlite/dialer_test.go b/internal/measurexlite/dialer_test.go index 1e5b40a776..d1f108447a 100644 --- a/internal/measurexlite/dialer_test.go +++ b/internal/measurexlite/dialer_test.go @@ -21,7 +21,7 @@ func TestNewDialerWithoutResolver(t *testing.T) { underlying := &mocks.Dialer{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { + trace.newDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { return underlying } dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) @@ -46,7 +46,7 @@ func TestNewDialerWithoutResolver(t *testing.T) { return nil, expectedErr }, } - trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { + trace.newDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { return underlying } dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) @@ -72,7 +72,7 @@ func TestNewDialerWithoutResolver(t *testing.T) { called = true }, } - trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { + trace.newDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { return underlying } dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) @@ -85,8 +85,8 @@ func TestNewDialerWithoutResolver(t *testing.T) { t.Run("DialContext saves into the trace", func(t *testing.T) { zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now // deterministic time tracking + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now // deterministic time tracking dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) ctx, cancel := context.WithCancel(context.Background()) cancel() // we cancel immediately so connect is ~instantaneous @@ -113,7 +113,8 @@ func TestNewDialerWithoutResolver(t *testing.T) { Failure: &expectedFailure, Success: false, }, - T: time.Second.Seconds(), + T: time.Second.Seconds(), + Tags: []string{"antani"}, } got := events[0] if diff := cmp.Diff(expect, got); diff != "" { @@ -136,7 +137,7 @@ func TestNewDialerWithoutResolver(t *testing.T) { T0: 0, T: time.Second.Seconds(), TransactionID: 0, - Tags: []string{}, + Tags: []string{"antani"}, } got := events[0] if diff := cmp.Diff(expect, got); diff != "" { @@ -209,8 +210,10 @@ func TestNewDialerWithoutResolver(t *testing.T) { zeroTime := time.Now() trace := NewTrace(0, zeroTime) dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) - cancel() // we cancel immediately so connect is ~instantaneous + cancel() // we cancel immediately so connect is ~instantaneous + conn, err := dialer.DialContext(ctx, "udp", "dns.google:443") // domain if !errors.Is(err, netxlite.ErrNoResolver) { t.Fatal("unexpected err", err) diff --git a/internal/measurexlite/dns.go b/internal/measurexlite/dns.go index 5d374859de..ddecb40cc3 100644 --- a/internal/measurexlite/dns.go +++ b/internal/measurexlite/dns.go @@ -54,6 +54,7 @@ func (r *resolverTrace) emitResolveStart() { select { case r.tx.networkEvent <- NewAnnotationArchivalNetworkEvent( r.tx.Index, r.tx.TimeSince(r.tx.ZeroTime), "resolve_start", + r.tx.tags..., ): default: // buffer is full } @@ -64,6 +65,7 @@ func (r *resolverTrace) emiteResolveDone() { select { case r.tx.networkEvent <- NewAnnotationArchivalNetworkEvent( r.tx.Index, r.tx.TimeSince(r.tx.ZeroTime), "resolve_done", + r.tx.tags..., ): default: // buffer is full } @@ -109,6 +111,7 @@ func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.Logger, URL string func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) { t := finished.Sub(tx.ZeroTime) + select { case tx.dnsLookup <- NewArchivalDNSLookupResultFromRoundTrip( tx.Index, @@ -119,8 +122,11 @@ func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resol addrs, err, t, + tx.tags..., ): + default: + // buffer is full } } @@ -137,8 +143,9 @@ type DNSNetworkAddresser interface { // NewArchivalDNSLookupResultFromRoundTrip generates a model.ArchivalDNSLookupResultFromRoundTrip // from the available information right after the DNS RoundTrip -func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso DNSNetworkAddresser, query model.DNSQuery, - response model.DNSResponse, addrs []string, err error, finished time.Duration) *model.ArchivalDNSLookupResult { +func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, + reso DNSNetworkAddresser, query model.DNSQuery, response model.DNSResponse, + addrs []string, err error, finished time.Duration, tags ...string) *model.ArchivalDNSLookupResult { return &model.ArchivalDNSLookupResult{ Answers: newArchivalDNSAnswers(addrs, response), Engine: reso.Network(), @@ -153,6 +160,7 @@ func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, ResolverAddress: reso.Address(), T0: started.Seconds(), T: finished.Seconds(), + Tags: copyAndNormalizeTags(tags), TransactionID: index, } } @@ -189,8 +197,9 @@ func newArchivalDNSAnswers(addrs []string, resp model.DNSResponse) (out []model. log.Printf("BUG: NewArchivalDNSLookupResult: invalid IP address: %s", addr) continue } - asn, org, _ := geoipx.LookupASN(addr) + asn, org, _ := geoipx.LookupASN(addr) // error if not in the DB; returns sensible values on error switch ipv6 { + case false: out = append(out, model.ArchivalDNSAnswer{ ASN: int64(asn), @@ -201,6 +210,7 @@ func newArchivalDNSAnswers(addrs []string, resp model.DNSResponse) (out []model. IPv6: "", TTL: nil, }) + case true: out = append(out, model.ArchivalDNSAnswer{ ASN: int64(asn), @@ -265,6 +275,7 @@ var ErrDelayedDNSResponseBufferFull = errors.New("buffer full") func (tx *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) error { t := finished.Sub(tx.ZeroTime) + select { case tx.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip( tx.Index, @@ -275,8 +286,10 @@ func (tx *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, addrs, err, t, + tx.tags..., ): return nil + default: return ErrDelayedDNSResponseBufferFull } @@ -291,8 +304,10 @@ func (tx *Trace) DelayedDNSResponseWithTimeout(ctx context.Context, timeout time.Duration) (out []*model.ArchivalDNSLookupResult) { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() + for { select { + case <-ctx.Done(): for { // once the context is done enter in channel draining mode select { @@ -302,6 +317,7 @@ func (tx *Trace) DelayedDNSResponseWithTimeout(ctx context.Context, return } } + case ev := <-tx.delayedDNSResponse: out = append(out, ev) } diff --git a/internal/measurexlite/dns_test.go b/internal/measurexlite/dns_test.go index 5ce2c2078f..aba71d1489 100644 --- a/internal/measurexlite/dns_test.go +++ b/internal/measurexlite/dns_test.go @@ -124,8 +124,8 @@ func TestNewResolver(t *testing.T) { t.Run("LookupHost saves into trace", func(t *testing.T) { zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now txp := &mocks.DNSTransport{ MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { response := &mocks.DNSResponse{ @@ -198,6 +198,9 @@ func TestNewResolver(t *testing.T) { if ev.Answers[1].AnswerType != "CNAME " && ev.Answers[1].Hostname != "dns.google." { t.Fatal("unexpected second answer (expected CNAME)", ev.Answers[1]) } + if diff := cmp.Diff([]string{"antani"}, ev.Tags); diff != "" { + t.Fatal(diff) + } } }) @@ -209,6 +212,9 @@ func TestNewResolver(t *testing.T) { foundNames := map[string]int{} for _, ev := range events { foundNames[ev.Operation]++ + if diff := cmp.Diff([]string{"antani"}, ev.Tags); diff != "" { + t.Fatal(diff) + } } if foundNames["resolve_start"] != 1 { t.Fatal("missing resolve_start") @@ -225,7 +231,7 @@ func TestNewResolver(t *testing.T) { trace := NewTrace(0, zeroTime) trace.dnsLookup = make(chan *model.ArchivalDNSLookupResult) // no buffer trace.networkEvent = make(chan *model.ArchivalNetworkEvent) // ditto - trace.TimeNowFn = td.Now + trace.timeNowFn = td.Now txp := &mocks.DNSTransport{ MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { response := &mocks.DNSResponse{ @@ -378,8 +384,8 @@ func TestDelayedDNSResponseWithTimeout(t *testing.T) { t.Run("when buffer is not full", func(t *testing.T) { zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now txp := &mocks.DNSTransport{ MockNetwork: func() string { return "udp" @@ -420,13 +426,16 @@ func TestDelayedDNSResponseWithTimeout(t *testing.T) { if len(got) != 1 { t.Fatal("unexpected output from trace") } + if diff := cmp.Diff([]string{"antani"}, got[0].Tags); diff != "" { + t.Fatal(diff) + } }) t.Run("when buffer is full", func(t *testing.T) { zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now + trace.timeNowFn = td.Now trace.delayedDNSResponse = make(chan *model.ArchivalDNSLookupResult) // no buffer txp := &mocks.DNSTransport{ MockNetwork: func() string { @@ -475,8 +484,8 @@ func TestDelayedDNSResponseWithTimeout(t *testing.T) { t.Run("context is already cancelled and we still drain the trace", func(t *testing.T) { zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now txp := &mocks.DNSTransport{ MockNetwork: func() string { return "udp" @@ -525,8 +534,8 @@ func TestDelayedDNSResponseWithTimeout(t *testing.T) { t.Run("normal case where the context times out after we start draining", func(t *testing.T) { zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now txp := &mocks.DNSTransport{ MockNetwork: func() string { return "udp" diff --git a/internal/measurexlite/failure.go b/internal/measurexlite/failure.go index 3e750b5d9c..a2ddd55d8b 100644 --- a/internal/measurexlite/failure.go +++ b/internal/measurexlite/failure.go @@ -16,9 +16,11 @@ import ( // See https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md // for more information about OONI failures. func NewFailure(err error) *string { + // make sure we behave when passed a nil input (as documented) if err == nil { return nil } + // The following code guarantees that the error is always wrapped even // when we could not actually hit our code that does the wrapping. A case // in which this could happen is with context deadline for HTTP when you @@ -29,6 +31,8 @@ func NewFailure(err error) *string { couldConvert := errors.As(err, &errWrapper) runtimex.Assert(couldConvert, "we should have an ErrWrapper here") } + + // handle the case where there's no actual failure (this would be a BUG) s := errWrapper.Failure if s == "" { s = "unknown_failure: errWrapper.Failure is empty" diff --git a/internal/measurexlite/http.go b/internal/measurexlite/http.go index 27a6173494..e82800ff0d 100644 --- a/internal/measurexlite/http.go +++ b/internal/measurexlite/http.go @@ -41,10 +41,12 @@ import ( // // - err is the possibly-nil error that occurred during the transaction; // -// - finished is when we finished reading the response's body. +// - finished is when we finished reading the response's body; +// +// - tags contains optional tags to fill the .Tags field in the archival format. func NewArchivalHTTPRequestResult(index int64, started time.Duration, network, address, alpn string, - transport string, req *http.Request, resp *http.Response, maxRespBodySize int64, body []byte, err error, - finished time.Duration) *model.ArchivalHTTPRequestResult { + transport string, req *http.Request, resp *http.Response, maxRespBodySize int64, body []byte, + err error, finished time.Duration, tags ...string) *model.ArchivalHTTPRequestResult { return &model.ArchivalHTTPRequestResult{ Network: network, Address: address, @@ -70,6 +72,7 @@ func NewArchivalHTTPRequestResult(index int64, started time.Duration, network, a }, T0: started.Seconds(), T: finished.Seconds(), + Tags: copyAndNormalizeTags(tags), TransactionID: index, } } @@ -173,10 +176,12 @@ func newHTTPHeaderList(header http.Header) (out []model.ArchivalHTTPHeader) { for key := range header { keys = append(keys, key) } - // ensure the output is consistent, which helps with testing + + // ensure the output is consistent, which helps with testing; // for an example of why we need to sort headers, see // https://github.com/ooni/probe-engine/pull/751/checks?check_run_id=853562310 sort.Strings(keys) + for _, key := range keys { for _, value := range header[key] { out = append(out, model.ArchivalHTTPHeader{ diff --git a/internal/measurexlite/http_test.go b/internal/measurexlite/http_test.go index da36b44fb7..3f0ad0892e 100644 --- a/internal/measurexlite/http_test.go +++ b/internal/measurexlite/http_test.go @@ -26,6 +26,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { body []byte err error finished time.Duration + tags []string } type config struct { @@ -49,6 +50,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { body: nil, err: nil, finished: 0, + tags: nil, }, expect: &model.ArchivalHTTPRequestResult{ Network: "", @@ -75,6 +77,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { }, T0: 0, T: 0, + Tags: []string{}, TransactionID: 0, }, }, { @@ -103,6 +106,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { body: nil, err: netxlite.NewTopLevelGenericErrWrapper(netxlite.ECONNRESET), finished: 750 * time.Millisecond, + tags: []string{"antani"}, }, expect: &model.ArchivalHTTPRequestResult{ Network: "tcp", @@ -145,6 +149,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { }, T0: 0.25, T: 0.75, + Tags: []string{"antani"}, TransactionID: 1, }, }, { @@ -179,6 +184,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { body: filtering.HTTPBlockpage451, err: nil, finished: 1500 * time.Millisecond, + tags: []string{"antani"}, }, expect: &model.ArchivalHTTPRequestResult{ Network: "udp", @@ -233,6 +239,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { }, T0: 1.4, T: 1.5, + Tags: []string{"antani"}, TransactionID: 44, }, }, { @@ -275,6 +282,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { body: nil, err: nil, finished: 1500 * time.Millisecond, + tags: []string{"antani"}, }, expect: &model.ArchivalHTTPRequestResult{ Network: "udp", @@ -337,6 +345,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { }, T0: 1.4, T: 1.5, + Tags: []string{"antani"}, TransactionID: 47, }, }} @@ -356,6 +365,7 @@ func TestNewArchivalHTTPRequestResult(t *testing.T) { cnf.args.body, cnf.args.err, cnf.args.finished, + cnf.args.tags..., ) if diff := cmp.Diff(cnf.expect, out); diff != "" { t.Fatal(diff) diff --git a/internal/measurexlite/logger.go b/internal/measurexlite/logger.go index aac482abb9..4e7a0d7f90 100644 --- a/internal/measurexlite/logger.go +++ b/internal/measurexlite/logger.go @@ -12,6 +12,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) +// TODO(bassosimone): consider moving inside the logx package? + // NewOperationLogger creates a new logger that logs // about an in-progress operation. If it takes too much // time to emit the result of the operation, the code diff --git a/internal/measurexlite/quic.go b/internal/measurexlite/quic.go index cfdb15125e..d2f4f20029 100644 --- a/internal/measurexlite/quic.go +++ b/internal/measurexlite/quic.go @@ -47,7 +47,8 @@ func (qdx *quicDialerTrace) CloseIdleConnections() { func (tx *Trace) OnQUICHandshakeStart(now time.Time, remoteAddr string, config *quic.Config) { t := now.Sub(tx.ZeroTime) select { - case tx.networkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "quic_handshake_start"): + case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( + tx.Index, t, "quic_handshake_start", tx.tags...): default: } } @@ -56,10 +57,12 @@ func (tx *Trace) OnQUICHandshakeStart(now time.Time, remoteAddr string, config * func (tx *Trace) OnQUICHandshakeDone(started time.Time, remoteAddr string, qconn quic.EarlyConnection, config *tls.Config, err error, finished time.Time) { t := finished.Sub(tx.ZeroTime) + state := tls.ConnectionState{} if qconn != nil { state = qconn.ConnectionState().TLS.ConnectionState } + select { case tx.quicHandshake <- NewArchivalTLSOrQUICHandshakeResult( tx.Index, @@ -70,11 +73,14 @@ func (tx *Trace) OnQUICHandshakeDone(started time.Time, remoteAddr string, qconn state, err, t, + tx.tags..., ): default: // buffer is full } + select { - case tx.networkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "quic_handshake_done"): + case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( + tx.Index, t, "quic_handshake_done", tx.tags...): default: // buffer is full } } diff --git a/internal/measurexlite/quic_test.go b/internal/measurexlite/quic_test.go index 278c9574ca..c28fdf1abd 100644 --- a/internal/measurexlite/quic_test.go +++ b/internal/measurexlite/quic_test.go @@ -14,6 +14,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/quictesting" "github.com/ooni/probe-cli/v3/internal/testingx" ) @@ -22,7 +23,7 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { underlying := &mocks.QUICDialer{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.NewQUICDialerWithoutResolverFn = func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { + trace.newQUICDialerWithoutResolverFn = func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { return underlying } listener := &mocks.QUICListener{} @@ -49,7 +50,7 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { return nil, expectedErr }, } - trace.NewQUICDialerWithoutResolverFn = func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { + trace.newQUICDialerWithoutResolverFn = func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { return underlying } listener := &mocks.QUICListener{} @@ -76,7 +77,7 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { called = true }, } - trace.NewQUICDialerWithoutResolverFn = func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { + trace.newQUICDialerWithoutResolverFn = func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { return underlying } listener := &mocks.QUICListener{} @@ -91,8 +92,8 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { mockedErr := errors.New("mocked") zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now // deterministic time tracking + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now // deterministic time tracking pconn := &mocks.UDPLikeConn{ MockLocalAddr: func() net.Addr { return &net.UDPAddr{ @@ -146,7 +147,7 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { PeerCertificates: []model.ArchivalMaybeBinaryData{}, ServerName: "dns.cloudflare.com", T: time.Second.Seconds(), - Tags: []string{}, + Tags: []string{"antani"}, TLSVersion: "", } got := events[0] @@ -169,7 +170,7 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { Operation: "quic_handshake_start", Proto: "", T: 0, - Tags: []string{}, + Tags: []string{"antani"}, } got := events[0] if diff := cmp.Diff(expect, got); diff != "" { @@ -186,7 +187,7 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { Proto: "", T0: time.Second.Seconds(), T: time.Second.Seconds(), - Tags: []string{}, + Tags: []string{"antani"}, } got := events[1] if diff := cmp.Diff(expect, got); diff != "" { @@ -256,6 +257,42 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { }) } +func TestOnQUICHandshakeDoneExtractsTheConnectionState(t *testing.T) { + // create a trace + trace := NewTrace(0, time.Now()) + + // create a QUIC dialer + quicListener := netxlite.NewQUICListener() + quicDialer := trace.NewQUICDialerWithoutResolver(quicListener, model.DiscardLogger) + + // dial with the endpoint we use for testing + quicConn, err := quicDialer.DialContext( + context.Background(), + quictesting.Endpoint("443"), + &tls.Config{ + InsecureSkipVerify: true, + }, + &quic.Config{}, + ) + defer MaybeCloseQUICConn(quicConn) + + // we do not expect to see an error here + if err != nil { + t.Fatal(err) + } + + // extract the QUIC handshake event + event := trace.FirstQUICHandshakeOrNil() + if event == nil { + t.Fatal("expected non-nil event") + } + + // make sure we have parsed the QUIC connection state + if event.NegotiatedProtocol != "h3" { + t.Fatal("it seems we did not parse the QUIC connection state") + } +} + func TestFirstQUICHandshake(t *testing.T) { t.Run("returns nil when buffer is empty", func(t *testing.T) { zeroTime := time.Now() diff --git a/internal/measurexlite/tls.go b/internal/measurexlite/tls.go index 02eeced88c..e7c80c6a40 100644 --- a/internal/measurexlite/tls.go +++ b/internal/measurexlite/tls.go @@ -44,7 +44,8 @@ func (thx *tlsHandshakerTrace) Handshake( func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) { t := now.Sub(tx.ZeroTime) select { - case tx.networkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_start"): + case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( + tx.Index, t, "tls_handshake_start", tx.tags...): default: // buffer is full } } @@ -53,6 +54,7 @@ func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *t func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) { t := finished.Sub(tx.ZeroTime) + select { case tx.tlsHandshake <- NewArchivalTLSOrQUICHandshakeResult( tx.Index, @@ -63,11 +65,14 @@ func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config state, err, t, + tx.tags..., ): default: // buffer is full } + select { - case tx.networkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_done"): + case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( + tx.Index, t, "tls_handshake_done", tx.tags...): default: // buffer is full } } @@ -76,7 +81,8 @@ func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config // from the available information right after the TLS handshake returns. func NewArchivalTLSOrQUICHandshakeResult( index int64, started time.Duration, network string, address string, config *tls.Config, - state tls.ConnectionState, err error, finished time.Duration) *model.ArchivalTLSOrQUICHandshakeResult { + state tls.ConnectionState, err error, finished time.Duration, + tags ...string) *model.ArchivalTLSOrQUICHandshakeResult { return &model.ArchivalTLSOrQUICHandshakeResult{ Network: network, Address: address, @@ -88,7 +94,7 @@ func NewArchivalTLSOrQUICHandshakeResult( ServerName: config.ServerName, T0: started.Seconds(), T: finished.Seconds(), - Tags: []string{}, + Tags: copyAndNormalizeTags(tags), TLSVersion: netxlite.TLSVersionString(state.Version), TransactionID: index, } @@ -110,12 +116,14 @@ func newArchivalBinaryData(data []byte) model.ArchivalMaybeBinaryData { func TLSPeerCerts( state tls.ConnectionState, err error) (out []model.ArchivalMaybeBinaryData) { out = []model.ArchivalMaybeBinaryData{} + var x509HostnameError x509.HostnameError if errors.As(err, &x509HostnameError) { // Test case: https://wrong.host.badssl.com/ out = append(out, newArchivalBinaryData(x509HostnameError.Certificate.Raw)) return } + var x509UnknownAuthorityError x509.UnknownAuthorityError if errors.As(err, &x509UnknownAuthorityError) { // Test case: https://self-signed.badssl.com/. This error has @@ -123,12 +131,14 @@ func TLSPeerCerts( out = append(out, newArchivalBinaryData(x509UnknownAuthorityError.Cert.Raw)) return } + var x509CertificateInvalidError x509.CertificateInvalidError if errors.As(err, &x509CertificateInvalidError) { // Test case: https://expired.badssl.com/ out = append(out, newArchivalBinaryData(x509CertificateInvalidError.Cert.Raw)) return } + for _, cert := range state.PeerCertificates { out = append(out, newArchivalBinaryData(cert.Raw)) } diff --git a/internal/measurexlite/tls_test.go b/internal/measurexlite/tls_test.go index 6eb33dba25..120b4d4eb8 100644 --- a/internal/measurexlite/tls_test.go +++ b/internal/measurexlite/tls_test.go @@ -24,7 +24,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { underlying := &mocks.TLSHandshaker{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { + trace.newTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { return underlying } thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) @@ -49,7 +49,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { return nil, tls.ConnectionState{}, expectedErr }, } - trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { + trace.newTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { return underlying } thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) @@ -73,8 +73,8 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { mockedErr := errors.New("mocked") zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) - trace := NewTrace(0, zeroTime) - trace.TimeNowFn = td.Now // deterministic timing + trace := NewTrace(0, zeroTime, "antani") + trace.timeNowFn = td.Now // deterministic timing thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) ctx := context.Background() tcpConn := &mocks.Conn{ @@ -129,7 +129,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { PeerCertificates: []model.ArchivalMaybeBinaryData{}, ServerName: "dns.cloudflare.com", T: time.Second.Seconds(), - Tags: []string{}, + Tags: []string{"antani"}, TLSVersion: "", } got := events[0] @@ -152,7 +152,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { Operation: "tls_handshake_start", Proto: "", T: 0, - Tags: []string{}, + Tags: []string{"antani"}, } got := events[0] if diff := cmp.Diff(expect, got); diff != "" { @@ -169,7 +169,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { Proto: "", T0: time.Second.Seconds(), T: time.Second.Seconds(), - Tags: []string{}, + Tags: []string{"antani"}, } got := events[1] if diff := cmp.Diff(expect, got); diff != "" { @@ -250,7 +250,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { zeroTime := time.Now() dt := testingx.NewTimeDeterministic(zeroTime) trace := NewTrace(0, zeroTime) - trace.TimeNowFn = dt.Now // deterministic timing + trace.timeNowFn = dt.Now // deterministic timing thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) tlsConfig := &tls.Config{ RootCAs: server.CertPool(), diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go index 0404850ebf..8ab0bddbbe 100644 --- a/internal/measurexlite/trace.go +++ b/internal/measurexlite/trace.go @@ -14,57 +14,51 @@ import ( // Trace implements model.Trace. // -// The zero-value of this struct is invalid. To construct you should either -// fill all the fields marked as MANDATORY or use NewTrace. +// The zero-value of this struct is invalid. To construct use NewTrace. // // # Buffered channels // // NewTrace uses reasonable buffer sizes for the channels used for collecting // events. You should drain the channels used by this implementation after // each operation you perform (i.e., we expect you to peform step-by-step -// measurements). If you want larger (or smaller) buffers, then you should -// construct this data type manually with the desired buffer sizes. -// -// We have convenience methods for extracting events from the buffered -// channels. Otherwise, you could read the channels directly. (In which -// case, remember to issue nonblocking channel reads because channels are -// never closed and they're just written when new events occur.) +// measurements). We have convenience methods for extracting events from the +// buffered channels. type Trace struct { - // Index is the MANDATORY unique index of this trace within the - // current measurement. If you don't care about uniquely identifying - // traces, you can use zero to indicate the "default" trace. + // Index is the unique index of this trace within the + // current measurement. Note that this field MUST be read-only. Writing it + // once you have constructed a trace MAY lead to data races. Index int64 // networkEvent is MANDATORY and buffers network events. networkEvent chan *model.ArchivalNetworkEvent - // NewStdlibResolverFn is OPTIONAL and can be used to overide + // newStdlibResolverFn is OPTIONAL and can be used to overide // calls to the netxlite.NewStdlibResolver factory. - NewStdlibResolverFn func(logger model.Logger) model.Resolver + newStdlibResolverFn func(logger model.Logger) model.Resolver - // NewParallelUDPResolverFn is OPTIONAL and can be used to overide + // newParallelUDPResolverFn is OPTIONAL and can be used to overide // calls to the netxlite.NewParallelUDPResolver factory. - NewParallelUDPResolverFn func(logger model.Logger, dialer model.Dialer, address string) model.Resolver + newParallelUDPResolverFn func(logger model.Logger, dialer model.Dialer, address string) model.Resolver - // NewParallelDNSOverHTTPSResolverFn is OPTIONAL and can be used to overide + // newParallelDNSOverHTTPSResolverFn is OPTIONAL and can be used to overide // calls to the netxlite.NewParallelDNSOverHTTPSUDPResolver factory. - NewParallelDNSOverHTTPSResolverFn func(logger model.Logger, URL string) model.Resolver + newParallelDNSOverHTTPSResolverFn func(logger model.Logger, URL string) model.Resolver - // NewDialerWithoutResolverFn is OPTIONAL and can be used to override + // newDialerWithoutResolverFn is OPTIONAL and can be used to override // calls to the netxlite.NewDialerWithoutResolver factory. - NewDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer + newDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer - // NewTLSHandshakerStdlibFn is OPTIONAL and can be used to overide + // newTLSHandshakerStdlibFn is OPTIONAL and can be used to overide // calls to the netxlite.NewTLSHandshakerStdlib factory. - NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker + newTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker - // NewTLSHandshakerUTLSFn is OPTIONAL and can be used to overide + // newTLSHandshakerUTLSFn is OPTIONAL and can be used to overide // calls to the netxlite.NewTLSHandshakerUTLS factory. - NewTLSHandshakerUTLSFn func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker + newTLSHandshakerUTLSFn func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker // NewDialerWithoutResolverFn is OPTIONAL and can be used to override // calls to the netxlite.NewQUICDialerWithoutResolver factory. - NewQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer + newQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer // dnsLookup is MANDATORY and buffers DNS Lookup observations. dnsLookup chan *model.ArchivalDNSLookupResult @@ -81,37 +75,42 @@ type Trace struct { // quicHandshake is MANDATORY and buffers QUIC handshake observations. quicHandshake chan *model.ArchivalTLSOrQUICHandshakeResult - // TimeNowFn is OPTIONAL and can be used to override calls to time.Now + // tags contains OPTIONAL tags to tag measurements. + tags []string + + // timeNowFn is OPTIONAL and can be used to override calls to time.Now // to produce deterministic timing when testing. - TimeNowFn func() time.Time + timeNowFn func() time.Time - // ZeroTime is the MANDATORY time when we started the current measurement. + // ZeroTime is the time when we started the current measurement. This field + // MUST be read-only. Writing it once you have constructed the trace will + // likely read to data races. ZeroTime time.Time } const ( // NetworkEventBufferSize is the buffer size for constructing - // the Trace's networkEvent buffered channel. + // the internal Trace's networkEvent buffered channel. NetworkEventBufferSize = 64 // DNSLookupBufferSize is the buffer size for constructing - // the Trace's dnsLookup buffered channel. + // the internal Trace's dnsLookup buffered channel. DNSLookupBufferSize = 8 // DNSResponseBufferSize is the buffer size for constructing - // the Trace's dnsDelayedResponse buffered channel. + // the internal Trace's dnsDelayedResponse buffered channel. DelayedDNSResponseBufferSize = 8 // TCPConnectBufferSize is the buffer size for constructing - // the Trace's tcpConnect buffered channel. + // the internal Trace's tcpConnect buffered channel. TCPConnectBufferSize = 8 // TLSHandshakeBufferSize is the buffer for construcing - // the Trace's tlsHandshake buffered channel. + // the internal Trace's tlsHandshake buffered channel. TLSHandshakeBufferSize = 8 // QUICHandshakeBufferSize is the buffer for constructing - // the Trace's quicHandshake buffered channel. + // the internal Trace's quicHandshake buffered channel. QUICHandshakeBufferSize = 8 ) @@ -125,16 +124,24 @@ const ( // - index is the unique index of this trace within the current measurement (use // zero if you don't care about giving this trace a unique ID); // -// - zeroTime is the time when we started the current measurement. -func NewTrace(index int64, zeroTime time.Time) *Trace { +// - zeroTime is the time when we started the current measurement; +// +// - tags contains optional tags to mark the archival data formats specially (e.g., +// to identify that some traces belong to some submeasurements). +func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { return &Trace{ Index: index, networkEvent: make( chan *model.ArchivalNetworkEvent, NetworkEventBufferSize, ), - NewDialerWithoutResolverFn: nil, // use default - NewTLSHandshakerStdlibFn: nil, // use default + newStdlibResolverFn: nil, // use default + newParallelUDPResolverFn: nil, // use default + newParallelDNSOverHTTPSResolverFn: nil, // use default + newDialerWithoutResolverFn: nil, // use default + newTLSHandshakerStdlibFn: nil, // use default + newTLSHandshakerUTLSFn: nil, // use default + newQUICDialerWithoutResolverFn: nil, // use default dnsLookup: make( chan *model.ArchivalDNSLookupResult, DNSLookupBufferSize, @@ -155,7 +162,8 @@ func NewTrace(index int64, zeroTime time.Time) *Trace { chan *model.ArchivalTLSOrQUICHandshakeResult, QUICHandshakeBufferSize, ), - TimeNowFn: nil, // use default + tags: tags, + timeNowFn: nil, // use default ZeroTime: zeroTime, } } @@ -163,8 +171,8 @@ func NewTrace(index int64, zeroTime time.Time) *Trace { // newStdlibResolver indirectly calls the passed netxlite.NewStdlibResolver // thus allowing us to mock this function for testing func (tx *Trace) newStdlibResolver(logger model.Logger) model.Resolver { - if tx.NewStdlibResolverFn != nil { - return tx.NewStdlibResolverFn(logger) + if tx.newStdlibResolverFn != nil { + return tx.newStdlibResolverFn(logger) } return netxlite.NewStdlibResolver(logger) } @@ -172,8 +180,8 @@ func (tx *Trace) newStdlibResolver(logger model.Logger) model.Resolver { // newParallelUDPResolver indirectly calls the passed netxlite.NewParallerUDPResolver // thus allowing us to mock this function for testing func (tx *Trace) newParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { - if tx.NewParallelUDPResolverFn != nil { - return tx.NewParallelUDPResolverFn(logger, dialer, address) + if tx.newParallelUDPResolverFn != nil { + return tx.newParallelUDPResolverFn(logger, dialer, address) } return netxlite.NewParallelUDPResolver(logger, dialer, address) } @@ -181,8 +189,8 @@ func (tx *Trace) newParallelUDPResolver(logger model.Logger, dialer model.Dialer // newParallelDNSOverHTTPSResolver indirectly calls the passed netxlite.NewParallerDNSOverHTTPSResolver // thus allowing us to mock this function for testing func (tx *Trace) newParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver { - if tx.NewParallelDNSOverHTTPSResolverFn != nil { - return tx.NewParallelDNSOverHTTPSResolverFn(logger, URL) + if tx.newParallelDNSOverHTTPSResolverFn != nil { + return tx.newParallelDNSOverHTTPSResolverFn(logger, URL) } return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL) } @@ -190,8 +198,8 @@ func (tx *Trace) newParallelDNSOverHTTPSResolver(logger model.Logger, URL string // newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver // thus allowing us to mock this func for testing. func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer { - if tx.NewDialerWithoutResolverFn != nil { - return tx.NewDialerWithoutResolverFn(dl) + if tx.newDialerWithoutResolverFn != nil { + return tx.newDialerWithoutResolverFn(dl) } return netxlite.NewDialerWithoutResolver(dl) } @@ -199,8 +207,8 @@ func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer { // newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib // thus allowing us to mock this func for testing. func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { - if tx.NewTLSHandshakerStdlibFn != nil { - return tx.NewTLSHandshakerStdlibFn(dl) + if tx.newTLSHandshakerStdlibFn != nil { + return tx.newTLSHandshakerStdlibFn(dl) } return netxlite.NewTLSHandshakerStdlib(dl) } @@ -208,8 +216,8 @@ func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshake // newTLSHandshakerUTLS indirectly calls netxlite.NewTLSHandshakerUTLS // thus allowing us to mock this func for testing. func (tx *Trace) newTLSHandshakerUTLS(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { - if tx.NewTLSHandshakerUTLSFn != nil { - return tx.NewTLSHandshakerUTLSFn(dl, id) + if tx.newTLSHandshakerUTLSFn != nil { + return tx.newTLSHandshakerUTLSFn(dl, id) } return netxlite.NewTLSHandshakerUTLS(dl, id) } @@ -217,16 +225,16 @@ func (tx *Trace) newTLSHandshakerUTLS(dl model.DebugLogger, id *utls.ClientHello // newQUICDialerWithoutResolver indirectly calls netxlite.NewQUICDialerWithoutResolver // thus allowing us to mock this func for testing. func (tx *Trace) newQUICDialerWithoutResolver(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { - if tx.NewQUICDialerWithoutResolverFn != nil { - return tx.NewQUICDialerWithoutResolverFn(listener, dl) + if tx.newQUICDialerWithoutResolverFn != nil { + return tx.newQUICDialerWithoutResolverFn(listener, dl) } return netxlite.NewQUICDialerWithoutResolver(listener, dl) } // TimeNow implements model.Trace.TimeNow. func (tx *Trace) TimeNow() time.Time { - if tx.TimeNowFn != nil { - return tx.TimeNowFn() + if tx.timeNowFn != nil { + return tx.timeNowFn() } return time.Now() } @@ -236,4 +244,9 @@ func (tx *Trace) TimeSince(t0 time.Time) time.Duration { return tx.TimeNow().Sub(t0) } +// Tags returns a copy of the tags configured for this trace. +func (tx *Trace) Tags() []string { + return copyAndNormalizeTags(tx.tags) +} + var _ model.Trace = &Trace{} diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index fe578c9118..0eb033cf6d 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" @@ -50,43 +51,43 @@ func TestNewTrace(t *testing.T) { }) t.Run("NewStdlibResolverFn is nil", func(t *testing.T) { - if trace.NewStdlibResolverFn != nil { + if trace.newStdlibResolverFn != nil { t.Fatal("expected nil NewStdlibResolverFn") } }) t.Run("NewParallelUDPResolverFn is nil", func(t *testing.T) { - if trace.NewParallelUDPResolverFn != nil { + if trace.newParallelUDPResolverFn != nil { t.Fatal("expected nil NewParallelUDPResolverFn") } }) t.Run("NewParallelDNSOverHTTPSResolverFn is nil", func(t *testing.T) { - if trace.NewParallelDNSOverHTTPSResolverFn != nil { + if trace.newParallelDNSOverHTTPSResolverFn != nil { t.Fatal("expected nil NewParallelDNSOverHTTPSResolverFn") } }) t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) { - if trace.NewDialerWithoutResolverFn != nil { + if trace.newDialerWithoutResolverFn != nil { t.Fatal("expected nil NewDialerWithoutResolverFn") } }) t.Run("NewTLSHandshakerStdlibFn is nil", func(t *testing.T) { - if trace.NewTLSHandshakerStdlibFn != nil { + if trace.newTLSHandshakerStdlibFn != nil { t.Fatal("expected nil NewTLSHandshakerStdlibFn") } }) t.Run("newTLShandshakerUTLSFn is nil", func(t *testing.T) { - if trace.NewTLSHandshakerUTLSFn != nil { + if trace.newTLSHandshakerUTLSFn != nil { t.Fatal("expected nil NewTLSHandshakerUTLSfn") } }) t.Run("NewQUICDialerWithoutResolverFn is nil", func(t *testing.T) { - if trace.NewQUICDialerWithoutResolverFn != nil { + if trace.newQUICDialerWithoutResolverFn != nil { t.Fatal("expected nil NewQUICDialerQithoutResolverFn") } }) @@ -187,7 +188,7 @@ func TestNewTrace(t *testing.T) { }) t.Run("TimeNowFn is nil", func(t *testing.T) { - if trace.TimeNowFn != nil { + if trace.timeNowFn != nil { t.Fatal("expected nil TimeNowFn") } }) @@ -205,7 +206,7 @@ func TestTrace(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewStdlibResolverFn: func(logger model.Logger) model.Resolver { + newStdlibResolverFn: func(logger model.Logger) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{}, mockedErr @@ -226,7 +227,7 @@ func TestTrace(t *testing.T) { t.Run("when nil", func(t *testing.T) { tx := &Trace{ - NewParallelUDPResolverFn: nil, + newParallelUDPResolverFn: nil, } resolver := tx.newStdlibResolver(model.DiscardLogger) ctx, cancel := context.WithCancel(context.Background()) @@ -245,7 +246,7 @@ func TestTrace(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewParallelUDPResolverFn: func(logger model.Logger, dialer model.Dialer, address string) model.Resolver { + newParallelUDPResolverFn: func(logger model.Logger, dialer model.Dialer, address string) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{}, mockedErr @@ -267,7 +268,7 @@ func TestTrace(t *testing.T) { t.Run("when nil", func(t *testing.T) { tx := &Trace{ - NewParallelUDPResolverFn: nil, + newParallelUDPResolverFn: nil, } dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger) resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") @@ -287,7 +288,7 @@ func TestTrace(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewParallelDNSOverHTTPSResolverFn: func(logger model.Logger, URL string) model.Resolver { + newParallelDNSOverHTTPSResolverFn: func(logger model.Logger, URL string) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{}, mockedErr @@ -308,7 +309,7 @@ func TestTrace(t *testing.T) { t.Run("when nil", func(t *testing.T) { tx := &Trace{ - NewParallelDNSOverHTTPSResolverFn: nil, + newParallelDNSOverHTTPSResolverFn: nil, } resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") ctx, cancel := context.WithCancel(context.Background()) @@ -327,7 +328,7 @@ func TestTrace(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer { + newDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer { return &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, mockedErr @@ -348,7 +349,7 @@ func TestTrace(t *testing.T) { t.Run("when nil", func(t *testing.T) { tx := &Trace{ - NewDialerWithoutResolverFn: nil, + newDialerWithoutResolverFn: nil, } dialer := tx.NewDialerWithoutResolver(model.DiscardLogger) ctx, cancel := context.WithCancel(context.Background()) @@ -367,7 +368,7 @@ func TestTrace(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker { + newTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker { return &mocks.TLSHandshaker{ MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { return nil, tls.ConnectionState{}, mockedErr @@ -392,7 +393,7 @@ func TestTrace(t *testing.T) { t.Run("when nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewTLSHandshakerStdlibFn: nil, + newTLSHandshakerStdlibFn: nil, } thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger) tcpConn := &mocks.Conn{ @@ -437,7 +438,7 @@ func TestTrace(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewTLSHandshakerUTLSFn: func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { + newTLSHandshakerUTLSFn: func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { return &mocks.TLSHandshaker{ MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { return nil, tls.ConnectionState{}, mockedErr @@ -462,7 +463,7 @@ func TestTrace(t *testing.T) { t.Run("when nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewTLSHandshakerStdlibFn: nil, + newTLSHandshakerStdlibFn: nil, } thx := tx.newTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) tcpConn := &mocks.Conn{ @@ -507,7 +508,7 @@ func TestTrace(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewQUICDialerWithoutResolverFn: func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { + newQUICDialerWithoutResolverFn: func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { return &mocks.QUICDialer{ MockDialContext: func(ctx context.Context, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { @@ -530,7 +531,7 @@ func TestTrace(t *testing.T) { t.Run("when nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewQUICDialerWithoutResolverFn: nil, // explicit + newQUICDialerWithoutResolverFn: nil, // explicit } pconn := &mocks.UDPLikeConn{ MockLocalAddr: func() net.Addr { @@ -573,7 +574,7 @@ func TestTrace(t *testing.T) { t.Run("TimeNowFn works as intended", func(t *testing.T) { fixedTime := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC) tx := &Trace{ - TimeNowFn: func() time.Time { + timeNowFn: func() time.Time { return fixedTime }, } @@ -586,7 +587,7 @@ func TestTrace(t *testing.T) { t0 := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC) t1 := t0.Add(10 * time.Second) tx := &Trace{ - TimeNowFn: func() time.Time { + timeNowFn: func() time.Time { return t1 }, } @@ -595,3 +596,11 @@ func TestTrace(t *testing.T) { } }) } + +func TestTags(t *testing.T) { + trace := NewTrace(0, time.Now(), "antani") + got := trace.Tags() + if diff := cmp.Diff([]string{"antani"}, got); diff != "" { + t.Fatal(diff) + } +} diff --git a/internal/measurexlite/utls_test.go b/internal/measurexlite/utls_test.go index b0579c367b..2dccd89f0a 100644 --- a/internal/measurexlite/utls_test.go +++ b/internal/measurexlite/utls_test.go @@ -14,7 +14,7 @@ func TestNewTLSHandshakerUTLS(t *testing.T) { underlying := &mocks.TLSHandshaker{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.NewTLSHandshakerUTLSFn = func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { + trace.newTLSHandshakerUTLSFn = func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { return underlying } thx := trace.NewTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) diff --git a/internal/model/archival.go b/internal/model/archival.go index 640e9b9002..9dfb532ae1 100644 --- a/internal/model/archival.go +++ b/internal/model/archival.go @@ -125,6 +125,7 @@ type ArchivalDNSLookupResult struct { ResolverAddress string `json:"resolver_address"` T0 float64 `json:"t0,omitempty"` T float64 `json:"t"` + Tags []string `json:"tags"` TransactionID int64 `json:"transaction_id,omitempty"` } @@ -152,6 +153,7 @@ type ArchivalTCPConnectResult struct { Status ArchivalTCPConnectStatus `json:"status"` T0 float64 `json:"t0,omitempty"` T float64 `json:"t"` + Tags []string `json:"tags"` TransactionID int64 `json:"transaction_id,omitempty"` } @@ -202,6 +204,7 @@ type ArchivalHTTPRequestResult struct { Response ArchivalHTTPResponse `json:"response"` T0 float64 `json:"t0,omitempty"` T float64 `json:"t"` + Tags []string `json:"tags"` TransactionID int64 `json:"transaction_id,omitempty"` }