Skip to content

Commit

Permalink
Enum support in query type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaud-lb committed Jul 6, 2022
1 parent f855eba commit ef75789
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 29 deletions.
76 changes: 54 additions & 22 deletions src/Type/Doctrine/Query/QueryResultTypeWalker.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

namespace PHPStan\Type\Doctrine\Query;

use BackedEnum;
use Doctrine\ORM\EntityManagerInterface;
use Doctrine\ORM\Mapping\ClassMetadata;
use Doctrine\ORM\Mapping\ClassMetadataInfo;
use Doctrine\ORM\Query;
use Doctrine\ORM\Query\AST;
use Doctrine\ORM\Query\AST\TypedExpression;
Expand All @@ -15,6 +17,7 @@
use PHPStan\Type\Constant\ConstantFloatType;
use PHPStan\Type\Constant\ConstantIntegerType;
use PHPStan\Type\Constant\ConstantStringType;
use PHPStan\Type\ConstantTypeHelper;
use PHPStan\Type\Doctrine\DescriptorNotRegisteredException;
use PHPStan\Type\Doctrine\DescriptorRegistry;
use PHPStan\Type\FloatType;
Expand All @@ -31,6 +34,7 @@
use PHPStan\Type\TypeTraverser;
use PHPStan\Type\TypeUtils;
use PHPStan\Type\UnionType;
use function array_map;
use function assert;
use function class_exists;
use function count;
Expand All @@ -42,6 +46,7 @@
use function is_numeric;
use function is_object;
use function is_string;
use function is_subclass_of;
use function serialize;
use function sprintf;
use function strtolower;
Expand Down Expand Up @@ -231,15 +236,13 @@ public function walkPathExpression($pathExpr)

switch ($pathExpr->type) {
case AST\PathExpression::TYPE_STATE_FIELD:
$typeName = $class->getTypeOfField($fieldName);

assert(is_string($typeName));
[$typeName, $enumType] = $this->getTypeOfField($class, $fieldName);

$nullable = $this->isQueryComponentNullable($dqlAlias)
|| $class->isNullable($fieldName)
|| $this->hasAggregateWithoutGroupBy();

$fieldType = $this->resolveDatabaseInternalType($typeName, $nullable);
$fieldType = $this->resolveDatabaseInternalType($typeName, $enumType, $nullable);

return $this->marshalType($fieldType);

Expand Down Expand Up @@ -273,14 +276,12 @@ public function walkPathExpression($pathExpr)
}

$targetFieldName = $identifierFieldNames[0];
$typeName = $targetClass->getTypeOfField($targetFieldName);

assert(is_string($typeName));
[$typeName] = $this->getTypeOfField($targetClass, $targetFieldName);

$nullable = (bool) ($joinColumn['nullable'] ?? true)
|| $this->hasAggregateWithoutGroupBy();

$fieldType = $this->resolveDatabaseInternalType($typeName, $nullable);
$fieldType = $this->resolveDatabaseInternalType($typeName, null, $nullable);

return $this->marshalType($fieldType);

Expand Down Expand Up @@ -543,7 +544,7 @@ public function walkFunction($function)
$joinColumn = null;

foreach ($assoc['joinColumns'] as $item) {
if ($item['referencedColumnName'] === $fieldMapping['columnName']) {
if ($item['referencedColumnName'] === ($fieldMapping['columnName'] ?? null)) {
$joinColumn = $item;
break;
}
Expand All @@ -556,7 +557,7 @@ public function walkFunction($function)
$nullable = (bool) ($joinColumn['nullable'] ?? true)
|| $this->hasAggregateWithoutGroupBy();

$fieldType = $this->resolveDatabaseInternalType($typeName, $nullable);
$fieldType = $this->resolveDatabaseInternalType($typeName, null, $nullable);

return $this->marshalType($fieldType);

Expand Down Expand Up @@ -783,15 +784,13 @@ public function walkSelectExpression($selectExpression)
$qComp = $this->queryComponents[$dqlAlias];
$class = $qComp['metadata'];

$typeName = $class->getTypeOfField($fieldName);

assert(is_string($typeName));
[$typeName, $enumType] = $this->getTypeOfField($class, $fieldName);

$nullable = $this->isQueryComponentNullable($dqlAlias)
|| $class->isNullable($fieldName)
|| $this->hasAggregateWithoutGroupBy();

$type = $this->resolveDoctrineType($typeName, $nullable);
$type = $this->resolveDoctrineType($typeName, $enumType, $nullable);

$this->typeBuilder->addScalar($resultAlias, $type);

Expand Down Expand Up @@ -1295,14 +1294,37 @@ private function isQueryComponentNullable(string $dqlAlias): bool
return $this->nullableQueryComponents[$dqlAlias] ?? false;
}

private function resolveDoctrineType(string $typeName, bool $nullable = false): Type
/** @return array{string, ?class-string<BackedEnum>} Doctrine type name and enum type of field */
private function getTypeOfField(ClassMetadataInfo $class, string $fieldName): array
{
try {
$type = $this->descriptorRegistry
->get($typeName)
->getWritableToPropertyType();
} catch (DescriptorNotRegisteredException $e) {
$type = new MixedType();
assert(isset($class->fieldMappings[$fieldName]));

/** @var array{type: string, enumType?: ?string} $metadata */
$metadata = $class->fieldMappings[$fieldName];

$type = $metadata['type'];
$enumType = $metadata['enumType'] ?? null;

if (!is_string($enumType) || !class_exists($enumType) || !is_subclass_of($enumType, BackedEnum::class)) {
$enumType = null;
}

return [$type, $enumType];
}

/** @param ?class-string<BackedEnum> $enumType */
private function resolveDoctrineType(string $typeName, ?string $enumType = null, bool $nullable = false): Type
{
if ($enumType !== null) {
$type = new ObjectType($enumType);
} else {
try {
$type = $this->descriptorRegistry
->get($typeName)
->getWritableToPropertyType();
} catch (DescriptorNotRegisteredException $e) {
$type = new MixedType();
}
}

if ($nullable) {
Expand All @@ -1312,7 +1334,8 @@ private function resolveDoctrineType(string $typeName, bool $nullable = false):
return $type;
}

private function resolveDatabaseInternalType(string $typeName, bool $nullable = false): Type
/** @param ?class-string<BackedEnum> $enumType */
private function resolveDatabaseInternalType(string $typeName, ?string $enumType = null, bool $nullable = false): Type
{
try {
$type = $this->descriptorRegistry
Expand All @@ -1322,6 +1345,15 @@ private function resolveDatabaseInternalType(string $typeName, bool $nullable =
$type = new MixedType();
}

if ($enumType !== null) {
$enumTypes = array_map(static function ($enumType) {
return ConstantTypeHelper::getTypeFromValue($enumType->value);
}, $enumType::cases());
$enumType = TypeCombinator::union(...$enumTypes);
$enumType = TypeCombinator::union($enumType, $enumType->toString());
$type = TypeCombinator::intersect($enumType, $type);
}

if ($nullable) {
$type = TypeCombinator::addNull($type);
}
Expand Down
77 changes: 70 additions & 7 deletions tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use DateTimeImmutable;
use Doctrine\Common\Collections\ArrayCollection;
use Doctrine\ORM\EntityManagerInterface;
use Doctrine\ORM\Mapping\Column;
use Doctrine\ORM\Query\AST\TypedExpression;
use Doctrine\ORM\Tools\SchemaTool;
use PHPStan\Testing\PHPStanTestCase;
Expand Down Expand Up @@ -35,11 +36,15 @@
use QueryResult\Entities\One;
use QueryResult\Entities\OneId;
use QueryResult\Entities\SingleTableChild;
use QueryResult\EntitiesEnum\EntityWithEnum;
use QueryResult\EntitiesEnum\IntEnum;
use QueryResult\EntitiesEnum\StringEnum;
use Throwable;
use function array_merge;
use function array_shift;
use function class_exists;
use function count;
use function property_exists;
use function sprintf;

final class QueryResultTypeWalkerTest extends PHPStanTestCase
Expand Down Expand Up @@ -151,6 +156,15 @@ public static function setUpBeforeClass(): void
$em->persist($child);
}

if (property_exists(Column::class, 'enumType')) {
$entityWithEnum = new EntityWithEnum();
$entityWithEnum->id = '1';
$entityWithEnum->stringEnumColumn = StringEnum::A;
$entityWithEnum->intEnumColumn = IntEnum::A;
$entityWithEnum->intEnumOnStringColumn = IntEnum::A;
$em->persist($entityWithEnum);
}

$em->flush();
}

Expand All @@ -173,6 +187,11 @@ public function test(Type $expectedType, string $dql, ?string $expectedException

$typeBuilder = new QueryResultTypeBuilder();

if ($expectedExceptionMessage !== null) {
$this->expectException(Throwable::class);
$this->expectExceptionMessage($expectedExceptionMessage);
}

QueryResultTypeWalker::walk($query, $typeBuilder, $this->descriptorRegistry);

$type = $typeBuilder->getResultType();
Expand All @@ -186,11 +205,6 @@ public function test(Type $expectedType, string $dql, ?string $expectedException

$query = $em->createQuery($dql);

if ($expectedExceptionMessage !== null) {
$this->expectException(Throwable::class);
$this->expectExceptionMessage($expectedExceptionMessage);
}

$result = $query->getResult();
self::assertGreaterThan(0, count($result));

Expand All @@ -199,7 +213,7 @@ public function test(Type $expectedType, string $dql, ?string $expectedException
self::assertTrue(
$type->accepts($rowType, true)->yes(),
sprintf(
"%s\nshould accept\n%s",
"The inferred type\n%s\nshould accept actual type\n%s",
$type->describe(VerbosityLevel::precise()),
$rowType->describe(VerbosityLevel::precise())
)
Expand All @@ -208,7 +222,7 @@ public function test(Type $expectedType, string $dql, ?string $expectedException
}

/**
* @return array<array-key,array{Type,string,2?:string}>
* @return array<array-key,array{Type,string,2?:?string}>
*/
public function getTestData(): array
{
Expand Down Expand Up @@ -467,6 +481,55 @@ public function getTestData(): array
FROM QueryResult\Entities\One o
',
],
'enum' => [
$this->constantArray([
[new ConstantStringType('stringEnumColumn'), new ObjectType(StringEnum::class)],
[new ConstantStringType('intEnumColumn'), new ObjectType(IntEnum::class)],
]),
'
SELECT e.stringEnumColumn, e.intEnumColumn
FROM QueryResult\EntitiesEnum\EntityWithEnum e
',
property_exists(Column::class, 'enumType')
? null
: 'The class \'QueryResult\\EntitiesEnum\\EntityWithEnum\' was not found in the chain configured namespaces QueryResult\\Entities\\',
],
'enum in expression' => [
$this->constantArray([
[
new ConstantIntegerType(1),
TypeCombinator::union(
new ConstantStringType('a'),
new ConstantStringType('b'),
),
],
[
new ConstantIntegerType(2),
TypeCombinator::union(
new ConstantIntegerType(1),
new ConstantIntegerType(2),
new ConstantStringType('1'),
new ConstantStringType('2'),
),
],
[
new ConstantIntegerType(3),
TypeCombinator::union(
new ConstantStringType('1'),
new ConstantStringType('2'),
),
],
]),
'
SELECT COALESCE(e.stringEnumColumn, e.stringEnumColumn),
COALESCE(e.intEnumColumn, e.intEnumColumn),
COALESCE(e.intEnumOnStringColumn, e.intEnumOnStringColumn)
FROM QueryResult\EntitiesEnum\EntityWithEnum e
',
property_exists(Column::class, 'enumType')
? null
: 'The class \'QueryResult\\EntitiesEnum\\EntityWithEnum\' was not found in the chain configured namespaces QueryResult\\Entities\\',
],
'hidden' => [
$this->constantArray([
[new ConstantStringType('intColumn'), new IntegerType()],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<?php declare(strict_types=1);

namespace QueryResult\EntitiesEnum;

use Doctrine\Common\Collections\Collection;
use Doctrine\ORM\Mapping\Column;
use Doctrine\ORM\Mapping\Embedded as ORMEmbedded;
use Doctrine\ORM\Mapping\Entity;
use Doctrine\ORM\Mapping\Id;
use Doctrine\ORM\Mapping\JoinColumn;
use Doctrine\ORM\Mapping\ManyToOne;
use Doctrine\ORM\Mapping\OneToMany;

/**
* @Entity
*/
class EntityWithEnum
{
/**
* @Column(type="bigint")
* @Id
*
* @var string
*/
public $id;

/**
* @Column(type="string", enumType="QueryResult\EntitiesEnum\StringEnum")
*/
public $stringEnumColumn;

/**
* @Column(type="integer", enumType="QueryResult\EntitiesEnum\IntEnum")
*/
public $intEnumColumn;

/**
* @Column(type="string", enumType="QueryResult\Entities\IntEnum")
*/
public $intEnumOnStringColumn;
}
18 changes: 18 additions & 0 deletions tests/Type/Doctrine/data/QueryResult/EntitiesEnum/IntEnum.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?php declare(strict_types=1); // lint >= 8.1

namespace QueryResult\EntitiesEnum;

use Doctrine\Common\Collections\Collection;
use Doctrine\ORM\Mapping\Column;
use Doctrine\ORM\Mapping\Embedded as ORMEmbedded;
use Doctrine\ORM\Mapping\Entity;
use Doctrine\ORM\Mapping\Id;
use Doctrine\ORM\Mapping\JoinColumn;
use Doctrine\ORM\Mapping\ManyToOne;
use Doctrine\ORM\Mapping\OneToMany;

enum IntEnum: int
{
case A = 1;
case B = 2;
}
18 changes: 18 additions & 0 deletions tests/Type/Doctrine/data/QueryResult/EntitiesEnum/StringEnum.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?php declare(strict_types=1); // lint >= 8.1

namespace QueryResult\EntitiesEnum;

use Doctrine\Common\Collections\Collection;
use Doctrine\ORM\Mapping\Column;
use Doctrine\ORM\Mapping\Embedded as ORMEmbedded;
use Doctrine\ORM\Mapping\Entity;
use Doctrine\ORM\Mapping\Id;
use Doctrine\ORM\Mapping\JoinColumn;
use Doctrine\ORM\Mapping\ManyToOne;
use Doctrine\ORM\Mapping\OneToMany;

enum StringEnum: string
{
case A = 'a';
case B = 'b';
}
Loading

0 comments on commit ef75789

Please sign in to comment.