Skip to content

Commit

Permalink
[otel] extension: small refactor to improve testability
Browse files Browse the repository at this point in the history
  • Loading branch information
truthbk committed Sep 5, 2024
1 parent ed3dee9 commit d27b8dc
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 24 deletions.
16 changes: 7 additions & 9 deletions comp/otelcol/ddflareextension/impl/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"

"go.opentelemetry.io/collector/component"
Expand All @@ -30,11 +29,10 @@ type ddExtension struct {

cfg *Config // Extension configuration.

telemetry component.TelemetrySettings
server *http.Server
tlsListener net.Listener
info component.BuildInfo
debug extensionDef.DebugSourceResponse
telemetry component.TelemetrySettings
server *server
info component.BuildInfo
debug extensionDef.DebugSourceResponse
}

var _ extension.Extension = (*ddExtension)(nil)
Expand All @@ -51,7 +49,7 @@ func NewExtension(_ context.Context, cfg *Config, telemetry component.TelemetryS
}

var err error
ext.server, ext.tlsListener, err = buildHTTPServer(cfg.HTTPConfig.Endpoint, ext)
ext.server, err = newServer(cfg.HTTPConfig.Endpoint, ext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -108,7 +106,7 @@ func (ext *ddExtension) Start(_ context.Context, host component.Host) error {
}

go func() {
if err := ext.server.Serve(ext.tlsListener); err != nil && err != http.ErrServerClosed {
if err := ext.server.start(); err != nil && err != http.ErrServerClosed {
ext.telemetry.ReportStatus(component.NewFatalErrorEvent(err))
ext.telemetry.Logger.Info("DD Extension HTTP could not start", zap.String("err", err.Error()))
}
Expand All @@ -123,7 +121,7 @@ func (ext *ddExtension) Shutdown(ctx context.Context) error {
ext.telemetry.Logger.Info("Shutting down HTTP server")

// Give the server a grace period to finish handling requests.
return ext.server.Shutdown(ctx)
return ext.server.shutdown(ctx)
}

// ServeHTTP the request handler for the extension.
Expand Down
4 changes: 3 additions & 1 deletion comp/otelcol/ddflareextension/impl/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,10 @@ func TestExtensionHTTPHandler(t *testing.T) {

ddExt.Start(context.TODO(), host)

handler := ddExt.server.srv.Handler

// Call the handler's ServeHTTP method
ddExt.ServeHTTP(rr, req)
handler.ServeHTTP(rr, req)

// Check the response status code
assert.Equalf(t, http.StatusOK, rr.Code,
Expand Down
42 changes: 28 additions & 14 deletions comp/otelcol/ddflareextension/impl/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package ddflareextensionimpl

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
Expand All @@ -23,6 +24,11 @@ import (
"github.com/gorilla/mux"
)

type server struct {
srv *http.Server
listener net.Listener
}

// validateToken - validates token for legacy API
func validateToken(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -33,18 +39,18 @@ func validateToken(next http.Handler) http.Handler {
})
}

func buildHTTPServer(endpoint string, handler http.Handler) (*http.Server, net.Listener, error) {
func newServer(endpoint string, handler http.Handler) (*server, error) {

// Generate a self-signed certificate
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
return nil, err
}

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, err
return nil, err
}

template := x509.Certificate{
Expand All @@ -70,12 +76,12 @@ func buildHTTPServer(endpoint string, handler http.Handler) (*http.Server, net.L

certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
return nil, nil, err
return nil, err
}
// parse the resulting certificate so we can use it again
_, err = x509.ParseCertificate(certDER)
if err != nil {
return nil, nil, err
return nil, err
}
// PEM encode the certificate (this is a standard TLS encoding)
b := pem.Block{Type: "CERTIFICATE", Bytes: certDER}
Expand All @@ -87,13 +93,13 @@ func buildHTTPServer(endpoint string, handler http.Handler) (*http.Server, net.L

pair, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, nil, fmt.Errorf("unable to generate TLS key pair: %v", err)
return nil, fmt.Errorf("unable to generate TLS key pair: %v", err)
}

tlsCertPool := x509.NewCertPool()
ok := tlsCertPool.AppendCertsFromPEM(certPEM)
if !ok {
return nil, nil, fmt.Errorf("unable to add new certificate to pool")
return nil, fmt.Errorf("unable to add new certificate to pool")
}

// Create TLS configuration
Expand All @@ -108,22 +114,30 @@ func buildHTTPServer(endpoint string, handler http.Handler) (*http.Server, net.L

r.Use(validateToken)

server := &http.Server{
s := &http.Server{
Addr: endpoint,
TLSConfig: tlsConfig,
Handler: r,
}

listener, err := net.Listen("tcp", endpoint)
if err != nil {
return nil, nil, err
return nil, err
}

tlsListener := tls.NewListener(listener, server.TLSConfig)
go func() {
_ = server.Serve(tlsListener)
}()
tlsListener := tls.NewListener(listener, s.TLSConfig)

return &server{
srv: s,
listener: tlsListener,
}, nil

return server, tlsListener, nil
}

func (s *server) start() error {
return s.srv.Serve(s.listener)
}

func (s *server) shutdown(ctx context.Context) error {
return s.srv.Shutdown(ctx)
}

0 comments on commit d27b8dc

Please sign in to comment.