Skip to content

Commit

Permalink
Make read/write buffer sizes for dtls configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
plorenz committed Sep 6, 2024
1 parent 7af36d0 commit 89aec36
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 12 deletions.
74 changes: 74 additions & 0 deletions address.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"io"
"math"
"net"
"strings"
"time"
)

Expand Down Expand Up @@ -109,6 +111,78 @@ func (self Configuration) GetHandshakeTimeout() (time.Duration, error) {
return 0, nil
}

func (self Configuration) GetUIntValue(first string, rest ...string) (uint, bool, error) {
val, err := self.GetValue(first, rest...)
if val == nil {
return 0, false, err
}

intVal, ok := val.(int)
if !ok {
key := strings.Join(self.toSlice(first, rest...), ":")
return 0, false, errors.Errorf("value for key %s should be int, not %v of type '%T'", key, val, val)
}
if intVal < 0 {
key := strings.Join(self.toSlice(first, rest...), ":")
return 0, false, errors.Errorf("value for key %s should be positive, not %v ", key, intVal)
}

if intVal > math.MaxInt {
key := strings.Join(self.toSlice(first, rest...), ":")
return 0, false, errors.Errorf("value for key %s should be less or equal to %d, not %v ", key, math.MaxInt, intVal)
}

return uint(intVal), true, nil
}

func (self Configuration) GetInt64Value(first string, rest ...string) (int64, bool, error) {
val, err := self.GetValue(first, rest...)
if val == nil {
return 0, false, err
}

intVal, ok := val.(int)
if !ok {
key := strings.Join(self.toSlice(first, rest...), ":")
return 0, false, errors.Errorf("value for key %s should be int, not %v of type '%T'", key, val, val)
}
return int64(intVal), true, nil
}

func (self Configuration) toSlice(first string, rest ...string) []string {
if len(rest) == 0 {
return []string{first}
}

key := make([]string, len(rest)+1)
key[0] = first
copy(key[1:], rest)
return key
}

func (self Configuration) GetValue(first string, rest ...string) (interface{}, error) {
return self.getValue(0, self.toSlice(first, rest...)...)
}

func (self Configuration) getValue(index int, key ...string) (interface{}, error) {
if index == len(key)-1 {
return self[key[index]], nil
}

val, ok := self[key[index]]
if !ok {
return nil, nil
}

subMap, ok := val.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("invalid transport configuration value. %s should be map, not %T",
strings.Join(key[:index+1], ":"), val)
}

return Configuration(subMap).getValue(index+1, key...)
}

type ProxyType string

const (
Expand Down
7 changes: 0 additions & 7 deletions dtls/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package dtls

import (
"github.com/michaelquigley/pfxlog"
"github.com/openziti/identity"
"github.com/openziti/transport/v2"
"github.com/pkg/errors"
Expand Down Expand Up @@ -126,16 +125,10 @@ func (ap AddressParser) Parse(s string) (transport.Address, error) {
}

func getMaxBytesPerSecond(tcfg transport.Configuration) (int64, bool) {
log := pfxlog.Logger()
log.Info("attempting to retrieve dtls maxBytesPerSecond value")
if m, ok := tcfg["dtls"]; ok {
log.Info("dtls submap found")
if subMap, ok := m.(map[interface{}]interface{}); ok {
log.Info("dtls submap correct format")
if v, ok := subMap["maxBytesPerSecond"]; ok {
log.Info("dtls maxBytesPerSecond found")
if bps, ok := v.(int); ok {
log.Info("dtls maxBytesPerSecond correct format")
return int64(bps), true
}
}
Expand Down
35 changes: 30 additions & 5 deletions dtls/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package dtls
import (
"context"
"crypto/tls"
"fmt"
"github.com/michaelquigley/pfxlog"
"github.com/openziti/identity"
"github.com/openziti/transport/v2"
Expand All @@ -30,6 +31,10 @@ import (
"time"
)

const (
DefaultBufferSize = 4 * 1024 * 1024
)

func Dial(addr *address, name string, i *identity.TokenId, timeout time.Duration, tcfg transport.Configuration) (transport.Conn, error) {
return DialWithLocalBinding(addr, name, "", i, timeout, tcfg)
}
Expand Down Expand Up @@ -62,12 +67,28 @@ func DialWithLocalBinding(addr *address, name, localBinding string, i *identity.
return nil, closeErr
}

if err := udpConn.SetWriteBuffer(4 * 1024 * 1024); err != nil {
panic(err)
writeBufferSize := DefaultBufferSize
bufferSize, found, err := tcfg.GetUIntValue("dtls", "writeBufferSize")
if err != nil {
return nil, err
}
if found {
writeBufferSize = int(bufferSize)
}
if err := udpConn.SetWriteBuffer(writeBufferSize); err != nil {
return nil, fmt.Errorf("unable to set udp write buffer size to %d (%w)", writeBufferSize, err)
}

if err := udpConn.SetReadBuffer(4 * 1024 * 1024); err != nil {
panic(err)
readBufferSize := DefaultBufferSize
bufferSize, found, err = tcfg.GetUIntValue("dtls", "readBufferSize")
if err != nil {
return nil, err
}
if found {
readBufferSize = int(bufferSize)
}
if err = udpConn.SetWriteBuffer(readBufferSize); err != nil {
return nil, fmt.Errorf("unable to set udp read buffer size to %d (%w)", readBufferSize, err)
}

conn, closeErr := dtls.Client(udpConn, &addr.UDPAddr, cfg)
Expand Down Expand Up @@ -100,7 +121,11 @@ func DialWithLocalBinding(addr *address, name, localBinding string, i *identity.
log.Debugf("server provided [%d] certificates", len(certs))

var w io.Writer = conn
if bps, ok := getMaxBytesPerSecond(tcfg); ok {
bps, found, err := tcfg.GetInt64Value("dtls", "maxBytesPerSecond")
if err != nil {
return nil, err
}
if found {
log.Infof("limiting DTLS writes to %dB/s", bps)
w = shaper.LimitWriter(conn, time.Second, bps)
}
Expand Down

0 comments on commit 89aec36

Please sign in to comment.