Skip to content

Commit

Permalink
Infer parameter types in arrow functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ondrejmirtes committed Jun 18, 2021
1 parent d4ded32 commit 8fdc2d3
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 37 deletions.
70 changes: 46 additions & 24 deletions src/Analyser/MutatingScope.php
Original file line number Diff line number Diff line change
Expand Up @@ -1370,30 +1370,31 @@ private function resolveType(Expr $node): Type
);
}

$callableParameters = null;
$arg = $node->getAttribute('parent');
if ($arg instanceof Arg) {
$funcCall = $arg->getAttribute('parent');
$argOrder = $arg->getAttribute('expressionOrder');
if ($funcCall instanceof FuncCall && $funcCall->name instanceof Name) {
$functionName = $this->reflectionProvider->resolveFunctionName($funcCall->name, $this);
if (
$functionName === 'array_map'
&& $argOrder === 0
&& isset($funcCall->args[1])
) {
$callableParameters = [
new DummyParameter('item', $this->getType($funcCall->args[1]->value)->getIterableValueType(), false, PassedByReference::createNo(), false, null),
];
}
}
}

if ($node instanceof Expr\ArrowFunction) {
$returnType = $this->enterArrowFunctionWithoutReflection($node)->getType($node->expr);
$returnType = $this->enterArrowFunctionWithoutReflection($node, $callableParameters)->getType($node->expr);
if ($node->returnType !== null) {
$returnType = TypehintHelper::decideType($this->getFunctionType($node->returnType, false, false), $returnType);
}
} else {
$callableParameters = null;
$arg = $node->getAttribute('parent');
if ($arg instanceof Arg) {
$funcCall = $arg->getAttribute('parent');
$argOrder = $arg->getAttribute('expressionOrder');
if ($funcCall instanceof FuncCall && $funcCall->name instanceof Name) {
$functionName = $this->reflectionProvider->resolveFunctionName($funcCall->name, $this);
if (
$functionName === 'array_map'
&& $argOrder === 0
&& isset($funcCall->args[1])
) {
$callableParameters = [
new DummyParameter('item', $this->getType($funcCall->args[1]->value)->getIterableValueType(), false, PassedByReference::createNo(), false, null),
];
}
}
}
$closureScope = $this->enterAnonymousFunctionWithoutReflection($node, $callableParameters);
$closureReturnStatements = [];
$closureYieldStatements = [];
Expand Down Expand Up @@ -3027,15 +3028,18 @@ private function enterAnonymousFunctionWithoutReflection(
);
}

/** @api */
public function enterArrowFunction(Expr\ArrowFunction $arrowFunction): self
/**
* @api
* @param \PHPStan\Reflection\ParameterReflection[]|null $callableParameters
*/
public function enterArrowFunction(Expr\ArrowFunction $arrowFunction, ?array $callableParameters): self
{
$anonymousFunctionReflection = $this->getType($arrowFunction);
if (!$anonymousFunctionReflection instanceof ClosureType) {
throw new \PHPStan\ShouldNotHappenException();
}

$scope = $this->enterArrowFunctionWithoutReflection($arrowFunction);
$scope = $this->enterArrowFunctionWithoutReflection($arrowFunction, $callableParameters);

return $this->scopeFactory->create(
$scope->context,
Expand All @@ -3057,19 +3061,37 @@ public function enterArrowFunction(Expr\ArrowFunction $arrowFunction): self
);
}

private function enterArrowFunctionWithoutReflection(Expr\ArrowFunction $arrowFunction): self
/**
* @param \PHPStan\Reflection\ParameterReflection[]|null $callableParameters
*/
private function enterArrowFunctionWithoutReflection(Expr\ArrowFunction $arrowFunction, ?array $callableParameters): self
{
$variableTypes = $this->variableTypes;
$mixed = new MixedType();
$parameterVariables = [];
foreach ($arrowFunction->params as $parameter) {
foreach ($arrowFunction->params as $i => $parameter) {
if ($parameter->type === null) {
$parameterType = $mixed;
} else {
$isNullable = $this->isParameterValueNullable($parameter);
$parameterType = $this->getFunctionType($parameter->type, $isNullable, $parameter->variadic);
}

if ($callableParameters !== null) {
if (isset($callableParameters[$i])) {
$parameterType = TypehintHelper::decideType($parameterType, $callableParameters[$i]->getType());
} elseif (count($callableParameters) > 0) {
$lastParameter = $callableParameters[count($callableParameters) - 1];
if ($lastParameter->isVariadic()) {
$parameterType = TypehintHelper::decideType($parameterType, $lastParameter->getType());
} else {
$parameterType = TypehintHelper::decideType($parameterType, new MixedType());
}
} else {
$parameterType = TypehintHelper::decideType($parameterType, new MixedType());
}
}

if (!$parameter->var instanceof Variable || !is_string($parameter->var->name)) {
throw new \PHPStan\ShouldNotHappenException();
}
Expand Down
57 changes: 44 additions & 13 deletions src/Analyser/NodeScopeResolver.php
Original file line number Diff line number Diff line change
Expand Up @@ -2199,19 +2199,7 @@ static function () use ($scope, $expr): MutatingScope {
$hasYield = false;
$throwPoints = [];
} elseif ($expr instanceof Expr\ArrowFunction) {
foreach ($expr->params as $param) {
$this->processParamNode($param, $scope, $nodeCallback);
}
if ($expr->returnType !== null) {
$nodeCallback($expr->returnType, $scope);
}

$arrowFunctionScope = $scope->enterArrowFunction($expr);
$nodeCallback(new InArrowFunctionNode($expr), $arrowFunctionScope);
$this->processExprNode($expr->expr, $arrowFunctionScope, $nodeCallback, ExpressionContext::createTopLevel());
$hasYield = false;
$throwPoints = [];

return $this->processArrowFunctionNode($expr, $scope, $nodeCallback, $context, null);
} elseif ($expr instanceof ErrorSuppress) {
$result = $this->processExprNode($expr->expr, $scope, $nodeCallback, $context);
$hasYield = $result->hasYield();
Expand Down Expand Up @@ -2963,6 +2951,46 @@ private function processClosureNode(
return new ExpressionResult($scope->processClosureScope($closureScope, null, $byRefUses), false, []);
}

/**
* @param \PhpParser\Node\Expr\ArrowFunction $expr
* @param \PHPStan\Analyser\MutatingScope $scope
* @param callable(\PhpParser\Node $node, Scope $scope): void $nodeCallback
* @param ExpressionContext $context
* @param Type|null $passedToType
* @return \PHPStan\Analyser\ExpressionResult
*/
private function processArrowFunctionNode(
Expr\ArrowFunction $expr,
MutatingScope $scope,
callable $nodeCallback,
ExpressionContext $context,
?Type $passedToType
): ExpressionResult
{
foreach ($expr->params as $param) {
$this->processParamNode($param, $scope, $nodeCallback);
}
if ($expr->returnType !== null) {
$nodeCallback($expr->returnType, $scope);
}

if ($passedToType !== null && !$passedToType->isCallable()->no()) {
$callableParameters = null;
$acceptors = $passedToType->getCallableParametersAcceptors($scope);
if (count($acceptors) === 1) {
$callableParameters = $acceptors[0]->getParameters();
}
} else {
$callableParameters = null;
}

$arrowFunctionScope = $scope->enterArrowFunction($expr, $callableParameters);
$nodeCallback(new InArrowFunctionNode($expr), $arrowFunctionScope);
$this->processExprNode($expr->expr, $arrowFunctionScope, $nodeCallback, ExpressionContext::createTopLevel());

return new ExpressionResult($scope, false, []);
}

private function lookForArrayDestructuringArray(MutatingScope $scope, Expr $expr, Type $valueType): MutatingScope
{
if ($expr instanceof Array_ || $expr instanceof List_) {
Expand Down Expand Up @@ -3138,6 +3166,9 @@ private function processArgs(
if ($arg->value instanceof Expr\Closure) {
$this->callNodeCallbackWithExpression($nodeCallback, $arg->value, $scopeToPass, $context);
$result = $this->processClosureNode($arg->value, $scopeToPass, $nodeCallback, $context, $parameterType ?? null);
} elseif ($arg->value instanceof Expr\ArrowFunction) {
$this->callNodeCallbackWithExpression($nodeCallback, $arg->value, $scopeToPass, $context);
$result = $this->processArrowFunctionNode($arg->value, $scopeToPass, $nodeCallback, $context, $parameterType ?? null);
} else {
$result = $this->processExprNode($arg->value, $scopeToPass, $nodeCallback, $context->enterDeep());
}
Expand Down
5 changes: 5 additions & 0 deletions tests/PHPStan/Analyser/NodeScopeResolverTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,11 @@ public function dataFileAsserts(): iterable
yield from $this->gatherAssertTypes(__DIR__ . '/../Rules/Functions/data/varying-acceptor.php');

yield from $this->gatherAssertTypes(__DIR__ . '/data/uksort-bug.php');

if (self::$useStaticReflectionProvider || PHP_VERSION_ID >= 70400) {
yield from $this->gatherAssertTypes(__DIR__ . '/data/arrow-function-types.php');
}

yield from $this->gatherAssertTypes(__DIR__ . '/data/closure-types.php');
}

Expand Down
44 changes: 44 additions & 0 deletions tests/PHPStan/Analyser/data/arrow-function-types.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
<?php // lint >= 7.4

namespace ArrowFunctionTypes;

use function PHPStan\Testing\assertType;

class Foo
{

/** @var array<int, array{foo: string, bar: int}> */
private $arrayShapes;

public function doFoo(): void
{
array_map(fn(array $a): array => assertType('array(\'foo\' => string, \'bar\' => int)', $a), $this->arrayShapes);
$a = array_map(fn(array $a) => $a, $this->arrayShapes);
assertType('array<int, array(\'foo\' => string, \'bar\' => int)>', $a);

array_map(fn($b) => assertType('array(\'foo\' => string, \'bar\' => int)', $b), $this->arrayShapes);
$b = array_map(fn($b) => $b['foo'], $this->arrayShapes);
assertType('array<int, string>', $b);
}

public function doBar(): void
{
usort($this->arrayShapes, fn(array $a, array $b): int => assertType('array(\'foo\' => string, \'bar\' => int)', $a));
}

public function doBar2(): void
{
usort($this->arrayShapes, fn (array $a, array $b): int => assertType('array(\'foo\' => string, \'bar\' => int)', $b));
}

public function doBaz(): void
{
usort($this->arrayShapes, fn ($a, $b): int => assertType('array(\'foo\' => string, \'bar\' => int)', $a));
}

public function doBaz2(): void
{
usort($this->arrayShapes, fn ($a, $b): int => assertType('array(\'foo\' => string, \'bar\' => int)', $b));
}

}

0 comments on commit 8fdc2d3

Please sign in to comment.