diff --git a/src/ShallowWrapper.js b/src/ShallowWrapper.js index 38d723331..7514e20c8 100644 --- a/src/ShallowWrapper.js +++ b/src/ShallowWrapper.js @@ -11,6 +11,7 @@ import { withSetStateAllowed, propsOfNode, typeOfNode, + isReactElementAlike, } from './Utils'; import { debugNodes, @@ -209,6 +210,13 @@ export default class ShallowWrapper { * @returns {Boolean} */ contains(nodeOrNodes) { + if (!isReactElementAlike(nodeOrNodes)) { + throw new Error( + 'ShallowWrapper::contains() can only be called with ReactElement (or array of them), ' + + 'string or number as argument.' + ); + } + const predicate = Array.isArray(nodeOrNodes) ? other => containsChildrenSubArray(nodeEqual, other, nodeOrNodes) : other => nodeEqual(nodeOrNodes, other); diff --git a/src/Utils.js b/src/Utils.js index 001c37ad6..a40eaf82f 100644 --- a/src/Utils.js +++ b/src/Utils.js @@ -1,5 +1,6 @@ /* eslint no-use-before-define:0 */ import isEqual from 'lodash/isEqual'; +import React from 'react'; import { isDOMComponent, findDOMNode, @@ -94,7 +95,7 @@ export function nodeEqual(a, b) { } } - if (typeof a !== 'string' && typeof a !== 'number') { + if (!isTextualNode(a)) { return leftKeys.length === Object.keys(right).length; } @@ -117,6 +118,13 @@ function childrenOfNode(node) { return childrenToArray(children); } +function isTextualNode(node) { + return typeof node === 'string' || typeof node === 'number'; +} + +export function isReactElementAlike(arg) { + return React.isValidElement(arg) || isTextualNode(arg) || Array.isArray(arg); +} // 'click' => 'onClick' // 'mouseEnter' => 'onMouseEnter' diff --git a/test/ShallowWrapper-spec.js b/test/ShallowWrapper-spec.js index 684503a71..30eca18c2 100644 --- a/test/ShallowWrapper-spec.js +++ b/test/ShallowWrapper-spec.js @@ -193,6 +193,13 @@ describe('shallow', () => { expect(wrapper.contains(passes2)).to.equal(true); }); + it('should throw on invalid argument', () => { + const wrapper = shallow(
); + + expect(() => wrapper.contains({})).to.throw(); + expect(() => wrapper.contains(() => ({}))).to.throw(); + }); + describeIf(!REACT013, 'stateless function components', () => { it('should match composite components', () => { const Foo = () => (