From 891412dcbffd2ad58c3b26f682b137c3b5718962 Mon Sep 17 00:00:00 2001 From: Martin Herndl Date: Wed, 4 Dec 2024 23:26:10 +0100 Subject: [PATCH] Improve `count()` narrowing of constant arrays --- src/Analyser/TypeSpecifier.php | 92 ++++++++++------------ tests/PHPStan/Analyser/nsrt/bug-4700.php | 2 +- tests/PHPStan/Analyser/nsrt/count-type.php | 23 ++++++ 3 files changed, 67 insertions(+), 50 deletions(-) diff --git a/src/Analyser/TypeSpecifier.php b/src/Analyser/TypeSpecifier.php index 088236bac5..dddf45ce08 100644 --- a/src/Analyser/TypeSpecifier.php +++ b/src/Analyser/TypeSpecifier.php @@ -277,22 +277,20 @@ public function specifyTypesInCondition( ) { $argType = $scope->getType($expr->right->getArgs()[0]->value); - if ($argType instanceof UnionType) { - $sizeType = null; - if ($leftType instanceof ConstantIntegerType) { - if ($orEqual) { - $sizeType = IntegerRangeType::createAllGreaterThanOrEqualTo($leftType->getValue()); - } else { - $sizeType = IntegerRangeType::createAllGreaterThan($leftType->getValue()); - } - } elseif ($leftType instanceof IntegerRangeType) { - $sizeType = $leftType; + $sizeType = null; + if ($leftType instanceof ConstantIntegerType) { + if ($orEqual) { + $sizeType = IntegerRangeType::createAllGreaterThanOrEqualTo($leftType->getValue()); + } else { + $sizeType = IntegerRangeType::createAllGreaterThan($leftType->getValue()); } + } elseif ($leftType instanceof IntegerRangeType) { + $sizeType = $leftType; + } - $narrowed = $this->narrowUnionByArraySize($expr->right, $argType, $sizeType, $context, $scope, $rootExpr); - if ($narrowed !== null) { - return $narrowed; - } + $specifiedTypes = $this->specifyTypesForCountFuncCall($expr->right, $argType, $sizeType, $context, $scope, $rootExpr); + if ($specifiedTypes !== null) { + $result = $result->unionWith($specifiedTypes); } if ( @@ -1010,66 +1008,64 @@ public function specifyTypesInCondition( return new SpecifiedTypes([], [], false, [], $rootExpr); } - private function narrowUnionByArraySize(FuncCall $countFuncCall, UnionType $argType, ?Type $sizeType, TypeSpecifierContext $context, Scope $scope, ?Expr $rootExpr): ?SpecifiedTypes + private function specifyTypesForCountFuncCall(FuncCall $countFuncCall, Type $type, ?Type $sizeType, TypeSpecifierContext $context, Scope $scope, ?Expr $rootExpr): ?SpecifiedTypes { if ($sizeType === null) { return null; } - if (count($countFuncCall->getArgs()) === 1) { - $isNormalCount = TrinaryLogic::createYes(); - } else { - $mode = $scope->getType($countFuncCall->getArgs()[1]->value); - $isNormalCount = (new ConstantIntegerType(COUNT_NORMAL))->isSuperTypeOf($mode)->or($argType->getIterableValueType()->isArray()->negate()); - } - if ( - $isNormalCount->yes() - && $argType->isConstantArray()->yes() + $this->isNormalCount($countFuncCall, $scope)->yes() + && $type->isConstantArray()->yes() ) { - $result = []; - foreach ($argType->getTypes() as $innerType) { - $arraySize = $innerType->getArraySize(); + $resultType = TypeTraverser::map($type, function (Type $type, callable $traverse) use ($sizeType, $context) { + if ($type instanceof UnionType) { + return $traverse($type); + } + + $arraySize = $type->getArraySize(); $isSize = $sizeType->isSuperTypeOf($arraySize); if ($context->truthy()) { if ($isSize->no()) { - continue; + return new NeverType(); } - $constArray = $this->turnListIntoConstantArray($countFuncCall, $innerType, $sizeType, $scope); + $constArray = $this->turnListIntoConstantArray($type, $sizeType); if ($constArray !== null) { - $innerType = $constArray; + $type = $constArray; } } if ($context->falsey()) { if (!$isSize->yes()) { - continue; + return new NeverType(); } } - $result[] = $innerType; - } + return $type; + }); - return $this->create($countFuncCall->getArgs()[0]->value, TypeCombinator::union(...$result), $context, false, $scope, $rootExpr); + return $this->create($countFuncCall->getArgs()[0]->value, $resultType, $context, false, $scope, $rootExpr); } return null; } - private function turnListIntoConstantArray(FuncCall $countFuncCall, Type $type, Type $sizeType, Scope $scope): ?Type + private function isNormalCount(FuncCall $countFuncCall, Scope $scope): TrinaryLogic { $argType = $scope->getType($countFuncCall->getArgs()[0]->value); if (count($countFuncCall->getArgs()) === 1) { - $isNormalCount = TrinaryLogic::createYes(); - } else { - $mode = $scope->getType($countFuncCall->getArgs()[1]->value); - $isNormalCount = (new ConstantIntegerType(COUNT_NORMAL))->isSuperTypeOf($mode)->or($argType->getIterableValueType()->isArray()->negate()); + return TrinaryLogic::createYes(); } + $mode = $scope->getType($countFuncCall->getArgs()[1]->value); + + return (new ConstantIntegerType(COUNT_NORMAL))->isSuperTypeOf($mode)->or($argType->getIterableValueType()->isArray()->negate()); + } + private function turnListIntoConstantArray(Type $type, Type $sizeType): ?Type + { if ( - $isNormalCount->yes() - && $type->isList()->yes() + $type->isList()->yes() && $sizeType instanceof ConstantIntegerType && $sizeType->getValue() < ConstantArrayTypeBuilder::ARRAY_COUNT_LIMIT ) { @@ -1083,8 +1079,7 @@ private function turnListIntoConstantArray(FuncCall $countFuncCall, Type $type, } if ( - $isNormalCount->yes() - && $type->isList()->yes() + $type->isList()->yes() && $sizeType instanceof IntegerRangeType && $sizeType->getMin() !== null ) { @@ -2171,11 +2166,9 @@ public function resolveIdentical(Expr\BinaryOp\Identical $expr, Scope $scope, Ty ); } - if ($argType instanceof UnionType) { - $narrowed = $this->narrowUnionByArraySize($unwrappedLeftExpr, $argType, $rightType, $context, $scope, $rootExpr); - if ($narrowed !== null) { - return $narrowed; - } + $specifiedTypes = $this->specifyTypesForCountFuncCall($unwrappedLeftExpr, $argType, $rightType, $context, $scope, $rootExpr); + if ($specifiedTypes !== null) { + return $specifiedTypes; } if ($context->truthy()) { @@ -2188,7 +2181,8 @@ public function resolveIdentical(Expr\BinaryOp\Identical $expr, Scope $scope, Ty } $funcTypes = $this->create($unwrappedLeftExpr, $rightType, $context, false, $scope, $rootExpr); - $constArray = $this->turnListIntoConstantArray($unwrappedLeftExpr, $argType, $rightType, $scope); + $isNormalCount = $this->isNormalCount($unwrappedLeftExpr, $scope); + $constArray = $isNormalCount->yes() ? $this->turnListIntoConstantArray($argType, $rightType) : null; if ($constArray !== null) { return $funcTypes->unionWith( $this->create($unwrappedLeftExpr->getArgs()[0]->value, $constArray, $context, false, $scope, $rootExpr), diff --git a/tests/PHPStan/Analyser/nsrt/bug-4700.php b/tests/PHPStan/Analyser/nsrt/bug-4700.php index 078ea41b12..202aca765c 100644 --- a/tests/PHPStan/Analyser/nsrt/bug-4700.php +++ b/tests/PHPStan/Analyser/nsrt/bug-4700.php @@ -40,7 +40,7 @@ function(array $array, int $count): void { if (isset($array['d'])) $a[] = $array['d']; if (isset($array['e'])) $a[] = $array['e']; if (count($a) > $count) { - assertType('int<1, 5>', count($a)); + assertType('int<2, 5>', count($a)); assertType('array{0: mixed~null, 1?: mixed~null, 2?: mixed~null, 3?: mixed~null, 4?: mixed~null}', $a); } else { assertType('0', count($a)); diff --git a/tests/PHPStan/Analyser/nsrt/count-type.php b/tests/PHPStan/Analyser/nsrt/count-type.php index 54fb89c2c7..859718b615 100644 --- a/tests/PHPStan/Analyser/nsrt/count-type.php +++ b/tests/PHPStan/Analyser/nsrt/count-type.php @@ -64,6 +64,29 @@ public function doFooBar( } } + /** @param array{0: string, 1?: string} $arr */ + public function doBar(array $arr): void + { + if (count($arr) <= 1) { + assertType('1', count($arr)); + return; + } + + assertType('2', count($arr)); + assertType('array{string, string}', $arr); + } + + /** @param array{0: string, 1?: string} $arr */ + public function doBaz(array $arr): void + { + if (count($arr) > 1) { + assertType('2', count($arr)); + assertType('array{string, string}', $arr); + } + + assertType('1|2', count($arr)); + } + } /**