diff --git a/in_session.go b/in_session.go index 6edbb8409..5445967cd 100644 --- a/in_session.go +++ b/in_session.go @@ -87,7 +87,7 @@ func (state inSession) Timeout(session *session, event internal.Event) (nextStat } func (state inSession) handleLogout(session *session, msg *Message) (nextState sessionState) { - if err := session.verifySelect(msg, false, false); err != nil { + if err := session.verifySelect(msg, false, false, true); err != nil { return state.processReject(session, msg, err) } @@ -154,7 +154,7 @@ func (state inSession) handleSequenceReset(session *session, msg *Message) (next } } - if err := session.verifySelect(msg, bool(gapFillFlag), bool(gapFillFlag)); err != nil { + if err := session.verifySelect(msg, bool(gapFillFlag), bool(gapFillFlag), true); err != nil { return state.processReject(session, msg, err) } diff --git a/logon_state_test.go b/logon_state_test.go index ee47a2be6..f5c871ea6 100644 --- a/logon_state_test.go +++ b/logon_state_test.go @@ -226,6 +226,27 @@ func (s *LogonStateTestSuite) TestFixMsgInLogonInitiateLogonExpectResetSeqNum() s.NextSenderMsgSeqNum(2) } +func (s *LogonStateTestSuite) TestFixMsgInLogonInitiateLogonRejectedSeqNumNotReset() { + s.session.InitiateLogon = true + s.session.sentReset = true + s.Require().Nil(s.store.IncrNextSenderMsgSeqNum()) + + logon := s.Logon() + logon.Body.SetField(tagHeartBtInt, FIXInt(32)) + logon.Body.SetField(tagResetSeqNumFlag, FIXBoolean(true)) + + s.MockApp.On("FromAdmin").Return(RejectLogon{"reject message"}) + s.MockApp.On("OnLogout") + s.MockApp.On("ToAdmin") + s.fixMsgIn(s.session, logon) + + s.MockApp.AssertExpectations(s.T()) + s.State(latentState{}) + + s.NextTargetMsgSeqNum(2) + s.NextSenderMsgSeqNum(3) +} + func (s *LogonStateTestSuite) TestFixMsgInLogonInitiateLogonUnExpectedResetSeqNum() { s.session.InitiateLogon = true s.session.sentReset = false @@ -358,6 +379,7 @@ func (s *LogonStateTestSuite) TestFixMsgInLogonSeqNumTooLow() { logon.Body.SetField(tagHeartBtInt, FIXInt(32)) logon.Header.SetInt(tagMsgSeqNum, 1) + s.MockApp.On("FromAdmin").Return(nil) s.MockApp.On("ToAdmin") s.NextTargetMsgSeqNum(2) s.fixMsgIn(s.session, logon) diff --git a/session.go b/session.go index 5320f49b9..cdcb8603d 100644 --- a/session.go +++ b/session.go @@ -507,6 +507,13 @@ func (s *session) handleLogon(msg *Message) error { } } + nextSenderMsgNumAtLogonReceived := s.store.NextSenderMsgSeqNum() + + // Make sure this is a valid session before resetting the store. + if err := s.verifyMsgAgainstAppImpl(msg); err != nil { + return err + } + var resetSeqNumFlag FIXBoolean if err := msg.Body.GetField(tagResetSeqNumFlag, &resetSeqNumFlag); err == nil { if resetSeqNumFlag { @@ -517,14 +524,14 @@ func (s *session) handleLogon(msg *Message) error { } } - nextSenderMsgNumAtLogonReceived := s.store.NextSenderMsgSeqNum() - if resetStore { if err := s.store.Reset(); err != nil { return err } } + // Verify seq num too high but dont check against app implementation since we just did that. + // Don't need to double check. if err := s.verifyIgnoreSeqNumTooHigh(msg); err != nil { return err } @@ -586,18 +593,18 @@ func (s *session) initiateLogoutInReplyTo(reason string, inReplyTo *Message) (er } func (s *session) verify(msg *Message) MessageRejectError { - return s.verifySelect(msg, true, true) + return s.verifySelect(msg, true, true, true) } func (s *session) verifyIgnoreSeqNumTooHigh(msg *Message) MessageRejectError { - return s.verifySelect(msg, false, true) + return s.verifySelect(msg, false, true, false) } func (s *session) verifyIgnoreSeqNumTooHighOrLow(msg *Message) MessageRejectError { - return s.verifySelect(msg, false, false) + return s.verifySelect(msg, false, false, true) } -func (s *session) verifySelect(msg *Message, checkTooHigh bool, checkTooLow bool) MessageRejectError { +func (s *session) verifySelect(msg *Message, checkTooHigh bool, checkTooLow bool, checkAppImpl bool) MessageRejectError { if reject := s.checkBeginString(msg); reject != nil { return reject } @@ -626,6 +633,14 @@ func (s *session) verifySelect(msg *Message, checkTooHigh bool, checkTooLow bool } } + if checkAppImpl { + return s.verifyMsgAgainstAppImpl(msg) + } + + return nil +} + +func (s *session) verifyMsgAgainstAppImpl(msg *Message) MessageRejectError { if s.Validator != nil { if reject := s.Validator.Validate(msg); reject != nil { return reject