diff --git a/src/modules/Sticky/Sticky.js b/src/modules/Sticky/Sticky.js index f413ab0d7a..b257ea7b39 100644 --- a/src/modules/Sticky/Sticky.js +++ b/src/modules/Sticky/Sticky.js @@ -103,15 +103,23 @@ export default class Sticky extends Component { } componentWillReceiveProps(nextProps) { - const { active: current } = this.props - const { active: next } = nextProps + const { active: current, scrollContext: currentScrollContext } = this.props + const { active: next, scrollContext: nextScrollContext } = nextProps + + if (current === next) { + if (currentScrollContext !== nextScrollContext) { + this.removeListeners() + this.addListeners(nextProps) + } + return + } - if (current === next) return if (next) { this.handleUpdate() this.addListeners(nextProps) return } + this.removeListeners() this.setState({ sticky: false }) } @@ -130,15 +138,19 @@ export default class Sticky extends Component { addListeners = (props) => { const { scrollContext } = props - eventStack.sub('resize', this.handleUpdate, { target: scrollContext }) - eventStack.sub('scroll', this.handleUpdate, { target: scrollContext }) + if (scrollContext) { + eventStack.sub('resize', this.handleUpdate, { target: scrollContext }) + eventStack.sub('scroll', this.handleUpdate, { target: scrollContext }) + } } removeListeners = () => { const { scrollContext } = this.props - eventStack.unsub('resize', this.handleUpdate, { target: scrollContext }) - eventStack.unsub('scroll', this.handleUpdate, { target: scrollContext }) + if (scrollContext) { + eventStack.unsub('resize', this.handleUpdate, { target: scrollContext }) + eventStack.unsub('scroll', this.handleUpdate, { target: scrollContext }) + } } // ---------------------------------------- diff --git a/test/specs/modules/Sticky/Sticky-test.js b/test/specs/modules/Sticky/Sticky-test.js index a6790e72b6..aea08f3b18 100644 --- a/test/specs/modules/Sticky/Sticky-test.js +++ b/test/specs/modules/Sticky/Sticky-test.js @@ -305,6 +305,46 @@ describe('Sticky', () => { domEvent.scroll(div) onStick.should.have.been.called() }) + + it('should not call onStick when context is null', () => { + const onStick = sandbox.spy() + const instance = mount().instance() + + instance.triggerRef = mockRect({ top: -1 }) + + domEvent.scroll(document) + onStick.should.not.have.been.called() + }) + + it('should call onStick when scrollContext changes', () => { + const div = document.createElement('div') + const onStick = sandbox.spy() + const renderedComponent = mount() + const instance = renderedComponent.instance() + + instance.triggerRef = mockRect({ top: -1 }) + renderedComponent.setProps({ scrollContext: div }) + + domEvent.scroll(div) + onStick.should.have.been.called() + }) + + it('should not call onStick when scrollContext changes and component is unmounted', () => { + const div = document.createElement('div') + const onStick = sandbox.spy() + const renderedComponent = mount() + const instance = renderedComponent.instance() + + instance.triggerRef = mockRect({ top: -1 }) + renderedComponent.setProps({ scrollContext: div }) + renderedComponent.unmount() + + domEvent.scroll(div) + onStick.should.not.have.been.called() + + domEvent.scroll(document) + onStick.should.not.have.been.called() + }) }) describe('update', () => {