diff --git a/packages/jest-mock/src/__tests__/jest_mock.test.js b/packages/jest-mock/src/__tests__/jest_mock.test.js index 5553a0d2eae6..b24a383260b8 100644 --- a/packages/jest-mock/src/__tests__/jest_mock.test.js +++ b/packages/jest-mock/src/__tests__/jest_mock.test.js @@ -341,6 +341,14 @@ describe('moduleMocker', () => { expect(fn1()).not.toEqual('abcd'); expect(fn2()).not.toEqual('abcd'); }); + + it('maintains function arity', () => { + const mockFunctionArity1 = moduleMocker.fn(x => x); + const mockFunctionArity2 = moduleMocker.fn((x, y) => y); + + expect(mockFunctionArity1.length).toBe(1); + expect(mockFunctionArity2.length).toBe(2); + }); }); it('supports mock value returning undefined', () => { diff --git a/packages/jest-mock/src/index.js b/packages/jest-mock/src/index.js index bdfbbf26b2c0..5f93868c173b 100644 --- a/packages/jest-mock/src/index.js +++ b/packages/jest-mock/src/index.js @@ -19,6 +19,7 @@ export type MockFunctionMetadata = { refID?: string | number, type?: string, value?: any, + length?: number, }; type MockFunctionState = { @@ -88,6 +89,65 @@ const RESERVED_KEYWORDS = Object.assign(Object.create(null), { yield: true, }); +function matchArity(fn: any, length: number): any { + let mockConstructor; + + switch (length) { + case 1: + mockConstructor = function(a) { + return fn.apply(this, arguments); + }; + break; + case 2: + mockConstructor = function(a, b) { + return fn.apply(this, arguments); + }; + break; + case 3: + mockConstructor = function(a, b, c) { + return fn.apply(this, arguments); + }; + break; + case 4: + mockConstructor = function(a, b, c, d) { + return fn.apply(this, arguments); + }; + break; + case 5: + mockConstructor = function(a, b, c, d, e) { + return fn.apply(this, arguments); + }; + break; + case 6: + mockConstructor = function(a, b, c, d, e, f) { + return fn.apply(this, arguments); + }; + break; + case 7: + mockConstructor = function(a, b, c, d, e, f, g) { + return fn.apply(this, arguments); + }; + break; + case 8: + mockConstructor = function(a, b, c, d, e, f, g, h) { + return fn.apply(this, arguments); + }; + break; + case 9: + mockConstructor = function(a, b, c, d, e, f, g, h, i) { + return fn.apply(this, arguments); + }; + break; + default: + mockConstructor = function() { + return fn.apply(this, arguments); + }; + break; + } + + return mockConstructor; +} + function isA(typeName: string, value: any): boolean { return Object.prototype.toString.apply(value) === '[object ' + typeName + ']'; } @@ -242,7 +302,7 @@ class ModuleMockerClass { {}; const prototypeSlots = getSlots(prototype); const mocker = this; - const mockConstructor = function() { + const mockConstructor = matchArity(function() { const mockState = mocker._ensureMockState(f); const mockConfig = mocker._ensureMockConfig(f); mockState.instances.push(this); @@ -298,7 +358,7 @@ class ModuleMockerClass { } return returnValue; - }; + }, metadata.length || 0); f = this._createMockFunction(metadata, mockConstructor); f._isMockFunction = true; @@ -431,6 +491,7 @@ class ModuleMockerClass { MOCK_CONSTRUCTOR_NAME, body, ); + return createConstructor(mockConstructor); } @@ -565,7 +626,8 @@ class ModuleMockerClass { } fn(implementation?: any): any { - const fn = this._makeComponent({type: 'function'}); + const length = implementation ? implementation.length : 0; + const fn = this._makeComponent({length, type: 'function'}); if (implementation) { fn.mockImplementation(implementation); }