Skip to content

Commit

Permalink
Add PSK support
Browse files Browse the repository at this point in the history
  • Loading branch information
janvrska committed May 13, 2024
1 parent 677ce13 commit 92a164a
Show file tree
Hide file tree
Showing 10 changed files with 507 additions and 148 deletions.
115 changes: 98 additions & 17 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Use ClientConfig struct to set various client's options and features.
//
// Example:
//
// import (
// "fmt"
// "time"
Expand Down Expand Up @@ -58,12 +59,16 @@ import (
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"

"github.com/patrickmn/go-cache"
"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"golang.org/x/sync/errgroup"
Expand All @@ -78,10 +83,17 @@ import (
type ClientConfig struct {
// UseDTLS controls whether DTLS should be used to secure the connection
// to the MQTT-SN gateway.
UseDTLS bool
Certificate *tls.Certificate
PrivateKey crypto.PrivateKey
CACertificates []*x509.Certificate
UseDTLS bool
Certificate *tls.Certificate
PrivateKey crypto.PrivateKey
UsePSK bool
PSK *cache.Cache
PSKCacheExpiration time.Duration
PSKIdentity string
PSKApiBasicAuthUsername string
PSKApiBasicAuthPassword string
PSKApiEndpoint string
CACertificates []*x509.Certificate
// SelfSigned controls whether the client should use a self-signed
// certificate and key. If SelfSigned is false and UseDTLS is true, you
// must provide CertFile and KeyFile.
Expand Down Expand Up @@ -147,21 +159,24 @@ func (c *Client) connectDTLS(ctx context.Context, address string) (net.Conn, err
var certificate *tls.Certificate
var err error

if c.cfg.SelfSigned {
var cert tls.Certificate
cert, err = selfsign.GenerateSelfSigned()
certificate = &cert
} else {
privateKey := c.cfg.PrivateKey
if privateKey == nil {
err = errors.New("private key is missing")
}
if certificate = c.cfg.Certificate; certificate != nil {
certificate.PrivateKey = privateKey
if !c.cfg.UsePSK && c.cfg.UseDTLS {
if c.cfg.SelfSigned {
var cert tls.Certificate
cert, err = selfsign.GenerateSelfSigned()
certificate = &cert
} else {
err = errors.New("TLS certificate is missing")
privateKey := c.cfg.PrivateKey
if privateKey == nil {
err = errors.New("private key is missing")
}
if certificate = c.cfg.Certificate; certificate != nil {
certificate.PrivateKey = privateKey
} else {
err = errors.New("TLS certificate is missing")
}
}
}

if err != nil {
return nil, err
}
Expand All @@ -182,12 +197,34 @@ func (c *Client) connectDTLS(ctx context.Context, address string) (net.Conn, err

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
Certificates: []tls.Certificate{*certificate},
InsecureSkipVerify: c.cfg.Insecure,
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
RootCAs: certPool,
}

if !c.cfg.UsePSK && c.cfg.UseDTLS && certificate != nil {
config.Certificates = []tls.Certificate{*certificate}
}

if c.cfg.UsePSK && c.cfg.UseDTLS {
config.CipherSuites = []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_GCM_SHA256}
config.PSK = func(hint []byte) ([]byte, error) {
psk, ok := c.cfg.PSK.Get(string(hint))
if !ok {
psk, ok = getPSK(string(hint), c.cfg, c.log)
if ok {
c.cfg.PSK.Set(string(hint), psk, c.cfg.PSKCacheExpiration)
return psk.([]byte), nil
} else {
return nil, errors.New("PSK not found")
}
}

return psk.([]byte), nil
}
config.PSKIdentityHint = []byte(c.cfg.PSKIdentity)
}

// Connect to a DTLS server
ctx2, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
Expand Down Expand Up @@ -556,3 +593,47 @@ func (c *Client) Disconnect() error {
return c.group.Wait()
}
}

func getPSK(hint string, clientConfig *ClientConfig, logger util.Logger) ([]byte, bool) {
req, err := http.NewRequest("GET", fmt.Sprintf(clientConfig.PSKApiEndpoint+"/%s", hint), nil)
if err != nil {
logger.Error("Error in creating request: %s", err)
return nil, false
}

req.SetBasicAuth(clientConfig.PSKApiBasicAuthUsername, clientConfig.PSKApiBasicAuthPassword)
client := &http.Client{}
resp, err := client.Do(req)

if err != nil {
logger.Error("Error in sending request: %s", err)
return nil, false
}

defer resp.Body.Close()

if resp.StatusCode == http.StatusNotFound {
logger.Debug("ID not found")
return nil, false
}

body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Error("Error in reading response body: %s", err)
return nil, false
}

type Response struct {
Data map[string][]byte `json:"data"`
}

var response Response

err = json.Unmarshal(body, &response)
if err != nil {
logger.Error("Error in unmarshalling response body: %s", err)
return nil, false
}

return response.Data[hint], true
}
47 changes: 31 additions & 16 deletions cmd/bisquitt-pub/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"syscall"
"time"

"github.com/patrickmn/go-cache"
"github.com/urfave/cli/v2"

snClient "github.com/energomonitor/bisquitt/client"
Expand All @@ -33,13 +34,20 @@ func handleAction() cli.ActionFunc {

useDTLS := c.Bool(DtlsFlag)
useSelfSigned := c.Bool(SelfSignedFlag)
usePSK := c.Bool(PskFlag)
pskCacheExpiration := c.Duration(PskCacheExpirationFlag)

pskIdentity := c.String(PskIdentityFlag)
pskApiBasicAuthUsername := c.String(PskApiBasicAuthUsernameFlag)
pskApiBasicAuthPassword := c.String(PskApiBasicAuthPasswordFlag)
pskApiEndpoint := c.String(PskApiEndpointFlag)
certFile := c.Path(CertFlag)
keyFile := c.Path(KeyFlag)
caFile := c.Path(CAFileFlag)
caPath := c.Path(CAPathFlag)
debug := c.Bool(DebugFlag)

if useDTLS && (certFile == "" || keyFile == "") && !useSelfSigned {
if useDTLS && ((certFile == "" || keyFile == "") && !useSelfSigned) && !usePSK {
return fmt.Errorf(`options "--%s" and "--%s" are mandatory when using DTLS. Use "--%s" to generate self-signed certificate.`,
CertFlag, KeyFlag, SelfSignedFlag)
}
Expand Down Expand Up @@ -132,21 +140,28 @@ func handleAction() cli.ActionFunc {
password := []byte(c.String(PasswordFlag))

clientCfg := &snClient.ClientConfig{
ClientID: clientID,
UseDTLS: useDTLS,
SelfSigned: useSelfSigned,
Insecure: insecure,
Certificate: certificate,
PrivateKey: privateKey,
CACertificates: caCertificates,
RetryDelay: 10 * time.Second,
RetryCount: 4,
ConnectTimeout: 20 * time.Second,
KeepAlive: 60 * time.Second,
CleanSession: true,
PredefinedTopics: predefinedTopics,
User: user,
Password: password,
ClientID: clientID,
UseDTLS: useDTLS,
UsePSK: usePSK,
PSK: cache.New(pskCacheExpiration, 5*time.Minute),
PSKCacheExpiration: pskCacheExpiration,
PSKIdentity: pskIdentity,
PSKApiBasicAuthUsername: pskApiBasicAuthUsername,
PSKApiBasicAuthPassword: pskApiBasicAuthPassword,
PSKApiEndpoint: pskApiEndpoint,
SelfSigned: useSelfSigned,
Insecure: insecure,
Certificate: certificate,
PrivateKey: privateKey,
CACertificates: caCertificates,
RetryDelay: 10 * time.Second,
RetryCount: 4,
ConnectTimeout: 20 * time.Second,
KeepAlive: 60 * time.Second,
CleanSession: true,
PredefinedTopics: predefinedTopics,
User: user,
Password: password,
}

var logger util.Logger
Expand Down
88 changes: 69 additions & 19 deletions cmd/bisquitt-pub/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,39 @@ package main

import (
"fmt"
"time"

"github.com/urfave/cli/v2"

"github.com/energomonitor/bisquitt"
)

const (
HostFlag = "host"
PortFlag = "port"
DtlsFlag = "dtls"
SelfSignedFlag = "self-signed"
CertFlag = "cert"
KeyFlag = "key"
CAFileFlag = "cafile"
CAPathFlag = "capath"
InsecureFlag = "insecure"
DebugFlag = "debug"
TopicFlag = "topic"
MessageFlag = "message"
RetainFlag = "retain"
PredefinedTopicFlag = "predefined-topic"
PredefinedTopicsFileFlag = "predefined-topics-file"
QOSFlag = "qos"
ClientIDFlag = "client-id"
UserFlag = "user"
PasswordFlag = "password"
HostFlag = "host"
PortFlag = "port"
DtlsFlag = "dtls"
SelfSignedFlag = "self-signed"
PskFlag = "psk"
PskCacheExpirationFlag = "psk-cache-expiration"
PskIdentityFlag = "psk-identity"
PskApiBasicAuthUsernameFlag = "psk-api-basic-auth-username"
PskApiBasicAuthPasswordFlag = "psk-api-basic-auth-password"
PskApiEndpointFlag = "psk-api-endpoint"
CertFlag = "cert"
KeyFlag = "key"
CAFileFlag = "cafile"
CAPathFlag = "capath"
InsecureFlag = "insecure"
DebugFlag = "debug"
TopicFlag = "topic"
MessageFlag = "message"
RetainFlag = "retain"
PredefinedTopicFlag = "predefined-topic"
PredefinedTopicsFileFlag = "predefined-topics-file"
QOSFlag = "qos"
ClientIDFlag = "client-id"
UserFlag = "user"
PasswordFlag = "password"
)

func init() {
Expand Down Expand Up @@ -76,6 +83,49 @@ var Application = cli.App{
"SELF_SIGNED",
},
},
&cli.BoolFlag{
Name: PskFlag,
Usage: "use PSK",
EnvVars: []string{
"PSK_ENABLED",
},
},
&cli.DurationFlag{
Name: PskCacheExpirationFlag,
Usage: "PSK cache expiration",
Value: 5 * time.Minute,
EnvVars: []string{
"PSK_CACHE_EXPIRATION",
},
},
&cli.StringFlag{
Name: PskIdentityFlag,
Usage: "PSK identity",
EnvVars: []string{
"PSK_IDENTITY",
},
},
&cli.StringFlag{
Name: PskApiBasicAuthUsernameFlag,
Usage: "PSK API basic auth username",
EnvVars: []string{
"PSK_API_BASIC_AUTH_USERNAME",
},
},
&cli.StringFlag{
Name: PskApiBasicAuthPasswordFlag,
Usage: "PSK API basic auth password",
EnvVars: []string{
"PSK_API_BASIC_AUTH_PASSWORD",
},
},
&cli.StringFlag{
Name: PskApiEndpointFlag,
Usage: "PSK API endpoint",
EnvVars: []string{
"PSK_API_ENDPOINT",
},
},
&cli.PathFlag{
Name: CertFlag,
Usage: "DTLS certificate file",
Expand Down
Loading

0 comments on commit 92a164a

Please sign in to comment.