diff --git a/address.go b/address.go index 987aff1..5b17000 100644 --- a/address.go +++ b/address.go @@ -23,7 +23,9 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/net/proxy" "io" + "math" "net" + "strings" "time" ) @@ -109,6 +111,78 @@ func (self Configuration) GetHandshakeTimeout() (time.Duration, error) { return 0, nil } +func (self Configuration) GetUIntValue(first string, rest ...string) (uint, bool, error) { + val, err := self.GetValue(first, rest...) + if val == nil { + return 0, false, err + } + + intVal, ok := val.(int) + if !ok { + key := strings.Join(self.toSlice(first, rest...), ":") + return 0, false, errors.Errorf("value for key %s should be int, not %v of type '%T'", key, val, val) + } + if intVal < 0 { + key := strings.Join(self.toSlice(first, rest...), ":") + return 0, false, errors.Errorf("value for key %s should be positive, not %v ", key, intVal) + } + + if intVal > math.MaxInt { + key := strings.Join(self.toSlice(first, rest...), ":") + return 0, false, errors.Errorf("value for key %s should be less or equal to %d, not %v ", key, math.MaxInt, intVal) + } + + return uint(intVal), true, nil +} + +func (self Configuration) GetInt64Value(first string, rest ...string) (int64, bool, error) { + val, err := self.GetValue(first, rest...) + if val == nil { + return 0, false, err + } + + intVal, ok := val.(int) + if !ok { + key := strings.Join(self.toSlice(first, rest...), ":") + return 0, false, errors.Errorf("value for key %s should be int, not %v of type '%T'", key, val, val) + } + return int64(intVal), true, nil +} + +func (self Configuration) toSlice(first string, rest ...string) []string { + if len(rest) == 0 { + return []string{first} + } + + key := make([]string, len(rest)+1) + key[0] = first + copy(key[1:], rest) + return key +} + +func (self Configuration) GetValue(first string, rest ...string) (interface{}, error) { + return self.getValue(0, self.toSlice(first, rest...)...) +} + +func (self Configuration) getValue(index int, key ...string) (interface{}, error) { + if index == len(key)-1 { + return self[key[index]], nil + } + + val, ok := self[key[index]] + if !ok { + return nil, nil + } + + subMap, ok := val.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("invalid transport configuration value. %s should be map, not %T", + strings.Join(key[:index+1], ":"), val) + } + + return Configuration(subMap).getValue(index+1, key...) +} + type ProxyType string const ( diff --git a/dtls/address.go b/dtls/address.go index 2e3b946..f71868e 100644 --- a/dtls/address.go +++ b/dtls/address.go @@ -17,7 +17,6 @@ package dtls import ( - "github.com/michaelquigley/pfxlog" "github.com/openziti/identity" "github.com/openziti/transport/v2" "github.com/pkg/errors" @@ -126,16 +125,10 @@ func (ap AddressParser) Parse(s string) (transport.Address, error) { } func getMaxBytesPerSecond(tcfg transport.Configuration) (int64, bool) { - log := pfxlog.Logger() - log.Info("attempting to retrieve dtls maxBytesPerSecond value") if m, ok := tcfg["dtls"]; ok { - log.Info("dtls submap found") if subMap, ok := m.(map[interface{}]interface{}); ok { - log.Info("dtls submap correct format") if v, ok := subMap["maxBytesPerSecond"]; ok { - log.Info("dtls maxBytesPerSecond found") if bps, ok := v.(int); ok { - log.Info("dtls maxBytesPerSecond correct format") return int64(bps), true } } diff --git a/dtls/dialer.go b/dtls/dialer.go index b0609b4..bd37d26 100644 --- a/dtls/dialer.go +++ b/dtls/dialer.go @@ -19,6 +19,7 @@ package dtls import ( "context" "crypto/tls" + "fmt" "github.com/michaelquigley/pfxlog" "github.com/openziti/identity" "github.com/openziti/transport/v2" @@ -30,6 +31,10 @@ import ( "time" ) +const ( + DefaultBufferSize = 4 * 1024 * 1024 +) + func Dial(addr *address, name string, i *identity.TokenId, timeout time.Duration, tcfg transport.Configuration) (transport.Conn, error) { return DialWithLocalBinding(addr, name, "", i, timeout, tcfg) } @@ -62,12 +67,28 @@ func DialWithLocalBinding(addr *address, name, localBinding string, i *identity. return nil, closeErr } - if err := udpConn.SetWriteBuffer(4 * 1024 * 1024); err != nil { - panic(err) + writeBufferSize := DefaultBufferSize + bufferSize, found, err := tcfg.GetUIntValue("dtls", "writeBufferSize") + if err != nil { + return nil, err + } + if found { + writeBufferSize = int(bufferSize) + } + if err := udpConn.SetWriteBuffer(writeBufferSize); err != nil { + return nil, fmt.Errorf("unable to set udp write buffer size to %d (%w)", writeBufferSize, err) } - if err := udpConn.SetReadBuffer(4 * 1024 * 1024); err != nil { - panic(err) + readBufferSize := DefaultBufferSize + bufferSize, found, err = tcfg.GetUIntValue("dtls", "readBufferSize") + if err != nil { + return nil, err + } + if found { + readBufferSize = int(bufferSize) + } + if err = udpConn.SetWriteBuffer(readBufferSize); err != nil { + return nil, fmt.Errorf("unable to set udp read buffer size to %d (%w)", readBufferSize, err) } conn, closeErr := dtls.Client(udpConn, &addr.UDPAddr, cfg) @@ -100,7 +121,11 @@ func DialWithLocalBinding(addr *address, name, localBinding string, i *identity. log.Debugf("server provided [%d] certificates", len(certs)) var w io.Writer = conn - if bps, ok := getMaxBytesPerSecond(tcfg); ok { + bps, found, err := tcfg.GetInt64Value("dtls", "maxBytesPerSecond") + if err != nil { + return nil, err + } + if found { log.Infof("limiting DTLS writes to %dB/s", bps) w = shaper.LimitWriter(conn, time.Second, bps) }