diff --git a/agent.go b/agent.go index b9268b8b..0364de70 100644 --- a/agent.go +++ b/agent.go @@ -145,6 +145,8 @@ type Agent struct { insecureSkipVerify bool proxyDialer proxy.Dialer + + enableUseCandidateCheckPriority bool } // NewAgent creates a new Agent @@ -219,6 +221,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit disableActiveTCP: config.DisableActiveTCP, userBindingRequestHandler: config.BindingRequestHandler, + + enableUseCandidateCheckPriority: config.EnableUseCandidateCheckPriority, } a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange, done: make(chan struct{})} a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate, done: make(chan struct{})} @@ -1219,3 +1223,7 @@ func (a *Agent) setGatheringState(newState GatheringState) error { <-done return nil } + +func (a *Agent) needsToCheckPriorityOnNominated() bool { + return !a.lite || a.enableUseCandidateCheckPriority +} diff --git a/agent_config.go b/agent_config.go index 93e88890..ad3dd497 100644 --- a/agent_config.go +++ b/agent_config.go @@ -200,6 +200,13 @@ type AgentConfig struct { // * Implement draft-thatcher-ice-renomination // * Implement custom CandidatePair switching logic BindingRequestHandler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool + + // EnableUseCandidateCheckPriority can be used to enable checking for equal or higher priority to + // switch selected candidate pair if the peer requests USE-CANDIDATE and agent is a lite agent. + // This is disabled by default, i. e. when peer requests USE-CANDIDATE, the selected pair will be + // switched to that irrespective of relative priority between current selected pair + // and priority of the pair being switched to. + EnableUseCandidateCheckPriority bool } // initWithDefaults populates an agent and falls back to defaults if fields are unset diff --git a/agent_test.go b/agent_test.go index 168fcddb..4b64976e 100644 --- a/agent_test.go +++ b/agent_test.go @@ -1844,99 +1844,148 @@ func TestAcceptAggressiveNomination(t *testing.T) { require.NoError(t, wan.Start()) - aNotifier, aConnected := onConnected() - bNotifier, bConnected := onConnected() - - KeepaliveInterval := time.Hour - cfg0 := &AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, - MulticastDNSMode: MulticastDNSModeDisabled, - Net: net0, - - KeepaliveInterval: &KeepaliveInterval, - CheckInterval: &KeepaliveInterval, - AcceptAggressiveNomination: true, - } - - var aAgent, bAgent *Agent - aAgent, err = NewAgent(cfg0) - require.NoError(t, err) - defer func() { - require.NoError(t, aAgent.Close()) - }() - require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) - - cfg1 := &AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, - MulticastDNSMode: MulticastDNSModeDisabled, - Net: net1, - KeepaliveInterval: &KeepaliveInterval, - CheckInterval: &KeepaliveInterval, - } - - bAgent, err = NewAgent(cfg1) - require.NoError(t, err) - defer func() { - require.NoError(t, bAgent.Close()) - }() - require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) - - connect(aAgent, bAgent) - - // Ensure pair selected - // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair - <-aConnected - <-bConnected + testCases := []struct { + name string + isLite bool + enableUseCandidateCheckPriority bool + useHigherPriority bool + isExpectedToSwitch bool + }{ + {"should accept higher priority - full agent", false, false, true, true}, + {"should not accept lower priority - full agent", false, false, false, false}, + {"should accept higher priority - no use-candidate priority check - lite agent", true, false, true, true}, + {"should accept lower priority - no use-candidate priority check - lite agent", true, false, false, true}, + {"should accept higher priority - use-candidate priority check - lite agent", true, true, true, true}, + {"should not accept lower priority - use-candidate priority check - lite agent", true, true, false, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + aNotifier, aConnected := onConnected() + bNotifier, bConnected := onConnected() + + KeepaliveInterval := time.Hour + cfg0 := &AgentConfig{ + NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, + MulticastDNSMode: MulticastDNSModeDisabled, + Net: net0, + KeepaliveInterval: &KeepaliveInterval, + CheckInterval: &KeepaliveInterval, + Lite: tc.isLite, + EnableUseCandidateCheckPriority: tc.enableUseCandidateCheckPriority, + } + if tc.isLite { + cfg0.CandidateTypes = []CandidateType{CandidateTypeHost} + } - // Send new USE-CANDIDATE message with higher priority to update the selected pair - buildMsg := func(class stun.MessageClass, username, key string, priority uint32) *stun.Message { - msg, err1 := stun.Build(stun.NewType(stun.MethodBinding, class), stun.TransactionID, - stun.NewUsername(username), - stun.NewShortTermIntegrity(key), - UseCandidate(), - PriorityAttr(priority), - stun.Fingerprint, - ) - require.NoError(t, err1) + var aAgent, bAgent *Agent + aAgent, err = NewAgent(cfg0) + require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() + require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) + + cfg1 := &AgentConfig{ + NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, + MulticastDNSMode: MulticastDNSModeDisabled, + Net: net1, + KeepaliveInterval: &KeepaliveInterval, + CheckInterval: &KeepaliveInterval, + } - return msg - } + bAgent, err = NewAgent(cfg1) + require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() + require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) + + connect(aAgent, bAgent) + + // Ensure pair selected + // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair + <-aConnected + <-bConnected + + // Send new USE-CANDIDATE message with priority to update the selected pair + buildMsg := func(class stun.MessageClass, username, key string, priority uint32) *stun.Message { + msg, err1 := stun.Build(stun.NewType(stun.MethodBinding, class), stun.TransactionID, + stun.NewUsername(username), + stun.NewShortTermIntegrity(key), + UseCandidate(), + PriorityAttr(priority), + stun.Fingerprint, + ) + require.NoError(t, err1) + + return msg + } - selectedCh := make(chan Candidate, 1) - var expectNewSelectedCandidate Candidate - err = aAgent.OnSelectedCandidatePairChange(func(_, remote Candidate) { - selectedCh <- remote - }) - require.NoError(t, err) - var bcandidates []Candidate - bcandidates, err = bAgent.GetLocalCandidates() - require.NoError(t, err) + selectedCh := make(chan Candidate, 1) + var expectNewSelectedCandidate Candidate + err = aAgent.OnSelectedCandidatePairChange(func(_, remote Candidate) { + selectedCh <- remote + }) + require.NoError(t, err) + var bcandidates []Candidate + bcandidates, err = bAgent.GetLocalCandidates() + require.NoError(t, err) - for _, c := range bcandidates { - if c != bAgent.getSelectedPair().Local { - if expectNewSelectedCandidate == nil { - incr_priority: - for _, candidates := range aAgent.remoteCandidates { - for _, candidate := range candidates { - if candidate.Equal(c) { - candidate.(*CandidateHost).priorityOverride += 1000 //nolint:forcetypeassert - break incr_priority + for _, c := range bcandidates { + if c != bAgent.getSelectedPair().Local { + if expectNewSelectedCandidate == nil { + expected_change_priority: + for _, candidates := range aAgent.remoteCandidates { + for _, candidate := range candidates { + if candidate.Equal(c) { + if tc.useHigherPriority { + candidate.(*CandidateHost).priorityOverride += 1000 //nolint:forcetypeassert + } else { + candidate.(*CandidateHost).priorityOverride -= 1000 //nolint:forcetypeassert + } + break expected_change_priority + } + } + } + if tc.isExpectedToSwitch { + expectNewSelectedCandidate = c + } else { + expectNewSelectedCandidate = aAgent.getSelectedPair().Remote + } + } else { + // a smaller change for other candidates other the new expected one + change_priority: + for _, candidates := range aAgent.remoteCandidates { + for _, candidate := range candidates { + if candidate.Equal(c) { + if tc.useHigherPriority { + candidate.(*CandidateHost).priorityOverride += 500 //nolint:forcetypeassert + } else { + candidate.(*CandidateHost).priorityOverride -= 500 //nolint:forcetypeassert + } + break change_priority + } + } } } + _, err = c.writeTo(buildMsg(stun.ClassRequest, aAgent.localUfrag+":"+aAgent.remoteUfrag, aAgent.localPwd, c.Priority()).Raw, bAgent.getSelectedPair().Remote) + require.NoError(t, err) } - expectNewSelectedCandidate = c } - _, err = c.writeTo(buildMsg(stun.ClassRequest, aAgent.localUfrag+":"+aAgent.remoteUfrag, aAgent.localPwd, c.Priority()).Raw, bAgent.getSelectedPair().Remote) - require.NoError(t, err) - } - } - time.Sleep(1 * time.Second) - select { - case selected := <-selectedCh: - require.True(t, selected.Equal(expectNewSelectedCandidate)) - default: - t.Fatal("No selected candidate pair") + time.Sleep(1 * time.Second) + select { + case selected := <-selectedCh: + require.True(t, selected.Equal(expectNewSelectedCandidate)) + default: + if !tc.isExpectedToSwitch { + require.True(t, aAgent.getSelectedPair().Remote.Equal(expectNewSelectedCandidate)) + } else { + t.Fatal("No selected candidate pair") + } + } + }) } require.NoError(t, wan.Stop()) diff --git a/selection.go b/selection.go index d3105301..9aa4cad4 100644 --- a/selection.go +++ b/selection.go @@ -241,7 +241,7 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot s.log.Tracef("Found valid candidate pair: %s", p) if p.nominateOnBindingSuccess { if selectedPair := s.agent.getSelectedPair(); selectedPair == nil || - (selectedPair != p && selectedPair.priority() <= p.priority()) { + (selectedPair != p && (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= p.priority())) { s.agent.setSelectedPair(p) } else if selectedPair != p { s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair) @@ -266,7 +266,7 @@ func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote // generated a valid pair (Section 7.2.5.3.2). The agent sets the // nominated flag value of the valid pair to true. selectedPair := s.agent.getSelectedPair() - if selectedPair == nil || (selectedPair != p && selectedPair.priority() <= p.priority()) { + if selectedPair == nil || (selectedPair != p && (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= p.priority())) { s.agent.setSelectedPair(p) } else if selectedPair != p { s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair)