diff --git a/spring-cloud-gcp-pubsub-stream-binder/src/test/java/com/google/cloud/spring/stream/binder/pubsub/PubSubMessageChannelBinderTests.java b/spring-cloud-gcp-pubsub-stream-binder/src/test/java/com/google/cloud/spring/stream/binder/pubsub/PubSubMessageChannelBinderTests.java index 3678b79891..25a2a0face 100644 --- a/spring-cloud-gcp-pubsub-stream-binder/src/test/java/com/google/cloud/spring/stream/binder/pubsub/PubSubMessageChannelBinderTests.java +++ b/spring-cloud-gcp-pubsub-stream-binder/src/test/java/com/google/cloud/spring/stream/binder/pubsub/PubSubMessageChannelBinderTests.java @@ -24,10 +24,13 @@ import com.google.cloud.spring.core.GcpProjectIdProvider; import com.google.cloud.spring.pubsub.PubSubAdmin; import com.google.cloud.spring.pubsub.core.PubSubTemplate; +import com.google.cloud.spring.pubsub.core.health.HealthTrackerRegistry; +import com.google.cloud.spring.pubsub.core.subscriber.PubSubSubscriberTemplate; import com.google.cloud.spring.pubsub.integration.AckMode; import com.google.cloud.spring.pubsub.integration.inbound.PubSubInboundChannelAdapter; import com.google.cloud.spring.pubsub.integration.inbound.PubSubMessageSource; import com.google.cloud.spring.pubsub.integration.outbound.PubSubMessageHandler; +import com.google.cloud.spring.pubsub.support.SubscriberFactory; import com.google.cloud.spring.stream.binder.pubsub.config.PubSubBinderConfiguration; import com.google.cloud.spring.stream.binder.pubsub.properties.PubSubConsumerProperties; import com.google.cloud.spring.stream.binder.pubsub.properties.PubSubExtendedBindingProperties; @@ -106,6 +109,9 @@ public class PubSubMessageChannelBinderTests { @Mock MessageChannel errorChannel; + @Mock + HealthTrackerRegistry healthTrackerRegistry; + ApplicationContextRunner baseContext = new ApplicationContextRunner() .withBean(PubSubTemplate.class, () -> pubSubTemplate) .withBean(PubSubAdmin.class, () -> pubSubAdmin) @@ -175,6 +181,32 @@ public void consumerMaxFetchPropertyPropagatesToMessageSource() { }); } + @Test + public void testCreateConsumerWithRegistry() { + SubscriberFactory subscriberFactory = mock(SubscriberFactory.class); + when(subscriberFactory.getProjectId()).thenReturn("test-project-id"); + PubSubSubscriberTemplate subSubscriberTemplate = mock(PubSubSubscriberTemplate.class); + when(subSubscriberTemplate.getSubscriberFactory()).thenReturn(subscriberFactory); + when(pubSubTemplate.getPubSubSubscriberTemplate()).thenReturn(subSubscriberTemplate); + + baseContext + .run(ctx -> { + PubSubMessageChannelBinder binder = ctx.getBean(PubSubMessageChannelBinder.class); + PubSubExtendedBindingProperties props = ctx.getBean("pubSubExtendedBindingProperties", PubSubExtendedBindingProperties.class); + binder.setHealthTrackerRegistry(healthTrackerRegistry); + + MessageProducer messageProducer = binder + .createConsumerEndpoint(consumerDestination, "testGroup", + new ExtendedConsumerProperties<>(props.getExtendedConsumerProperties("test")) + ); + + assertThat(messageProducer).isInstanceOf(PubSubInboundChannelAdapter.class); + PubSubInboundChannelAdapter inboundChannelAdapter = (PubSubInboundChannelAdapter) messageProducer; + assertThat(inboundChannelAdapter.getAckMode()).isSameAs(AckMode.AUTO); + assertThat(inboundChannelAdapter.healthCheckEnabled()).isEqualTo(true); + }); + } + @Test public void testProducerAndConsumerCustomizers() { baseContext.withUserConfiguration(PubSubBinderTestConfig.class) diff --git a/spring-cloud-gcp-pubsub/src/main/java/com/google/cloud/spring/pubsub/integration/inbound/PubSubInboundChannelAdapter.java b/spring-cloud-gcp-pubsub/src/main/java/com/google/cloud/spring/pubsub/integration/inbound/PubSubInboundChannelAdapter.java index 8da2a48c9b..ea17059c1b 100644 --- a/spring-cloud-gcp-pubsub/src/main/java/com/google/cloud/spring/pubsub/integration/inbound/PubSubInboundChannelAdapter.java +++ b/spring-cloud-gcp-pubsub/src/main/java/com/google/cloud/spring/pubsub/integration/inbound/PubSubInboundChannelAdapter.java @@ -19,11 +19,14 @@ import java.util.Map; import com.google.cloud.pubsub.v1.Subscriber; +import com.google.cloud.spring.pubsub.core.health.HealthTrackerRegistry; import com.google.cloud.spring.pubsub.core.subscriber.PubSubSubscriberOperations; import com.google.cloud.spring.pubsub.integration.AckMode; import com.google.cloud.spring.pubsub.integration.PubSubHeaderMapper; import com.google.cloud.spring.pubsub.support.GcpPubSubHeaders; +import com.google.cloud.spring.pubsub.support.PubSubSubscriptionUtils; import com.google.cloud.spring.pubsub.support.converter.ConvertedBasicAcknowledgeablePubsubMessage; +import com.google.pubsub.v1.ProjectSubscriptionName; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -57,6 +60,10 @@ public class PubSubInboundChannelAdapter extends MessageProducerSupport { private Class payloadType = byte[].class; + private HealthTrackerRegistry healthTrackerRegistry; + + private String projectId; + public PubSubInboundChannelAdapter(PubSubSubscriberOperations pubSubSubscriberOperations, String subscriptionName) { Assert.notNull(pubSubSubscriberOperations, "Pub/Sub subscriber template can't be null."); Assert.notNull(subscriptionName, "Pub/Sub subscription name can't be null."); @@ -73,6 +80,16 @@ public void setAckMode(AckMode ackMode) { this.ackMode = ackMode; } + public void setHealthTrackerRegistry( + HealthTrackerRegistry healthTrackerRegistry) { + Assert.notNull(projectId, "HealthTrackerRegistry requires a projectId"); + this.healthTrackerRegistry = healthTrackerRegistry; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + public Class getPayloadType() { return this.payloadType; } @@ -105,8 +122,12 @@ public void setHeaderMapper(HeaderMapper> headerMapper) { protected void doStart() { super.doStart(); + addToHealthRegistry(); + this.subscriber = this.pubSubSubscriberOperations.subscribeAndConvert( this.subscriptionName, this::consumeMessage, this.payloadType); + + addListeners(); } @Override @@ -132,6 +153,8 @@ private void consumeMessage(ConvertedBasicAcknowledgeablePubsubMessage messag .copyHeaders(messageHeaders) .build()); + processedMessage(message.getProjectSubscriptionName()); + if (this.ackMode == AckMode.AUTO_ACK || this.ackMode == AckMode.AUTO) { message.ack(); } @@ -148,4 +171,28 @@ private void consumeMessage(ConvertedBasicAcknowledgeablePubsubMessage messag } } } + + private void addToHealthRegistry() { + if (healthCheckEnabled()) { + ProjectSubscriptionName projectSubscriptionName = PubSubSubscriptionUtils.toProjectSubscriptionName(subscriptionName, this.projectId); + healthTrackerRegistry.registerTracker(projectSubscriptionName); + } + } + + private void addListeners() { + if (healthCheckEnabled()) { + healthTrackerRegistry.addListener(subscriber); + } + } + + private void processedMessage(ProjectSubscriptionName projectSubscriptionName) { + if (healthCheckEnabled()) { + healthTrackerRegistry.processedMessage(projectSubscriptionName); + } + } + + public boolean healthCheckEnabled() { + return healthTrackerRegistry != null; + } + } diff --git a/spring-cloud-gcp-pubsub/src/main/java/com/google/cloud/spring/pubsub/support/DefaultSubscriberFactory.java b/spring-cloud-gcp-pubsub/src/main/java/com/google/cloud/spring/pubsub/support/DefaultSubscriberFactory.java index 9f5d6d84c3..cb07d20589 100644 --- a/spring-cloud-gcp-pubsub/src/main/java/com/google/cloud/spring/pubsub/support/DefaultSubscriberFactory.java +++ b/spring-cloud-gcp-pubsub/src/main/java/com/google/cloud/spring/pubsub/support/DefaultSubscriberFactory.java @@ -37,6 +37,8 @@ import com.google.cloud.pubsub.v1.stub.SubscriberStubSettings; import com.google.cloud.spring.core.GcpProjectIdProvider; import com.google.cloud.spring.pubsub.core.PubSubConfiguration; +import com.google.cloud.spring.pubsub.core.health.HealthTrackerRegistry; +import com.google.pubsub.v1.ProjectSubscriptionName; import com.google.pubsub.v1.PullRequest; import org.threeten.bp.Duration; @@ -79,6 +81,8 @@ public class DefaultSubscriberFactory implements SubscriberFactory { private RetrySettings subscriberStubRetrySettings; + private HealthTrackerRegistry healthTrackerRegistry; + private PubSubConfiguration pubSubConfiguration; private ConcurrentMap threadPoolTaskSchedulerMap = new ConcurrentHashMap<>(); @@ -210,6 +214,14 @@ public void setSubscriberStubRetrySettings(RetrySettings subscriberStubRetrySett this.subscriberStubRetrySettings = subscriberStubRetrySettings; } + /** + * Set the health tracker chain for the generated subscriptions. + * @param healthTrackerRegistry parameter for registering health trackers when creating subscriptions + */ + public void setHealthTrackerRegistry(HealthTrackerRegistry healthTrackerRegistry) { + this.healthTrackerRegistry = healthTrackerRegistry; + } + @Override public Subscriber createSubscriber(String subscriptionName, MessageReceiver receiver) { boolean shouldAddToHealthCheck = shouldAddToHealthCheck(subscriptionName); @@ -511,4 +523,14 @@ public void setThreadPoolTaskSchedulerMap( public void setGlobalScheduler(ThreadPoolTaskScheduler threadPoolTaskScheduler) { this.globalScheduler = threadPoolTaskScheduler; } + + private boolean shouldAddToHealthCheck(String subscriptionName) { + if (healthTrackerRegistry == null) { + return false; + } + + ProjectSubscriptionName projectSubscriptionName = PubSubSubscriptionUtils.toProjectSubscriptionName(subscriptionName, this.projectId); + return !healthTrackerRegistry.isTracked(projectSubscriptionName); + } + } diff --git a/spring-cloud-gcp-pubsub/src/test/java/com/google/cloud/spring/pubsub/support/DefaultSubscriberFactoryTests.java b/spring-cloud-gcp-pubsub/src/test/java/com/google/cloud/spring/pubsub/support/DefaultSubscriberFactoryTests.java index b97d66317e..30a8367b2a 100644 --- a/spring-cloud-gcp-pubsub/src/test/java/com/google/cloud/spring/pubsub/support/DefaultSubscriberFactoryTests.java +++ b/spring-cloud-gcp-pubsub/src/test/java/com/google/cloud/spring/pubsub/support/DefaultSubscriberFactoryTests.java @@ -26,6 +26,8 @@ import com.google.cloud.pubsub.v1.Subscriber; import com.google.cloud.spring.core.GcpProjectIdProvider; import com.google.cloud.spring.pubsub.core.PubSubConfiguration; +import com.google.cloud.spring.pubsub.core.health.HealthTrackerRegistry; +import com.google.pubsub.v1.ProjectSubscriptionName; import com.google.pubsub.v1.PullRequest; import org.junit.Rule; import org.junit.Test; @@ -38,6 +40,10 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -67,6 +73,10 @@ public class DefaultSubscriberFactoryTests { private ThreadPoolTaskScheduler mockScheduler; @Mock private ThreadPoolTaskScheduler mockGlobalScheduler; + @Mock + private HealthTrackerRegistry healthTrackerRegistry; + @Mock + private ExecutorProvider executorProvider; /** * used to check exception messages and types. @@ -519,4 +529,43 @@ public void testGetPullEndpoint_configurationIsNull() { assertThat(factory.getPullEndpoint("subscription-name")).isNull(); } + + + @Test + public void testNewSubscriber_shouldNotAddToHealthCheck() { + ProjectSubscriptionName subscriptionName = ProjectSubscriptionName.of("angeldust", "midnight cowboy"); + + when(healthTrackerRegistry.isTracked(subscriptionName)).thenReturn(true); + + DefaultSubscriberFactory factory = new DefaultSubscriberFactory(() -> "angeldust"); + factory.setCredentialsProvider(this.credentialsProvider); + factory.setHealthTrackerRegistry(healthTrackerRegistry); + + Subscriber subscriber = factory.createSubscriber("midnight cowboy", (message, consumer) -> { }); + assertThat(subscriber.getSubscriptionNameString()) + .isEqualTo("projects/angeldust/subscriptions/midnight cowboy"); + + verify(healthTrackerRegistry, times(1)).isTracked(subscriptionName); + verify(healthTrackerRegistry, times(0)).wrap(eq(subscriptionName), any()); + } + + @Test + public void testNewSubscriber_shouldAddToHealthCheck() { + ProjectSubscriptionName subscriptionName = ProjectSubscriptionName.of("angeldust", "midnight cowboy"); + + when(healthTrackerRegistry.isTracked(subscriptionName)).thenReturn(false); + + DefaultSubscriberFactory factory = new DefaultSubscriberFactory(() -> "angeldust"); + factory.setCredentialsProvider(this.credentialsProvider); + factory.setHealthTrackerRegistry(healthTrackerRegistry); + + Subscriber subscriber = factory.createSubscriber("midnight cowboy", (message, consumer) -> { }); + assertThat(subscriber.getSubscriptionNameString()) + .isEqualTo("projects/angeldust/subscriptions/midnight cowboy"); + + verify(healthTrackerRegistry).isTracked(subscriptionName); + verify(healthTrackerRegistry).wrap(any(), any()); + verify(healthTrackerRegistry).addListener(any()); + } + }