Skip to content

Commit

Permalink
Hacking
Browse files Browse the repository at this point in the history
  • Loading branch information
snicoll committed Nov 9, 2023
1 parent d8ed7c7 commit add5d98
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.orm.jpa;

import java.util.function.Predicate;

import javax.sql.DataSource;

import jakarta.persistence.EntityManagerFactory;
Expand Down Expand Up @@ -200,6 +202,16 @@ public void setPackagesToScan(String... packagesToScan) {
this.internalPersistenceUnitManager.setPackagesToScan(packagesToScan);
}

/**
* Set the {@linkplain Predicate filter} to apply on entity classes discovered
* using {@linkplain #setPackagesToScan(String...) classpath scanning}.
* @param managedClassNameFilter the predicate to filter entity classes
* @since 6.2
*/
public void setManagedClassNameFilter(Predicate<String> managedClassNameFilter) {
this.internalPersistenceUnitManager.setManagedClassNameFilter(managedClassNameFilter);
}

/**
* Specify one or more mapping resources (equivalent to {@code <mapping-file>}
* entries in {@code persistence.xml}) for the default persistence unit.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;

import javax.sql.DataSource;

Expand Down Expand Up @@ -122,6 +123,9 @@ public class DefaultPersistenceUnitManager
@Nullable
private String[] packagesToScan;

@Nullable
private Predicate<String> managedClassNameFilter;

@Nullable
private String[] mappingResources;

Expand Down Expand Up @@ -229,6 +233,7 @@ public void setManagedTypes(PersistenceManagedTypes managedTypes) {
* resource for the default unit if the mapping file is not co-located with a
* {@code persistence.xml} file (in which case we assume it is only meant to be
* used with the persistence units defined there, like in standard JPA).
* @see #setManagedClassNameFilter(Predicate)
* @see #setManagedTypes(PersistenceManagedTypes)
* @see #setDefaultPersistenceUnitName
* @see #setMappingResources
Expand All @@ -237,6 +242,16 @@ public void setPackagesToScan(String... packagesToScan) {
this.packagesToScan = packagesToScan;
}

/**
* Set the {@linkplain Predicate filter} to apply on entity classes discovered
* using {@linkplain #setPackagesToScan(String...) classpath scanning}.
* @param managedClassNameFilter the predicate to filter entity classes
* @since 6.2
*/
public void setManagedClassNameFilter(Predicate<String> managedClassNameFilter) {
this.managedClassNameFilter = managedClassNameFilter;
}

/**
* Specify one or more mapping resources (equivalent to {@code <mapping-file>}
* entries in {@code persistence.xml}) for the default persistence unit.
Expand Down Expand Up @@ -546,8 +561,9 @@ private SpringPersistenceUnitInfo buildDefaultPersistenceUnitInfo() {
applyManagedTypes(scannedUnit, this.managedTypes);
}
else if (this.packagesToScan != null) {
applyManagedTypes(scannedUnit, new PersistenceManagedTypesScanner(
this.resourcePatternResolver).scan(this.packagesToScan));
PersistenceManagedTypesScanner scanner = new PersistenceManagedTypesScanner(
this.resourcePatternResolver, this.managedClassNameFilter);
applyManagedTypes(scannedUnit, scanner.scan(this.packagesToScan));
}

if (this.mappingResources != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;

import jakarta.persistence.Converter;
import jakarta.persistence.Embeddable;
Expand Down Expand Up @@ -73,10 +74,20 @@ public final class PersistenceManagedTypesScanner {
@Nullable
private final CandidateComponentsIndex componentsIndex;

private final Predicate<String> managedClassNameFilter;


public PersistenceManagedTypesScanner(ResourceLoader resourceLoader,
@Nullable Predicate<String> managedClassNameFilter) {

public PersistenceManagedTypesScanner(ResourceLoader resourceLoader) {
this.resourcePatternResolver = ResourcePatternUtils.getResourcePatternResolver(resourceLoader);
this.componentsIndex = CandidateComponentsIndexLoader.loadIndex(resourceLoader.getClassLoader());
this.managedClassNameFilter = (managedClassNameFilter != null ? managedClassNameFilter
: className -> true);
}

public PersistenceManagedTypesScanner(ResourceLoader resourceLoader) {
this(resourceLoader, null);
}

/**
Expand All @@ -99,7 +110,7 @@ private void scanPackage(String pkg, ScanResult scanResult) {
for (AnnotationTypeFilter filter : entityTypeFilters) {
candidates.addAll(this.componentsIndex.getCandidateTypes(pkg, filter.getAnnotationType().getName()));
}
scanResult.managedClassNames.addAll(candidates);
scanResult.managedClassNames.addAll(candidates.stream().filter(this.managedClassNameFilter).toList());
scanResult.managedPackages.addAll(this.componentsIndex.getCandidateTypes(pkg, "package-info"));
return;
}
Expand All @@ -113,7 +124,8 @@ private void scanPackage(String pkg, ScanResult scanResult) {
try {
MetadataReader reader = readerFactory.getMetadataReader(resource);
String className = reader.getClassMetadata().getClassName();
if (matchesFilter(reader, readerFactory)) {
if (matchesEntityTypeFilter(reader, readerFactory)
&& this.managedClassNameFilter.test(className)) {
scanResult.managedClassNames.add(className);
if (scanResult.persistenceUnitRootUrl == null) {
URL url = resource.getURL();
Expand Down Expand Up @@ -141,7 +153,7 @@ else if (className.endsWith(PACKAGE_INFO_SUFFIX)) {
* Check whether any of the configured entity type filters matches
* the current class descriptor contained in the metadata reader.
*/
private boolean matchesFilter(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException {
private boolean matchesEntityTypeFilter(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException {
for (TypeFilter filter : entityTypeFilters) {
if (filter.match(reader, readerFactory)) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package org.springframework.orm.jpa.persistenceunit;

import java.util.List;
import java.util.function.Predicate;

import org.junit.jupiter.api.Test;

import org.springframework.context.testfixture.index.CandidateComponentsTestClassLoader;
Expand All @@ -28,6 +31,11 @@
import org.springframework.orm.jpa.domain2.entity.User;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

/**
* Tests for {@link PersistenceManagedTypesScanner}.
Expand All @@ -36,7 +44,9 @@
*/
class PersistenceManagedTypesScannerTests {

private final PersistenceManagedTypesScanner scanner = new PersistenceManagedTypesScanner(new DefaultResourceLoader());
public static final DefaultResourceLoader RESOURCE_LOADER = new DefaultResourceLoader();

private final PersistenceManagedTypesScanner scanner = new PersistenceManagedTypesScanner(RESOURCE_LOADER);

@Test
void scanPackageWithOnlyEntities() {
Expand All @@ -47,6 +57,30 @@ void scanPackageWithOnlyEntities() {
assertThat(managedTypes.getManagedPackages()).isEmpty();
}

@Test
@SuppressWarnings("unchecked")
void scanPackageInvokesManagedClassNamesFilter() {
Predicate<String> filter = mock(Predicate.class);
given(filter.test(anyString())).willReturn(true);
new PersistenceManagedTypesScanner(RESOURCE_LOADER, filter)
.scan("org.springframework.orm.jpa.domain");
verify(filter).test(Person.class.getName());
verify(filter).test(DriversLicense.class.getName());
verify(filter).test(Employee.class.getName());
verify(filter).test(EmployeeLocationConverter.class.getName());
verifyNoMoreInteractions(filter);
}

@Test
void scanPackageWithUseManagedClassNamesFilter() {
List<String> candidates = List.of(Person.class.getName(), DriversLicense.class.getName());
PersistenceManagedTypes managedTypes = new PersistenceManagedTypesScanner(
RESOURCE_LOADER, candidates::contains).scan("org.springframework.orm.jpa.domain");
assertThat(managedTypes.getManagedClassNames()).containsExactlyInAnyOrder(
Person.class.getName(), DriversLicense.class.getName());
assertThat(managedTypes.getManagedPackages()).isEmpty();
}

@Test
void scanPackageWithEntitiesAndManagedPackages() {
PersistenceManagedTypes managedTypes = this.scanner.scan("org.springframework.orm.jpa.domain2");
Expand All @@ -65,7 +99,20 @@ void scanPackageUsesIndexIfPresent() {
"com.example.domain.Person", "com.example.domain.Address");
assertThat(managedTypes.getManagedPackages()).containsExactlyInAnyOrder(
"com.example.domain");
}

@Test
void scanPackageUsesIndexAndClassNameFilterIfPresent() {
List<String> candidates = List.of("com.example.domain.Address");
DefaultResourceLoader resourceLoader = new DefaultResourceLoader(
CandidateComponentsTestClassLoader.index(getClass().getClassLoader(),
new ClassPathResource("test-spring.components", getClass())));
PersistenceManagedTypes managedTypes = new PersistenceManagedTypesScanner(
resourceLoader, candidates::contains).scan("com.example");
assertThat(managedTypes.getManagedClassNames()).containsExactlyInAnyOrder(
"com.example.domain.Address");
assertThat(managedTypes.getManagedPackages()).containsExactlyInAnyOrder(
"com.example.domain");
}

}

0 comments on commit add5d98

Please sign in to comment.