From add5d98f9d9a1331e3fb98efce4c9c0d8b8bfd39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Thu, 9 Nov 2023 14:12:47 +0100 Subject: [PATCH] Hacking --- ...ocalContainerEntityManagerFactoryBean.java | 12 +++++ .../DefaultPersistenceUnitManager.java | 20 +++++++- .../PersistenceManagedTypesScanner.java | 20 ++++++-- .../PersistenceManagedTypesScannerTests.java | 49 ++++++++++++++++++- 4 files changed, 94 insertions(+), 7 deletions(-) diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/LocalContainerEntityManagerFactoryBean.java b/spring-orm/src/main/java/org/springframework/orm/jpa/LocalContainerEntityManagerFactoryBean.java index f977c277bd1d..69e3b4fd63e7 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/LocalContainerEntityManagerFactoryBean.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/LocalContainerEntityManagerFactoryBean.java @@ -16,6 +16,8 @@ package org.springframework.orm.jpa; +import java.util.function.Predicate; + import javax.sql.DataSource; import jakarta.persistence.EntityManagerFactory; @@ -200,6 +202,16 @@ public void setPackagesToScan(String... packagesToScan) { this.internalPersistenceUnitManager.setPackagesToScan(packagesToScan); } + /** + * Set the {@linkplain Predicate filter} to apply on entity classes discovered + * using {@linkplain #setPackagesToScan(String...) classpath scanning}. + * @param managedClassNameFilter the predicate to filter entity classes + * @since 6.2 + */ + public void setManagedClassNameFilter(Predicate managedClassNameFilter) { + this.internalPersistenceUnitManager.setManagedClassNameFilter(managedClassNameFilter); + } + /** * Specify one or more mapping resources (equivalent to {@code } * entries in {@code persistence.xml}) for the default persistence unit. diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/DefaultPersistenceUnitManager.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/DefaultPersistenceUnitManager.java index e85b783bf61c..1333387902e9 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/DefaultPersistenceUnitManager.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/DefaultPersistenceUnitManager.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Predicate; import javax.sql.DataSource; @@ -122,6 +123,9 @@ public class DefaultPersistenceUnitManager @Nullable private String[] packagesToScan; + @Nullable + private Predicate managedClassNameFilter; + @Nullable private String[] mappingResources; @@ -229,6 +233,7 @@ public void setManagedTypes(PersistenceManagedTypes managedTypes) { * resource for the default unit if the mapping file is not co-located with a * {@code persistence.xml} file (in which case we assume it is only meant to be * used with the persistence units defined there, like in standard JPA). + * @see #setManagedClassNameFilter(Predicate) * @see #setManagedTypes(PersistenceManagedTypes) * @see #setDefaultPersistenceUnitName * @see #setMappingResources @@ -237,6 +242,16 @@ public void setPackagesToScan(String... packagesToScan) { this.packagesToScan = packagesToScan; } + /** + * Set the {@linkplain Predicate filter} to apply on entity classes discovered + * using {@linkplain #setPackagesToScan(String...) classpath scanning}. + * @param managedClassNameFilter the predicate to filter entity classes + * @since 6.2 + */ + public void setManagedClassNameFilter(Predicate managedClassNameFilter) { + this.managedClassNameFilter = managedClassNameFilter; + } + /** * Specify one or more mapping resources (equivalent to {@code } * entries in {@code persistence.xml}) for the default persistence unit. @@ -546,8 +561,9 @@ private SpringPersistenceUnitInfo buildDefaultPersistenceUnitInfo() { applyManagedTypes(scannedUnit, this.managedTypes); } else if (this.packagesToScan != null) { - applyManagedTypes(scannedUnit, new PersistenceManagedTypesScanner( - this.resourcePatternResolver).scan(this.packagesToScan)); + PersistenceManagedTypesScanner scanner = new PersistenceManagedTypesScanner( + this.resourcePatternResolver, this.managedClassNameFilter); + applyManagedTypes(scannedUnit, scanner.scan(this.packagesToScan)); } if (this.mappingResources != null) { diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScanner.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScanner.java index 1686b9c2fb22..849bb7c18fee 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScanner.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScanner.java @@ -24,6 +24,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Set; +import java.util.function.Predicate; import jakarta.persistence.Converter; import jakarta.persistence.Embeddable; @@ -73,10 +74,20 @@ public final class PersistenceManagedTypesScanner { @Nullable private final CandidateComponentsIndex componentsIndex; + private final Predicate managedClassNameFilter; + + + public PersistenceManagedTypesScanner(ResourceLoader resourceLoader, + @Nullable Predicate managedClassNameFilter) { - public PersistenceManagedTypesScanner(ResourceLoader resourceLoader) { this.resourcePatternResolver = ResourcePatternUtils.getResourcePatternResolver(resourceLoader); this.componentsIndex = CandidateComponentsIndexLoader.loadIndex(resourceLoader.getClassLoader()); + this.managedClassNameFilter = (managedClassNameFilter != null ? managedClassNameFilter + : className -> true); + } + + public PersistenceManagedTypesScanner(ResourceLoader resourceLoader) { + this(resourceLoader, null); } /** @@ -99,7 +110,7 @@ private void scanPackage(String pkg, ScanResult scanResult) { for (AnnotationTypeFilter filter : entityTypeFilters) { candidates.addAll(this.componentsIndex.getCandidateTypes(pkg, filter.getAnnotationType().getName())); } - scanResult.managedClassNames.addAll(candidates); + scanResult.managedClassNames.addAll(candidates.stream().filter(this.managedClassNameFilter).toList()); scanResult.managedPackages.addAll(this.componentsIndex.getCandidateTypes(pkg, "package-info")); return; } @@ -113,7 +124,8 @@ private void scanPackage(String pkg, ScanResult scanResult) { try { MetadataReader reader = readerFactory.getMetadataReader(resource); String className = reader.getClassMetadata().getClassName(); - if (matchesFilter(reader, readerFactory)) { + if (matchesEntityTypeFilter(reader, readerFactory) + && this.managedClassNameFilter.test(className)) { scanResult.managedClassNames.add(className); if (scanResult.persistenceUnitRootUrl == null) { URL url = resource.getURL(); @@ -141,7 +153,7 @@ else if (className.endsWith(PACKAGE_INFO_SUFFIX)) { * Check whether any of the configured entity type filters matches * the current class descriptor contained in the metadata reader. */ - private boolean matchesFilter(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException { + private boolean matchesEntityTypeFilter(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException { for (TypeFilter filter : entityTypeFilters) { if (filter.match(reader, readerFactory)) { return true; diff --git a/spring-orm/src/test/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScannerTests.java b/spring-orm/src/test/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScannerTests.java index bf88cc096279..ff82233145ab 100644 --- a/spring-orm/src/test/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScannerTests.java +++ b/spring-orm/src/test/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScannerTests.java @@ -16,6 +16,9 @@ package org.springframework.orm.jpa.persistenceunit; +import java.util.List; +import java.util.function.Predicate; + import org.junit.jupiter.api.Test; import org.springframework.context.testfixture.index.CandidateComponentsTestClassLoader; @@ -28,6 +31,11 @@ import org.springframework.orm.jpa.domain2.entity.User; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; /** * Tests for {@link PersistenceManagedTypesScanner}. @@ -36,7 +44,9 @@ */ class PersistenceManagedTypesScannerTests { - private final PersistenceManagedTypesScanner scanner = new PersistenceManagedTypesScanner(new DefaultResourceLoader()); + public static final DefaultResourceLoader RESOURCE_LOADER = new DefaultResourceLoader(); + + private final PersistenceManagedTypesScanner scanner = new PersistenceManagedTypesScanner(RESOURCE_LOADER); @Test void scanPackageWithOnlyEntities() { @@ -47,6 +57,30 @@ void scanPackageWithOnlyEntities() { assertThat(managedTypes.getManagedPackages()).isEmpty(); } + @Test + @SuppressWarnings("unchecked") + void scanPackageInvokesManagedClassNamesFilter() { + Predicate filter = mock(Predicate.class); + given(filter.test(anyString())).willReturn(true); + new PersistenceManagedTypesScanner(RESOURCE_LOADER, filter) + .scan("org.springframework.orm.jpa.domain"); + verify(filter).test(Person.class.getName()); + verify(filter).test(DriversLicense.class.getName()); + verify(filter).test(Employee.class.getName()); + verify(filter).test(EmployeeLocationConverter.class.getName()); + verifyNoMoreInteractions(filter); + } + + @Test + void scanPackageWithUseManagedClassNamesFilter() { + List candidates = List.of(Person.class.getName(), DriversLicense.class.getName()); + PersistenceManagedTypes managedTypes = new PersistenceManagedTypesScanner( + RESOURCE_LOADER, candidates::contains).scan("org.springframework.orm.jpa.domain"); + assertThat(managedTypes.getManagedClassNames()).containsExactlyInAnyOrder( + Person.class.getName(), DriversLicense.class.getName()); + assertThat(managedTypes.getManagedPackages()).isEmpty(); + } + @Test void scanPackageWithEntitiesAndManagedPackages() { PersistenceManagedTypes managedTypes = this.scanner.scan("org.springframework.orm.jpa.domain2"); @@ -65,7 +99,20 @@ void scanPackageUsesIndexIfPresent() { "com.example.domain.Person", "com.example.domain.Address"); assertThat(managedTypes.getManagedPackages()).containsExactlyInAnyOrder( "com.example.domain"); + } + @Test + void scanPackageUsesIndexAndClassNameFilterIfPresent() { + List candidates = List.of("com.example.domain.Address"); + DefaultResourceLoader resourceLoader = new DefaultResourceLoader( + CandidateComponentsTestClassLoader.index(getClass().getClassLoader(), + new ClassPathResource("test-spring.components", getClass()))); + PersistenceManagedTypes managedTypes = new PersistenceManagedTypesScanner( + resourceLoader, candidates::contains).scan("com.example"); + assertThat(managedTypes.getManagedClassNames()).containsExactlyInAnyOrder( + "com.example.domain.Address"); + assertThat(managedTypes.getManagedPackages()).containsExactlyInAnyOrder( + "com.example.domain"); } }