Skip to content

Commit

Permalink
feature(transformer): Support overloaded functions by attaching signa…
Browse files Browse the repository at this point in the history
…tures on use

Add overloads feature flag that enables this feature.

Enabling it makes the transformer process function calls if their declaration
was previously marked for mocking (via getMethod).

From a type-perspective, typed methods shouldn't bother to consider their inputs
in order to determine the output in runtime. At transformation time, the type
checker resolves the matching overload and that information can be used to
attach to the function, by utilizing the "instance" (`this`) of it. The
transformer changes transform functions in the following way.

```
mockedFunction() -> mockedFunction.apply(<signature>, [])
```

As for constructor instantiation signatures in interfaces, those can be wrapped
by an intermediate function that will copy the mocked properties to preserve the
instantiation behavior.

```
new mockedNewFunction()
    |
    `-> new (mockedNewFunction[<signature>] || (mockedNewFunction[<signature>] = function() {
               Object.assign(this, mockedNewFunction.apply(<signature>, []));
            }))()
```

These attached interfaces will determine the branching at runtime and to reduce
as much overhead as possible, all signatures of an overloaded function are
mapped to the resolved return type and stored in a jump table, i.e.:

```
getMethod("functionName", function () {
  const jt = {
    ['<signature-1>']: () => <signature-1-return-descriptor>,
    ['<signature-2>']: () => <signature-2-return-descriptor>,
    ...
  };

  return jt[this]();
})
```

It should be noted, that if spies are introduced using the method provider, then
`this` will be occupied by the signature key.
  • Loading branch information
martinjlowm committed Jun 13, 2020
1 parent f28570c commit 340b562
Show file tree
Hide file tree
Showing 22 changed files with 373 additions and 95 deletions.
3 changes: 2 additions & 1 deletion config/utils/features.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ function DetermineFeaturesFromEnvironment() {

if (features) {
return [
'random'
'overloads',
'random',
];
}

Expand Down
6 changes: 4 additions & 2 deletions src/extension/method/provider/functionMethod.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function functionMethod(name: string, value: () => any): any {
export function functionMethod(name: string, value: (...args: any[]) => any): any {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (): any => value();
return function(...args: any[]): any {
return value.apply(this, args);
};
}
2 changes: 1 addition & 1 deletion src/merge/merge.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { merge} from 'lodash-es';
import { merge } from 'lodash-es';
import { DeepPartial } from '../partial/deepPartial';

export class Merge {
Expand Down
2 changes: 1 addition & 1 deletion src/options/features.ts
Original file line number Diff line number Diff line change
@@ -1 +1 @@
export type TsAutoMockFeaturesOption = 'random';
export type TsAutoMockFeaturesOption = 'random' | 'overloads';
5 changes: 5 additions & 0 deletions src/options/overloads.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import { GetOptionByKey } from './options';

export function IsTsAutoMockOverloadsEnabled(): boolean {
return GetOptionByKey('features').includes('overloads');
}
85 changes: 79 additions & 6 deletions src/transformer/base/base.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as ts from 'typescript';
import ts from 'typescript';
import { IsTsAutoMockOverloadsEnabled } from '../../options/overloads';
import { SetTsAutoMockOptions, TsAutoMockOptions } from '../../options/options';
import { SetTypeChecker } from '../typeChecker/typeChecker';
import { MockDefiner } from '../mockDefiner/mockDefiner';
Expand All @@ -9,7 +10,7 @@ import {
isFunctionFromThisLibrary,
} from '../matcher/matcher';

export type Visitor = (node: ts.CallExpression & { typeArguments: ts.NodeArray<ts.TypeNode> }, declaration: ts.FunctionDeclaration) => ts.Node;
export type Visitor = (node: ts.CallLikeExpression & { typeArguments: ts.NodeArray<ts.TypeNode> }, declaration: ts.SignatureDeclaration) => ts.Node;

export function baseTransformer(visitor: Visitor, customFunctions: CustomFunction[]): (program: ts.Program, options?: TsAutoMockOptions) => ts.TransformerFactory<ts.SourceFile> {
return (program: ts.Program, options?: TsAutoMockOptions): ts.TransformerFactory<ts.SourceFile> => {
Expand Down Expand Up @@ -47,14 +48,88 @@ function isObjectWithProperty<T extends {}, K extends keyof T>(
return typeof obj[key] !== 'undefined';
}

function isMockedByThisLibrary(declaration: ts.Declaration): boolean {
return MockDefiner.instance.hasKeyForDeclaration(declaration);
}

function visitNode(node: ts.Node, visitor: Visitor, customFunctions: CustomFunction[]): ts.Node {
if (!ts.isCallExpression(node)) {
if (!ts.isCallExpression(node) && !ts.isNewExpression(node)) {
return node;
}

const signature: ts.Signature | undefined = TypescriptHelper.getSignatureOfCallExpression(node);
const declaration: ts.Declaration | undefined = signature?.declaration;

if (!signature || !isFunctionFromThisLibrary(signature, customFunctions)) {
if (!declaration || !ts.isFunctionLike(declaration)) {
return node;
}

if (IsTsAutoMockOverloadsEnabled() && isMockedByThisLibrary(declaration)) {
const mockKey: string = MockDefiner.instance.getDeclarationKeyMap(declaration);
const mockKeyLiteral: ts.StringLiteral = ts.createStringLiteral(mockKey);

const boundSignatureCall: ts.CallExpression = ts.createCall(
ts.createPropertyAccess(
node.expression,
ts.createIdentifier('apply'),
),
undefined,
[mockKeyLiteral, ts.createArrayLiteral(node.arguments)],
);

if (ts.isCallExpression(node)) {
return boundSignatureCall;
}

const cachedConstructor: ts.ElementAccessExpression = ts.createElementAccess(
node.expression,
mockKeyLiteral,
);

return ts.createNew(
ts.createParen(
ts.createBinary(
cachedConstructor,
ts.SyntaxKind.BarBarToken,
ts.createParen(
ts.createBinary(
cachedConstructor,
ts.SyntaxKind.EqualsToken,
ts.createFunctionExpression(
undefined,
undefined,
'',
undefined,
undefined,
undefined,
ts.createBlock(
[
ts.createExpressionStatement(
ts.createCall(
ts.createPropertyAccess(
ts.createIdentifier('Object'),
ts.createIdentifier('assign'),
),
undefined,
[
ts.createIdentifier('this'),
boundSignatureCall,
]
),
),
],
),
),
),
),
),
),
undefined,
undefined,
);
}

if (!isFunctionFromThisLibrary(declaration, customFunctions)) {
return node;
}

Expand All @@ -73,7 +148,5 @@ function visitNode(node: ts.Node, visitor: Visitor, customFunctions: CustomFunct
MockDefiner.instance.setFileNameFromNode(nodeToMock);
MockDefiner.instance.setTsAutoMockImportIdentifier();

const declaration: ts.FunctionDeclaration = signature.declaration as ts.FunctionDeclaration;

return visitor(node, declaration);
}
2 changes: 1 addition & 1 deletion src/transformer/descriptor/helper/helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ export namespace TypescriptHelper {
}


export function getSignatureOfCallExpression(node: ts.CallExpression): ts.Signature | undefined {
export function getSignatureOfCallExpression(node: ts.CallLikeExpression): ts.Signature | undefined {
const typeChecker: ts.TypeChecker = TypeChecker();

return typeChecker.getResolvedSignature(node);
Expand Down
32 changes: 15 additions & 17 deletions src/transformer/descriptor/method/bodyReturnType.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,28 @@ export function GetReturnTypeFromBodyDescriptor(node: ts.ArrowFunction | ts.Func
return GetDescriptor(GetReturnNodeFromBody(node), scope);
}

export function GetReturnNodeFromBody(node: ts.FunctionLikeDeclaration): ts.Node {
let returnValue: ts.Node | undefined;

export function GetReturnNodeFromBody<T extends ts.Node & { body?: ts.ConciseBody }>(node: T): ts.Expression {
const functionBody: ts.ConciseBody | undefined = node.body;

if (functionBody && ts.isBlock(functionBody)) {
const returnStatement: ts.ReturnStatement = GetReturnStatement(functionBody);
if (!functionBody) {
return GetNullDescriptor();
}

if (returnStatement) {
returnValue = returnStatement.expression;
} else {
returnValue = GetNullDescriptor();
}
} else {
returnValue = node.body;
if (!ts.isBlock(functionBody)) {
return functionBody;
}

if (!returnValue) {
throw new Error(`Failed to determine the return value of ${node.getText()}.`);
const returnStatement: ts.ReturnStatement | undefined = GetReturnStatement(functionBody);

if (!returnStatement?.expression) {
return GetNullDescriptor();
}

return returnValue;
return returnStatement.expression;
}

function GetReturnStatement(body: ts.FunctionBody): ts.ReturnStatement {
return body.statements.find((statement: ts.Statement) => statement.kind === ts.SyntaxKind.ReturnStatement) as ts.ReturnStatement;
function GetReturnStatement(body: ts.FunctionBody): ts.ReturnStatement | undefined {
return body.statements.find(
(statement: ts.Statement): statement is ts.ReturnStatement => statement.kind === ts.SyntaxKind.ReturnStatement,
);
}
6 changes: 2 additions & 4 deletions src/transformer/descriptor/method/functionAssignment.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import * as ts from 'typescript';
import ts from 'typescript';
import { Scope } from '../../scope/scope';
import { PropertySignatureCache } from '../property/cache';
import { GetReturnTypeFromBodyDescriptor } from './bodyReturnType';
import { GetMethodDescriptor } from './method';

type functionAssignment = ts.ArrowFunction | ts.FunctionExpression;

export function GetFunctionAssignmentDescriptor(node: functionAssignment, scope: Scope): ts.Expression {
const property: ts.PropertyName = PropertySignatureCache.instance.get();
const returnValue: ts.Expression = GetReturnTypeFromBodyDescriptor(node, scope);

return GetMethodDescriptor(property, returnValue);
return GetMethodDescriptor(property, [node], scope);
}
5 changes: 1 addition & 4 deletions src/transformer/descriptor/method/functionType.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import * as ts from 'typescript';
import { Scope } from '../../scope/scope';
import { GetDescriptor } from '../descriptor';
import { PropertySignatureCache } from '../property/cache';
import { GetMethodDescriptor } from './method';

Expand All @@ -11,7 +10,5 @@ export function GetFunctionTypeDescriptor(node: ts.FunctionTypeNode | ts.CallSig
throw new Error(`No type was declared for ${node.getText()}.`);
}

const returnValue: ts.Expression = GetDescriptor(node.type, scope);

return GetMethodDescriptor(property, returnValue);
return GetMethodDescriptor(property, [node], scope);
}
96 changes: 89 additions & 7 deletions src/transformer/descriptor/method/method.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,100 @@
import * as ts from 'typescript';
import ts from 'typescript';
import { IsTsAutoMockOverloadsEnabled } from '../../../options/overloads';
import { TypescriptCreator } from '../../helper/creator';
import { MockDefiner } from '../../mockDefiner/mockDefiner';
import { ModuleName } from '../../mockDefiner/modules/moduleName';
import { MockIdentifierJumpTable } from '../../mockIdentifier/mockIdentifier';
import { Scope } from '../../scope/scope';
import { TypescriptHelper } from '../helper/helper';
import { GetDescriptor } from '../descriptor';
import { GetReturnNodeFromBody } from './bodyReturnType';

export function GetMethodDescriptor(propertyName: ts.PropertyName, returnValue: ts.Expression): ts.Expression {
type MethodDeclaration =
| ts.ArrowFunction
| ts.FunctionExpression
| ts.MethodSignature
| ts.FunctionTypeNode
| ts.CallSignatureDeclaration
| ts.ConstructSignatureDeclaration
| ts.MethodDeclaration
| ts.FunctionDeclaration;

function GetDeclarationType(declaration: ts.SignatureDeclaration): ts.TypeNode {
if (declaration.type) {
return declaration.type;
}

return ts.createLiteralTypeNode(GetReturnNodeFromBody(declaration) as ts.LiteralExpression);
}

export function GetMethodDescriptor(
propertyName: ts.PropertyName,
methodDeclarations: ReadonlyArray<MethodDeclaration>,
scope: Scope,
): ts.CallExpression {
const providerGetMethod: ts.PropertyAccessExpression = CreateProviderGetMethod();

const propertyNameString: string = TypescriptHelper.GetStringPropertyName(propertyName);
const propertyNameStringLiteral: ts.StringLiteral = ts.createStringLiteral(propertyNameString);

const propertyValueFunction: ts.ArrowFunction = TypescriptCreator.createArrowFunction(ts.createBlock(
[ts.createReturn(returnValue)],
true,
));
const statements: ts.Statement[] = [];

const [primaryDeclaration, ...remainingDeclarations]: ReadonlyArray<MethodDeclaration> = methodDeclarations;

if (remainingDeclarations.length && IsTsAutoMockOverloadsEnabled()) {
const jumpTableEntries: ts.PropertyAssignment[] = methodDeclarations.map((declaration: ts.FunctionDeclaration) =>
ts.createPropertyAssignment(
ts.createComputedPropertyName(
ts.createStringLiteral(
MockDefiner.instance.getDeclarationKeyMap(declaration),
),
),
ts.createArrowFunction(
undefined,
undefined,
[],
undefined,
undefined,
GetDescriptor(GetDeclarationType(declaration), scope),
),
),
);

statements.push(
TypescriptCreator.createVariableStatement([
TypescriptCreator.createVariableDeclaration(
MockIdentifierJumpTable,
ts.createObjectLiteral(jumpTableEntries),
),
]),
);

statements.push(
ts.createReturn(
ts.createCall(
ts.createElementAccess(
MockIdentifierJumpTable,
ts.createIdentifier('this'),
),
undefined,
undefined,
),
),
);
} else {
statements.push(
ts.createReturn(
GetDescriptor(GetDeclarationType(primaryDeclaration), scope),
),
);
}

const block: ts.Block = ts.createBlock(statements, true);

const propertyValueFunction: ts.FunctionExpression = TypescriptCreator.createFunctionExpression(
block,
[],
);

return TypescriptCreator.createCall(providerGetMethod, [propertyNameStringLiteral, propertyValueFunction]);
}
Expand All @@ -26,5 +107,6 @@ function CreateProviderGetMethod(): ts.PropertyAccessExpression {
ts.createIdentifier('Provider'),
),
ts.createIdentifier('instance')),
ts.createIdentifier('getMethod'));
ts.createIdentifier('getMethod'),
);
}
20 changes: 14 additions & 6 deletions src/transformer/descriptor/method/methodDeclaration.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import * as ts from 'typescript';
import ts from 'typescript';
import { Scope } from '../../scope/scope';
import { GetDescriptor } from '../descriptor';
import { GetFunctionReturnType } from './functionReturnType';
import { TypeChecker } from '../../typeChecker/typeChecker';

import { GetMethodDescriptor } from './method';

export function GetMethodDeclarationDescriptor(node: ts.MethodDeclaration | ts.FunctionDeclaration, scope: Scope): ts.Expression {
const returnTypeNode: ts.Node = GetFunctionReturnType(node);
const returnType: ts.Expression = GetDescriptor(returnTypeNode, scope);
const declarationType: ts.Type | undefined = TypeChecker().getTypeAtLocation(node);
const methodDeclarations: Array<ts.MethodDeclaration | ts.FunctionDeclaration> = declarationType.symbol.declarations
.filter(
(declaration: ts.Declaration): declaration is ts.MethodDeclaration | ts.FunctionDeclaration =>
ts.isMethodDeclaration(declaration) || ts.isFunctionDeclaration(declaration)
);

if (!methodDeclarations.length) {
methodDeclarations.push(node);
}

if (!node.name) {
throw new Error(
`The transformer couldn't determine the name of ${node.getText()}. Please report this incident.`,
);
}

return GetMethodDescriptor(node.name, returnType);
return GetMethodDescriptor(node.name, methodDeclarations, scope);
}
Loading

0 comments on commit 340b562

Please sign in to comment.