Skip to content

Commit

Permalink
Fix race conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
DerAndereAndi committed Dec 31, 2023
1 parent cc5a74e commit 50a4423
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 67 deletions.
4 changes: 2 additions & 2 deletions cmd/evse/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ func (h *evse) VisibleRemoteServicesUpdated(service *service.EEBUSService, entri

func (h *evse) ServiceShipIDUpdate(ski string, shipdID string) {}

func (h *evse) ServicePairingDetailUpdate(ski string, detail service.ConnectionStateDetail) {
if ski == remoteSki && detail.State == service.ConnectionStateRemoteDeniedTrust {
func (h *evse) ServicePairingDetailUpdate(ski string, detail *service.ConnectionStateDetail) {
if ski == remoteSki && detail.State() == service.ConnectionStateRemoteDeniedTrust {
fmt.Println("The remote service denied trust. Exiting.")
h.myService.RegisterRemoteSKI(ski, false)
h.myService.CancelPairingWithSKI(ski)
Expand Down
4 changes: 2 additions & 2 deletions cmd/hems/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ func (h *hems) VisibleRemoteServicesUpdated(service *service.EEBUSService, entri

func (h *hems) ServiceShipIDUpdate(ski string, shipdID string) {}

func (h *hems) ServicePairingDetailUpdate(ski string, detail service.ConnectionStateDetail) {
if ski == remoteSki && detail.State == service.ConnectionStateRemoteDeniedTrust {
func (h *hems) ServicePairingDetailUpdate(ski string, detail *service.ConnectionStateDetail) {
if ski == remoteSki && detail.State() == service.ConnectionStateRemoteDeniedTrust {
fmt.Println("The remote service denied trust. Exiting.")
h.myService.RegisterRemoteSKI(ski, false)
h.myService.CancelPairingWithSKI(ski)
Expand Down
64 changes: 29 additions & 35 deletions service/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var connectionInitiationDelayTimeRanges = []connectionInitiationDelayTimeRange{
// interface for reporting data from connectionsHub to the EEBUSService
type ServiceProvider interface {
// report a newly discovered remote EEBUS service
VisibleMDNSRecordsUpdated(entries []MdnsEntry)
VisibleMDNSRecordsUpdated(entries []*MdnsEntry)

// report a connection to a SKI
RemoteSKIConnected(ski string)
Expand All @@ -58,7 +58,7 @@ type ServiceProvider interface {
ServiceShipIDUpdate(ski string, shipID string)

// provides the current handshake state for a given SKI
ServicePairingDetailUpdate(ski string, detail ConnectionStateDetail)
ServicePairingDetailUpdate(ski string, detail *ConnectionStateDetail)

// return if the user is still able to trust the connection
AllowWaitingForTrust(ski string) bool
Expand Down Expand Up @@ -87,7 +87,7 @@ type connectionsHub struct {
mdns MdnsService

// list of currently known/reported mDNS entries
knownMdnsEntries []MdnsEntry
knownMdnsEntries []*MdnsEntry

// the SPINE local device
spineLocalDevice *spine.DeviceLocalImpl
Expand All @@ -104,7 +104,7 @@ func newConnectionsHub(serviceProvider ServiceProvider, mdns MdnsService, spineL
connectionAttemptCounter: make(map[string]int),
connectionAttemptRunning: make(map[string]bool),
remoteServices: make(map[string]*ServiceDetails),
knownMdnsEntries: make([]MdnsEntry, 0),
knownMdnsEntries: make([]*MdnsEntry, 0),
serviceProvider: serviceProvider,
spineLocalDevice: spineLocalDevice,
configuration: configuration,
Expand Down Expand Up @@ -230,8 +230,6 @@ func (h *connectionsHub) AllowWaitingForTrust(ski string) bool {

// Provides the current ship message exchange state for a given SKI and the corresponding error if state is error
func (h *connectionsHub) HandleShipHandshakeStateUpdate(ski string, state ship.ShipState) {
service := h.serviceForSKI(ski)

// overwrite service Paired value
if state.State == ship.SmeHelloStateOk {
h.RegisterRemoteSKI(ski, true)
Expand All @@ -242,13 +240,12 @@ func (h *connectionsHub) HandleShipHandshakeStateUpdate(ski string, state ship.S
pairingState = ConnectionStateError
}

pairingDetail := ConnectionStateDetail{
State: pairingState,
Error: state.Error,
}
pairingDetail := NewConnectionStateDetail(pairingState, state.Error)

service := h.serviceForSKI(ski)

existingDetails := service.ConnectionStateDetail
if existingDetails.State != pairingState || existingDetails.Error != state.Error {
if existingDetails.State() != pairingState || existingDetails.Error() != state.Error {
service.ConnectionStateDetail = pairingDetail

h.serviceProvider.ServicePairingDetailUpdate(ski, pairingDetail)
Expand All @@ -261,16 +258,13 @@ func (h *connectionsHub) HandleShipHandshakeStateUpdate(ski string, state ship.S
//
// ErrNotPaired if the SKI is not in the (to be) paired list
// ErrNoConnectionFound if no connection for the SKI was found
func (h *connectionsHub) PairingDetailForSki(ski string) ConnectionStateDetail {
func (h *connectionsHub) PairingDetailForSki(ski string) *ConnectionStateDetail {
service := h.serviceForSKI(ski)

if conn := h.connectionForSKI(ski); conn != nil {
shipState, shipError := conn.ShipHandshakeState()
state := h.mapShipMessageExchangeState(shipState, ski)
return ConnectionStateDetail{
State: state,
Error: shipError,
}
return NewConnectionStateDetail(state, shipError)
}

return service.ConnectionStateDetail
Expand Down Expand Up @@ -457,8 +451,8 @@ func (h *connectionsHub) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Check if the remote service is paired
service := h.serviceForSKI(remoteService.SKI)
if service.ConnectionStateDetail.State == ConnectionStateQueued {
service.ConnectionStateDetail.State = ConnectionStateReceivedPairingRequest
if service.ConnectionStateDetail.State() == ConnectionStateQueued {
service.ConnectionStateDetail.SetState(ConnectionStateReceivedPairingRequest)
h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail)
}

Expand Down Expand Up @@ -601,7 +595,7 @@ func (h *connectionsHub) serviceForSKI(ski string) *ServiceDetails {
service, ok := h.remoteServices[ski]
if !ok {
service = NewServiceDetails(ski)
service.ConnectionStateDetail.State = ConnectionStateNone
service.ConnectionStateDetail.SetState(ConnectionStateNone)
h.remoteServices[ski] = service
}

Expand All @@ -622,7 +616,7 @@ func (h *connectionsHub) RegisterRemoteSKI(ski string, enable bool) {

h.removeConnectionAttemptCounter(ski)

service.ConnectionStateDetail.State = ConnectionStateNone
service.ConnectionStateDetail.SetState(ConnectionStateNone)

h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail)

Expand All @@ -644,7 +638,7 @@ func (h *connectionsHub) InitiatePairingWithSKI(ski string) {

// locally initiated
service := h.serviceForSKI(ski)
service.ConnectionStateDetail.State = ConnectionStateQueued
service.ConnectionStateDetail.SetState(ConnectionStateQueued)

h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail)

Expand All @@ -663,18 +657,18 @@ func (h *connectionsHub) CancelPairingWithSKI(ski string) {
}

service := h.serviceForSKI(ski)
service.ConnectionStateDetail.State = ConnectionStateNone
service.ConnectionStateDetail.SetState(ConnectionStateNone)
service.Trusted = false

h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail)
}

// Process reported mDNS services
func (h *connectionsHub) ReportMdnsEntries(entries map[string]MdnsEntry) {
func (h *connectionsHub) ReportMdnsEntries(entries map[string]*MdnsEntry) {
h.muxMdns.Lock()
defer h.muxMdns.Unlock()

var mdnsEntries []MdnsEntry
var mdnsEntries []*MdnsEntry

for ski, entry := range entries {
mdnsEntries = append(mdnsEntries, entry)
Expand All @@ -686,8 +680,8 @@ func (h *connectionsHub) ReportMdnsEntries(entries map[string]MdnsEntry) {

// Check if the remote service is paired or queued for connection
service := h.serviceForSKI(ski)
pairingState := service.ConnectionStateDetail.State
if !h.IsRemoteServiceForSKIPaired(ski) && pairingState != ConnectionStateQueued {
if !h.IsRemoteServiceForSKIPaired(ski) &&
service.ConnectionStateDetail.State() != ConnectionStateQueued {
continue
}

Expand Down Expand Up @@ -715,7 +709,7 @@ func (h *connectionsHub) ReportMdnsEntries(entries map[string]MdnsEntry) {
}

// coordinate connection initiation attempts to a remove service
func (h *connectionsHub) coordinateConnectionInitations(ski string, entry MdnsEntry) {
func (h *connectionsHub) coordinateConnectionInitations(ski string, entry *MdnsEntry) {
if h.isConnectionAttemptRunning(ski) {
return
}
Expand All @@ -725,7 +719,7 @@ func (h *connectionsHub) coordinateConnectionInitations(ski string, entry MdnsEn
counter, duration := h.getConnectionInitiationDelayTime(ski)

service := h.serviceForSKI(ski)
if service.ConnectionStateDetail.State == ConnectionStateQueued {
if service.ConnectionStateDetail.State() == ConnectionStateQueued {
go h.prepareConnectionInitation(ski, counter, entry)
return
}
Expand All @@ -744,12 +738,9 @@ func (h *connectionsHub) coordinateConnectionInitations(ski string, entry MdnsEn

// invoked by coordinateConnectionInitations either with a delay or directly
// when initating a pairing process
func (h *connectionsHub) prepareConnectionInitation(ski string, counter int, entry MdnsEntry) {
func (h *connectionsHub) prepareConnectionInitation(ski string, counter int, entry *MdnsEntry) {
h.setConnectionAttemptRunning(ski, false)

// check if the remoteService still exists
service := h.serviceForSKI(ski)

// check if the current counter is still the same, otherwise this counter is irrelevant
currentCounter, exists := h.getCurrentConnectionAttemptCounter(ski)
if !exists || currentCounter != counter {
Expand All @@ -758,7 +749,7 @@ func (h *connectionsHub) prepareConnectionInitation(ski string, counter int, ent

// connection attempt is not relevant if the device is no longer paired
// or it is not queued for pairing
pairingState := h.serviceForSKI(ski).ConnectionStateDetail.State
pairingState := h.serviceForSKI(ski).ConnectionStateDetail.State()
if !h.IsRemoteServiceForSKIPaired(ski) && pairingState != ConnectionStateQueued {
return
}
Expand All @@ -769,21 +760,24 @@ func (h *connectionsHub) prepareConnectionInitation(ski string, counter int, ent
}

// now initiate the connection
// check if the remoteService still exists
service := h.serviceForSKI(ski)

if success := h.initateConnection(service, entry); !success {
h.checkRestartMdnsSearch()
}
}

// attempt to establish a connection to a remote service
// returns true if successful
func (h *connectionsHub) initateConnection(remoteService *ServiceDetails, entry MdnsEntry) bool {
func (h *connectionsHub) initateConnection(remoteService *ServiceDetails, entry *MdnsEntry) bool {
var err error

// try connecting via an IP address first
for _, address := range entry.Addresses {
// connection attempt is not relevant if the device is no longer paired
// or it is not queued for pairing
pairingState := h.serviceForSKI(remoteService.SKI).ConnectionStateDetail.State
pairingState := h.serviceForSKI(remoteService.SKI).ConnectionStateDetail.State()
if !h.IsRemoteServiceForSKIPaired(remoteService.SKI) && pairingState != ConnectionStateQueued {
return false
}
Expand Down
74 changes: 57 additions & 17 deletions service/mdns.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type MdnsEntry struct {

// implemented by hubConnection, used by mdns
type MdnsSearch interface {
ReportMdnsEntries(entries map[string]MdnsEntry)
ReportMdnsEntries(entries map[string]*MdnsEntry)
}

// implemented by mdns, used by hubConnection
Expand All @@ -56,21 +56,22 @@ type mdnsManager struct {
cancelChan chan bool

// the currently available mDNS entries with the SKI as the key in the map
entries map[string]MdnsEntry
entries map[string]*MdnsEntry

// the registered callback, only connectionsHub is using this
searchDelegate MdnsSearch

mdnsProvider mdns.MdnsProvider

mux sync.Mutex
mux sync.Mutex
entriesMux sync.Mutex
}

func newMDNS(ski string, configuration *Configuration) *mdnsManager {
m := &mdnsManager{
ski: ski,
configuration: configuration,
entries: make(map[string]MdnsEntry),
entries: make(map[string]*MdnsEntry),
cancelChan: make(chan bool),
}

Expand Down Expand Up @@ -203,6 +204,49 @@ func (m *mdnsManager) setIsSearchingServices(enable bool) {
m.isSearchingServices = enable
}

func (m *mdnsManager) mdnsEntries() map[string]*MdnsEntry {
m.entriesMux.Lock()
defer m.entriesMux.Unlock()

return m.entries
}

func (m *mdnsManager) copyMdnsEntries() map[string]*MdnsEntry {
m.entriesMux.Lock()
defer m.entriesMux.Unlock()

mdnsEntries := make(map[string]*MdnsEntry)
for k, v := range m.entries {
newEntry := &MdnsEntry{}
util.DeepCopy[*MdnsEntry](v, newEntry)
mdnsEntries[k] = newEntry
}

return mdnsEntries
}

func (m *mdnsManager) mdnsEntry(ski string) (*MdnsEntry, bool) {
m.entriesMux.Lock()
defer m.entriesMux.Unlock()

entry, ok := m.entries[ski]
return entry, ok
}

func (m *mdnsManager) setMdnsEntry(ski string, entry *MdnsEntry) {
m.entriesMux.Lock()
defer m.entriesMux.Unlock()

m.entries[ski] = entry
}

func (m *mdnsManager) removeMdnsEntry(ski string) {
m.entriesMux.Lock()
defer m.entriesMux.Unlock()

delete(m.entries, ski)
}

// Register a callback to be invoked for found mDNS entries
func (m *mdnsManager) RegisterMdnsSearch(cb MdnsSearch) {
m.mux.Lock()
Expand All @@ -218,12 +262,12 @@ func (m *mdnsManager) RegisterMdnsSearch(cb MdnsSearch) {
}

// do we already know some entries?
if len(m.entries) == 0 {
if len(m.mdnsEntries()) == 0 {
return
}

// maybe entries are already found
mdnsEntries := m.entries
mdnsEntries := m.copyMdnsEntries()

go m.searchDelegate.ReportMdnsEntries(mdnsEntries)
}
Expand Down Expand Up @@ -315,18 +359,17 @@ func (m *mdnsManager) processMdnsEntry(elements map[string]string, name, host st

updated := true

_, exists := m.entries[ski]
entry, exists := m.mdnsEntry(ski)

if remove && exists {
// remove
// there will be a remove for each address with avahi, but we'll delete it right away
delete(m.entries, ski)
m.removeMdnsEntry(ski)
} else if exists {
// update
updated = false

// avahi sends an item for each network address, merge them
entry := m.entries[ski]

// we assume only network addresses are added
for _, address := range addresses {
Expand All @@ -346,10 +389,10 @@ func (m *mdnsManager) processMdnsEntry(elements map[string]string, name, host st
}
}

m.entries[ski] = entry
m.setMdnsEntry(ski, entry)
} else if !exists && !remove {
// new
newEntry := MdnsEntry{
newEntry := &MdnsEntry{
Name: name,
Ski: ski,
Identifier: identifier,
Expand All @@ -362,18 +405,15 @@ func (m *mdnsManager) processMdnsEntry(elements map[string]string, name, host st
Port: port,
Addresses: addresses,
}
m.entries[ski] = newEntry
m.setMdnsEntry(ski, newEntry)

logging.Log.Debug("ski:", ski, "name:", name, "brand:", brand, "model:", model, "typ:", deviceType, "identifier:", identifier, "register:", register, "host:", host, "port:", port, "addresses:", addresses)
} else {
return
}

if m.searchDelegate != nil && updated {
mdnsEntries := make(map[string]MdnsEntry)
for k, v := range m.entries {
mdnsEntries[k] = v
}
go m.searchDelegate.ReportMdnsEntries(mdnsEntries)
entries := m.copyMdnsEntries()
go m.searchDelegate.ReportMdnsEntries(entries)
}
}
Loading

0 comments on commit 50a4423

Please sign in to comment.