From 50a44231039b9419bd4ec02da826d3e696b86ad3 Mon Sep 17 00:00:00 2001 From: Andreas Linde Date: Sun, 31 Dec 2023 15:48:49 +0100 Subject: [PATCH] Fix race conditions --- cmd/evse/main.go | 4 +-- cmd/hems/main.go | 4 +-- service/hub.go | 64 +++++++++++++++------------------ service/mdns.go | 74 ++++++++++++++++++++++++++++++--------- service/mock_hub_test.go | 4 +-- service/mock_mdns_test.go | 2 +- service/service.go | 8 ++--- service/types.go | 48 ++++++++++++++++++++++--- 8 files changed, 141 insertions(+), 67 deletions(-) diff --git a/cmd/evse/main.go b/cmd/evse/main.go index e2f1a44a..6d8440c2 100644 --- a/cmd/evse/main.go +++ b/cmd/evse/main.go @@ -100,8 +100,8 @@ func (h *evse) VisibleRemoteServicesUpdated(service *service.EEBUSService, entri func (h *evse) ServiceShipIDUpdate(ski string, shipdID string) {} -func (h *evse) ServicePairingDetailUpdate(ski string, detail service.ConnectionStateDetail) { - if ski == remoteSki && detail.State == service.ConnectionStateRemoteDeniedTrust { +func (h *evse) ServicePairingDetailUpdate(ski string, detail *service.ConnectionStateDetail) { + if ski == remoteSki && detail.State() == service.ConnectionStateRemoteDeniedTrust { fmt.Println("The remote service denied trust. Exiting.") h.myService.RegisterRemoteSKI(ski, false) h.myService.CancelPairingWithSKI(ski) diff --git a/cmd/hems/main.go b/cmd/hems/main.go index b09b7b77..54f49bf2 100644 --- a/cmd/hems/main.go +++ b/cmd/hems/main.go @@ -100,8 +100,8 @@ func (h *hems) VisibleRemoteServicesUpdated(service *service.EEBUSService, entri func (h *hems) ServiceShipIDUpdate(ski string, shipdID string) {} -func (h *hems) ServicePairingDetailUpdate(ski string, detail service.ConnectionStateDetail) { - if ski == remoteSki && detail.State == service.ConnectionStateRemoteDeniedTrust { +func (h *hems) ServicePairingDetailUpdate(ski string, detail *service.ConnectionStateDetail) { + if ski == remoteSki && detail.State() == service.ConnectionStateRemoteDeniedTrust { fmt.Println("The remote service denied trust. Exiting.") h.myService.RegisterRemoteSKI(ski, false) h.myService.CancelPairingWithSKI(ski) diff --git a/service/hub.go b/service/hub.go index 832bbb1c..3ef0c91a 100644 --- a/service/hub.go +++ b/service/hub.go @@ -45,7 +45,7 @@ var connectionInitiationDelayTimeRanges = []connectionInitiationDelayTimeRange{ // interface for reporting data from connectionsHub to the EEBUSService type ServiceProvider interface { // report a newly discovered remote EEBUS service - VisibleMDNSRecordsUpdated(entries []MdnsEntry) + VisibleMDNSRecordsUpdated(entries []*MdnsEntry) // report a connection to a SKI RemoteSKIConnected(ski string) @@ -58,7 +58,7 @@ type ServiceProvider interface { ServiceShipIDUpdate(ski string, shipID string) // provides the current handshake state for a given SKI - ServicePairingDetailUpdate(ski string, detail ConnectionStateDetail) + ServicePairingDetailUpdate(ski string, detail *ConnectionStateDetail) // return if the user is still able to trust the connection AllowWaitingForTrust(ski string) bool @@ -87,7 +87,7 @@ type connectionsHub struct { mdns MdnsService // list of currently known/reported mDNS entries - knownMdnsEntries []MdnsEntry + knownMdnsEntries []*MdnsEntry // the SPINE local device spineLocalDevice *spine.DeviceLocalImpl @@ -104,7 +104,7 @@ func newConnectionsHub(serviceProvider ServiceProvider, mdns MdnsService, spineL connectionAttemptCounter: make(map[string]int), connectionAttemptRunning: make(map[string]bool), remoteServices: make(map[string]*ServiceDetails), - knownMdnsEntries: make([]MdnsEntry, 0), + knownMdnsEntries: make([]*MdnsEntry, 0), serviceProvider: serviceProvider, spineLocalDevice: spineLocalDevice, configuration: configuration, @@ -230,8 +230,6 @@ func (h *connectionsHub) AllowWaitingForTrust(ski string) bool { // Provides the current ship message exchange state for a given SKI and the corresponding error if state is error func (h *connectionsHub) HandleShipHandshakeStateUpdate(ski string, state ship.ShipState) { - service := h.serviceForSKI(ski) - // overwrite service Paired value if state.State == ship.SmeHelloStateOk { h.RegisterRemoteSKI(ski, true) @@ -242,13 +240,12 @@ func (h *connectionsHub) HandleShipHandshakeStateUpdate(ski string, state ship.S pairingState = ConnectionStateError } - pairingDetail := ConnectionStateDetail{ - State: pairingState, - Error: state.Error, - } + pairingDetail := NewConnectionStateDetail(pairingState, state.Error) + + service := h.serviceForSKI(ski) existingDetails := service.ConnectionStateDetail - if existingDetails.State != pairingState || existingDetails.Error != state.Error { + if existingDetails.State() != pairingState || existingDetails.Error() != state.Error { service.ConnectionStateDetail = pairingDetail h.serviceProvider.ServicePairingDetailUpdate(ski, pairingDetail) @@ -261,16 +258,13 @@ func (h *connectionsHub) HandleShipHandshakeStateUpdate(ski string, state ship.S // // ErrNotPaired if the SKI is not in the (to be) paired list // ErrNoConnectionFound if no connection for the SKI was found -func (h *connectionsHub) PairingDetailForSki(ski string) ConnectionStateDetail { +func (h *connectionsHub) PairingDetailForSki(ski string) *ConnectionStateDetail { service := h.serviceForSKI(ski) if conn := h.connectionForSKI(ski); conn != nil { shipState, shipError := conn.ShipHandshakeState() state := h.mapShipMessageExchangeState(shipState, ski) - return ConnectionStateDetail{ - State: state, - Error: shipError, - } + return NewConnectionStateDetail(state, shipError) } return service.ConnectionStateDetail @@ -457,8 +451,8 @@ func (h *connectionsHub) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Check if the remote service is paired service := h.serviceForSKI(remoteService.SKI) - if service.ConnectionStateDetail.State == ConnectionStateQueued { - service.ConnectionStateDetail.State = ConnectionStateReceivedPairingRequest + if service.ConnectionStateDetail.State() == ConnectionStateQueued { + service.ConnectionStateDetail.SetState(ConnectionStateReceivedPairingRequest) h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail) } @@ -601,7 +595,7 @@ func (h *connectionsHub) serviceForSKI(ski string) *ServiceDetails { service, ok := h.remoteServices[ski] if !ok { service = NewServiceDetails(ski) - service.ConnectionStateDetail.State = ConnectionStateNone + service.ConnectionStateDetail.SetState(ConnectionStateNone) h.remoteServices[ski] = service } @@ -622,7 +616,7 @@ func (h *connectionsHub) RegisterRemoteSKI(ski string, enable bool) { h.removeConnectionAttemptCounter(ski) - service.ConnectionStateDetail.State = ConnectionStateNone + service.ConnectionStateDetail.SetState(ConnectionStateNone) h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail) @@ -644,7 +638,7 @@ func (h *connectionsHub) InitiatePairingWithSKI(ski string) { // locally initiated service := h.serviceForSKI(ski) - service.ConnectionStateDetail.State = ConnectionStateQueued + service.ConnectionStateDetail.SetState(ConnectionStateQueued) h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail) @@ -663,18 +657,18 @@ func (h *connectionsHub) CancelPairingWithSKI(ski string) { } service := h.serviceForSKI(ski) - service.ConnectionStateDetail.State = ConnectionStateNone + service.ConnectionStateDetail.SetState(ConnectionStateNone) service.Trusted = false h.serviceProvider.ServicePairingDetailUpdate(ski, service.ConnectionStateDetail) } // Process reported mDNS services -func (h *connectionsHub) ReportMdnsEntries(entries map[string]MdnsEntry) { +func (h *connectionsHub) ReportMdnsEntries(entries map[string]*MdnsEntry) { h.muxMdns.Lock() defer h.muxMdns.Unlock() - var mdnsEntries []MdnsEntry + var mdnsEntries []*MdnsEntry for ski, entry := range entries { mdnsEntries = append(mdnsEntries, entry) @@ -686,8 +680,8 @@ func (h *connectionsHub) ReportMdnsEntries(entries map[string]MdnsEntry) { // Check if the remote service is paired or queued for connection service := h.serviceForSKI(ski) - pairingState := service.ConnectionStateDetail.State - if !h.IsRemoteServiceForSKIPaired(ski) && pairingState != ConnectionStateQueued { + if !h.IsRemoteServiceForSKIPaired(ski) && + service.ConnectionStateDetail.State() != ConnectionStateQueued { continue } @@ -715,7 +709,7 @@ func (h *connectionsHub) ReportMdnsEntries(entries map[string]MdnsEntry) { } // coordinate connection initiation attempts to a remove service -func (h *connectionsHub) coordinateConnectionInitations(ski string, entry MdnsEntry) { +func (h *connectionsHub) coordinateConnectionInitations(ski string, entry *MdnsEntry) { if h.isConnectionAttemptRunning(ski) { return } @@ -725,7 +719,7 @@ func (h *connectionsHub) coordinateConnectionInitations(ski string, entry MdnsEn counter, duration := h.getConnectionInitiationDelayTime(ski) service := h.serviceForSKI(ski) - if service.ConnectionStateDetail.State == ConnectionStateQueued { + if service.ConnectionStateDetail.State() == ConnectionStateQueued { go h.prepareConnectionInitation(ski, counter, entry) return } @@ -744,12 +738,9 @@ func (h *connectionsHub) coordinateConnectionInitations(ski string, entry MdnsEn // invoked by coordinateConnectionInitations either with a delay or directly // when initating a pairing process -func (h *connectionsHub) prepareConnectionInitation(ski string, counter int, entry MdnsEntry) { +func (h *connectionsHub) prepareConnectionInitation(ski string, counter int, entry *MdnsEntry) { h.setConnectionAttemptRunning(ski, false) - // check if the remoteService still exists - service := h.serviceForSKI(ski) - // check if the current counter is still the same, otherwise this counter is irrelevant currentCounter, exists := h.getCurrentConnectionAttemptCounter(ski) if !exists || currentCounter != counter { @@ -758,7 +749,7 @@ func (h *connectionsHub) prepareConnectionInitation(ski string, counter int, ent // connection attempt is not relevant if the device is no longer paired // or it is not queued for pairing - pairingState := h.serviceForSKI(ski).ConnectionStateDetail.State + pairingState := h.serviceForSKI(ski).ConnectionStateDetail.State() if !h.IsRemoteServiceForSKIPaired(ski) && pairingState != ConnectionStateQueued { return } @@ -769,6 +760,9 @@ func (h *connectionsHub) prepareConnectionInitation(ski string, counter int, ent } // now initiate the connection + // check if the remoteService still exists + service := h.serviceForSKI(ski) + if success := h.initateConnection(service, entry); !success { h.checkRestartMdnsSearch() } @@ -776,14 +770,14 @@ func (h *connectionsHub) prepareConnectionInitation(ski string, counter int, ent // attempt to establish a connection to a remote service // returns true if successful -func (h *connectionsHub) initateConnection(remoteService *ServiceDetails, entry MdnsEntry) bool { +func (h *connectionsHub) initateConnection(remoteService *ServiceDetails, entry *MdnsEntry) bool { var err error // try connecting via an IP address first for _, address := range entry.Addresses { // connection attempt is not relevant if the device is no longer paired // or it is not queued for pairing - pairingState := h.serviceForSKI(remoteService.SKI).ConnectionStateDetail.State + pairingState := h.serviceForSKI(remoteService.SKI).ConnectionStateDetail.State() if !h.IsRemoteServiceForSKIPaired(remoteService.SKI) && pairingState != ConnectionStateQueued { return false } diff --git a/service/mdns.go b/service/mdns.go index c377f70d..7649b277 100644 --- a/service/mdns.go +++ b/service/mdns.go @@ -33,7 +33,7 @@ type MdnsEntry struct { // implemented by hubConnection, used by mdns type MdnsSearch interface { - ReportMdnsEntries(entries map[string]MdnsEntry) + ReportMdnsEntries(entries map[string]*MdnsEntry) } // implemented by mdns, used by hubConnection @@ -56,21 +56,22 @@ type mdnsManager struct { cancelChan chan bool // the currently available mDNS entries with the SKI as the key in the map - entries map[string]MdnsEntry + entries map[string]*MdnsEntry // the registered callback, only connectionsHub is using this searchDelegate MdnsSearch mdnsProvider mdns.MdnsProvider - mux sync.Mutex + mux sync.Mutex + entriesMux sync.Mutex } func newMDNS(ski string, configuration *Configuration) *mdnsManager { m := &mdnsManager{ ski: ski, configuration: configuration, - entries: make(map[string]MdnsEntry), + entries: make(map[string]*MdnsEntry), cancelChan: make(chan bool), } @@ -203,6 +204,49 @@ func (m *mdnsManager) setIsSearchingServices(enable bool) { m.isSearchingServices = enable } +func (m *mdnsManager) mdnsEntries() map[string]*MdnsEntry { + m.entriesMux.Lock() + defer m.entriesMux.Unlock() + + return m.entries +} + +func (m *mdnsManager) copyMdnsEntries() map[string]*MdnsEntry { + m.entriesMux.Lock() + defer m.entriesMux.Unlock() + + mdnsEntries := make(map[string]*MdnsEntry) + for k, v := range m.entries { + newEntry := &MdnsEntry{} + util.DeepCopy[*MdnsEntry](v, newEntry) + mdnsEntries[k] = newEntry + } + + return mdnsEntries +} + +func (m *mdnsManager) mdnsEntry(ski string) (*MdnsEntry, bool) { + m.entriesMux.Lock() + defer m.entriesMux.Unlock() + + entry, ok := m.entries[ski] + return entry, ok +} + +func (m *mdnsManager) setMdnsEntry(ski string, entry *MdnsEntry) { + m.entriesMux.Lock() + defer m.entriesMux.Unlock() + + m.entries[ski] = entry +} + +func (m *mdnsManager) removeMdnsEntry(ski string) { + m.entriesMux.Lock() + defer m.entriesMux.Unlock() + + delete(m.entries, ski) +} + // Register a callback to be invoked for found mDNS entries func (m *mdnsManager) RegisterMdnsSearch(cb MdnsSearch) { m.mux.Lock() @@ -218,12 +262,12 @@ func (m *mdnsManager) RegisterMdnsSearch(cb MdnsSearch) { } // do we already know some entries? - if len(m.entries) == 0 { + if len(m.mdnsEntries()) == 0 { return } // maybe entries are already found - mdnsEntries := m.entries + mdnsEntries := m.copyMdnsEntries() go m.searchDelegate.ReportMdnsEntries(mdnsEntries) } @@ -315,18 +359,17 @@ func (m *mdnsManager) processMdnsEntry(elements map[string]string, name, host st updated := true - _, exists := m.entries[ski] + entry, exists := m.mdnsEntry(ski) if remove && exists { // remove // there will be a remove for each address with avahi, but we'll delete it right away - delete(m.entries, ski) + m.removeMdnsEntry(ski) } else if exists { // update updated = false // avahi sends an item for each network address, merge them - entry := m.entries[ski] // we assume only network addresses are added for _, address := range addresses { @@ -346,10 +389,10 @@ func (m *mdnsManager) processMdnsEntry(elements map[string]string, name, host st } } - m.entries[ski] = entry + m.setMdnsEntry(ski, entry) } else if !exists && !remove { // new - newEntry := MdnsEntry{ + newEntry := &MdnsEntry{ Name: name, Ski: ski, Identifier: identifier, @@ -362,7 +405,7 @@ func (m *mdnsManager) processMdnsEntry(elements map[string]string, name, host st Port: port, Addresses: addresses, } - m.entries[ski] = newEntry + m.setMdnsEntry(ski, newEntry) logging.Log.Debug("ski:", ski, "name:", name, "brand:", brand, "model:", model, "typ:", deviceType, "identifier:", identifier, "register:", register, "host:", host, "port:", port, "addresses:", addresses) } else { @@ -370,10 +413,7 @@ func (m *mdnsManager) processMdnsEntry(elements map[string]string, name, host st } if m.searchDelegate != nil && updated { - mdnsEntries := make(map[string]MdnsEntry) - for k, v := range m.entries { - mdnsEntries[k] = v - } - go m.searchDelegate.ReportMdnsEntries(mdnsEntries) + entries := m.copyMdnsEntries() + go m.searchDelegate.ReportMdnsEntries(entries) } } diff --git a/service/mock_hub_test.go b/service/mock_hub_test.go index 25ccdb98..99f8b912 100644 --- a/service/mock_hub_test.go +++ b/service/mock_hub_test.go @@ -72,7 +72,7 @@ func (mr *MockServiceProviderMockRecorder) RemoteSKIDisconnected(arg0 interface{ } // ServicePairingDetailUpdate mocks base method. -func (m *MockServiceProvider) ServicePairingDetailUpdate(arg0 string, arg1 ConnectionStateDetail) { +func (m *MockServiceProvider) ServicePairingDetailUpdate(arg0 string, arg1 *ConnectionStateDetail) { m.ctrl.T.Helper() m.ctrl.Call(m, "ServicePairingDetailUpdate", arg0, arg1) } @@ -96,7 +96,7 @@ func (mr *MockServiceProviderMockRecorder) ServiceShipIDUpdate(arg0, arg1 interf } // VisibleMDNSRecordsUpdated mocks base method. -func (m *MockServiceProvider) VisibleMDNSRecordsUpdated(arg0 []MdnsEntry) { +func (m *MockServiceProvider) VisibleMDNSRecordsUpdated(arg0 []*MdnsEntry) { m.ctrl.T.Helper() m.ctrl.Call(m, "VisibleMDNSRecordsUpdated", arg0) } diff --git a/service/mock_mdns_test.go b/service/mock_mdns_test.go index 45d6efc3..2207663f 100644 --- a/service/mock_mdns_test.go +++ b/service/mock_mdns_test.go @@ -34,7 +34,7 @@ func (m *MockMdnsSearch) EXPECT() *MockMdnsSearchMockRecorder { } // ReportMdnsEntries mocks base method. -func (m *MockMdnsSearch) ReportMdnsEntries(arg0 map[string]MdnsEntry) { +func (m *MockMdnsSearch) ReportMdnsEntries(arg0 map[string]*MdnsEntry) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReportMdnsEntries", arg0) } diff --git a/service/service.go b/service/service.go index cf164be0..15ea9359 100644 --- a/service/service.go +++ b/service/service.go @@ -38,7 +38,7 @@ type EEBUSServiceHandler interface { // Provides the current pairing state for the remote service // This is called whenever the state changes and can be used to // provide user information for the pairing/connection process - ServicePairingDetailUpdate(ski string, detail ConnectionStateDetail) + ServicePairingDetailUpdate(ski string, detail *ConnectionStateDetail) // return if the user is still able to trust the connection AllowWaitingForTrust(ski string) bool @@ -73,7 +73,7 @@ func NewEEBUSService(configuration *Configuration, serviceHandler EEBUSServiceHa var _ ServiceProvider = (*EEBUSService)(nil) -func (s *EEBUSService) VisibleMDNSRecordsUpdated(entries []MdnsEntry) { +func (s *EEBUSService) VisibleMDNSRecordsUpdated(entries []*MdnsEntry) { var remoteServices []RemoteService for _, entry := range entries { @@ -109,12 +109,12 @@ func (s *EEBUSService) ServiceShipIDUpdate(ski string, shipdID string) { // Provides the current pairing state for the remote service // This is called whenever the state changes and can be used to // provide user information for the pairing/connection process -func (s *EEBUSService) ServicePairingDetailUpdate(ski string, detail ConnectionStateDetail) { +func (s *EEBUSService) ServicePairingDetailUpdate(ski string, detail *ConnectionStateDetail) { s.serviceHandler.ServicePairingDetailUpdate(ski, detail) } // Get the current pairing details for a given SKI -func (s *EEBUSService) PairingDetailForSki(ski string) ConnectionStateDetail { +func (s *EEBUSService) PairingDetailForSki(ski string) *ConnectionStateDetail { return s.connectionsHub.PairingDetailForSki(ski) } diff --git a/service/types.go b/service/types.go index ee710041..aa537f54 100644 --- a/service/types.go +++ b/service/types.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "errors" "fmt" + "sync" "time" "github.com/enbility/eebus-go/spine/model" @@ -30,8 +31,45 @@ const ( // the connection state of a service and error if applicable type ConnectionStateDetail struct { - State ConnectionState - Error error + state ConnectionState + error error + + mux sync.Mutex +} + +func NewConnectionStateDetail(state ConnectionState, err error) *ConnectionStateDetail { + return &ConnectionStateDetail{ + state: state, + error: err, + } +} + +func (c *ConnectionStateDetail) State() ConnectionState { + c.mux.Lock() + defer c.mux.Unlock() + + return c.state +} + +func (c *ConnectionStateDetail) SetState(state ConnectionState) { + c.mux.Lock() + defer c.mux.Unlock() + + c.state = state +} + +func (c *ConnectionStateDetail) Error() error { + c.mux.Lock() + defer c.mux.Unlock() + + return c.error +} + +func (c *ConnectionStateDetail) SetError(err error) { + c.mux.Lock() + defer c.mux.Unlock() + + c.error = err } // generic service details about the local or any remote service @@ -63,13 +101,15 @@ type ServiceDetails struct { Trusted bool // the current connection state details - ConnectionStateDetail ConnectionStateDetail + ConnectionStateDetail *ConnectionStateDetail } // create a new ServiceDetails record with a SKI func NewServiceDetails(ski string) *ServiceDetails { + connState := NewConnectionStateDetail(ConnectionStateNone, nil) service := &ServiceDetails{ - SKI: util.NormalizeSKI(ski), // standardize the provided SKI strings + SKI: util.NormalizeSKI(ski), // standardize the provided SKI strings + ConnectionStateDetail: connState, } return service