Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
gkatzioura committed Oct 23, 2021
1 parent 9400179 commit 9ab58d0
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import com.google.cloud.spring.core.GcpProjectIdProvider;
import com.google.cloud.spring.pubsub.PubSubAdmin;
import com.google.cloud.spring.pubsub.core.PubSubTemplate;
import com.google.cloud.spring.pubsub.core.health.HealthTrackerRegistry;
import com.google.cloud.spring.pubsub.core.subscriber.PubSubSubscriberTemplate;
import com.google.cloud.spring.pubsub.integration.AckMode;
import com.google.cloud.spring.pubsub.integration.inbound.PubSubInboundChannelAdapter;
import com.google.cloud.spring.pubsub.integration.inbound.PubSubMessageSource;
import com.google.cloud.spring.pubsub.integration.outbound.PubSubMessageHandler;
import com.google.cloud.spring.pubsub.support.SubscriberFactory;
import com.google.cloud.spring.stream.binder.pubsub.config.PubSubBinderConfiguration;
import com.google.cloud.spring.stream.binder.pubsub.properties.PubSubConsumerProperties;
import com.google.cloud.spring.stream.binder.pubsub.properties.PubSubExtendedBindingProperties;
Expand Down Expand Up @@ -106,6 +109,9 @@ public class PubSubMessageChannelBinderTests {
@Mock
MessageChannel errorChannel;

@Mock
HealthTrackerRegistry healthTrackerRegistry;

ApplicationContextRunner baseContext = new ApplicationContextRunner()
.withBean(PubSubTemplate.class, () -> pubSubTemplate)
.withBean(PubSubAdmin.class, () -> pubSubAdmin)
Expand Down Expand Up @@ -175,6 +181,32 @@ public void consumerMaxFetchPropertyPropagatesToMessageSource() {
});
}

@Test
public void testCreateConsumerWithRegistry() {
SubscriberFactory subscriberFactory = mock(SubscriberFactory.class);
when(subscriberFactory.getProjectId()).thenReturn("test-project-id");
PubSubSubscriberTemplate subSubscriberTemplate = mock(PubSubSubscriberTemplate.class);
when(subSubscriberTemplate.getSubscriberFactory()).thenReturn(subscriberFactory);
when(pubSubTemplate.getPubSubSubscriberTemplate()).thenReturn(subSubscriberTemplate);

baseContext
.run(ctx -> {
PubSubMessageChannelBinder binder = ctx.getBean(PubSubMessageChannelBinder.class);
PubSubExtendedBindingProperties props = ctx.getBean("pubSubExtendedBindingProperties", PubSubExtendedBindingProperties.class);
binder.setHealthTrackerRegistry(healthTrackerRegistry);

MessageProducer messageProducer = binder
.createConsumerEndpoint(consumerDestination, "testGroup",
new ExtendedConsumerProperties<>(props.getExtendedConsumerProperties("test"))
);

assertThat(messageProducer).isInstanceOf(PubSubInboundChannelAdapter.class);
PubSubInboundChannelAdapter inboundChannelAdapter = (PubSubInboundChannelAdapter) messageProducer;
assertThat(inboundChannelAdapter.getAckMode()).isSameAs(AckMode.AUTO);
assertThat(inboundChannelAdapter.healthCheckEnabled()).isEqualTo(true);
});
}

@Test
public void testProducerAndConsumerCustomizers() {
baseContext.withUserConfiguration(PubSubBinderTestConfig.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
import java.util.Map;

import com.google.cloud.pubsub.v1.Subscriber;
import com.google.cloud.spring.pubsub.core.health.HealthTrackerRegistry;
import com.google.cloud.spring.pubsub.core.subscriber.PubSubSubscriberOperations;
import com.google.cloud.spring.pubsub.integration.AckMode;
import com.google.cloud.spring.pubsub.integration.PubSubHeaderMapper;
import com.google.cloud.spring.pubsub.support.GcpPubSubHeaders;
import com.google.cloud.spring.pubsub.support.PubSubSubscriptionUtils;
import com.google.cloud.spring.pubsub.support.converter.ConvertedBasicAcknowledgeablePubsubMessage;
import com.google.pubsub.v1.ProjectSubscriptionName;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

Expand Down Expand Up @@ -57,6 +60,10 @@ public class PubSubInboundChannelAdapter extends MessageProducerSupport {

private Class<?> payloadType = byte[].class;

private HealthTrackerRegistry healthTrackerRegistry;

private String projectId;

public PubSubInboundChannelAdapter(PubSubSubscriberOperations pubSubSubscriberOperations, String subscriptionName) {
Assert.notNull(pubSubSubscriberOperations, "Pub/Sub subscriber template can't be null.");
Assert.notNull(subscriptionName, "Pub/Sub subscription name can't be null.");
Expand All @@ -73,6 +80,16 @@ public void setAckMode(AckMode ackMode) {
this.ackMode = ackMode;
}

public void setHealthTrackerRegistry(
HealthTrackerRegistry healthTrackerRegistry) {
Assert.notNull(projectId, "HealthTrackerRegistry requires a projectId");
this.healthTrackerRegistry = healthTrackerRegistry;
}

public void setProjectId(String projectId) {
this.projectId = projectId;
}

public Class<?> getPayloadType() {
return this.payloadType;
}
Expand Down Expand Up @@ -105,8 +122,12 @@ public void setHeaderMapper(HeaderMapper<Map<String, String>> headerMapper) {
protected void doStart() {
super.doStart();

addToHealthRegistry();

this.subscriber = this.pubSubSubscriberOperations.subscribeAndConvert(
this.subscriptionName, this::consumeMessage, this.payloadType);

addListeners();
}

@Override
Expand All @@ -132,6 +153,8 @@ private void consumeMessage(ConvertedBasicAcknowledgeablePubsubMessage<?> messag
.copyHeaders(messageHeaders)
.build());

processedMessage(message.getProjectSubscriptionName());

if (this.ackMode == AckMode.AUTO_ACK || this.ackMode == AckMode.AUTO) {
message.ack();
}
Expand All @@ -148,4 +171,28 @@ private void consumeMessage(ConvertedBasicAcknowledgeablePubsubMessage<?> messag
}
}
}

private void addToHealthRegistry() {
if (healthCheckEnabled()) {
ProjectSubscriptionName projectSubscriptionName = PubSubSubscriptionUtils.toProjectSubscriptionName(subscriptionName, this.projectId);
healthTrackerRegistry.registerTracker(projectSubscriptionName);
}
}

private void addListeners() {
if (healthCheckEnabled()) {
healthTrackerRegistry.addListener(subscriber);
}
}

private void processedMessage(ProjectSubscriptionName projectSubscriptionName) {
if (healthCheckEnabled()) {
healthTrackerRegistry.processedMessage(projectSubscriptionName);
}
}

public boolean healthCheckEnabled() {
return healthTrackerRegistry != null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import com.google.cloud.pubsub.v1.stub.SubscriberStubSettings;
import com.google.cloud.spring.core.GcpProjectIdProvider;
import com.google.cloud.spring.pubsub.core.PubSubConfiguration;
import com.google.cloud.spring.pubsub.core.health.HealthTrackerRegistry;
import com.google.pubsub.v1.ProjectSubscriptionName;
import com.google.pubsub.v1.PullRequest;
import org.threeten.bp.Duration;

Expand Down Expand Up @@ -79,6 +81,8 @@ public class DefaultSubscriberFactory implements SubscriberFactory {

private RetrySettings subscriberStubRetrySettings;

private HealthTrackerRegistry healthTrackerRegistry;

private PubSubConfiguration pubSubConfiguration;

private ConcurrentMap<String, ThreadPoolTaskScheduler> threadPoolTaskSchedulerMap = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -210,6 +214,14 @@ public void setSubscriberStubRetrySettings(RetrySettings subscriberStubRetrySett
this.subscriberStubRetrySettings = subscriberStubRetrySettings;
}

/**
* Set the health tracker chain for the generated subscriptions.
* @param healthTrackerRegistry parameter for registering health trackers when creating subscriptions
*/
public void setHealthTrackerRegistry(HealthTrackerRegistry healthTrackerRegistry) {
this.healthTrackerRegistry = healthTrackerRegistry;
}

@Override
public Subscriber createSubscriber(String subscriptionName, MessageReceiver receiver) {
boolean shouldAddToHealthCheck = shouldAddToHealthCheck(subscriptionName);
Expand Down Expand Up @@ -511,4 +523,14 @@ public void setThreadPoolTaskSchedulerMap(
public void setGlobalScheduler(ThreadPoolTaskScheduler threadPoolTaskScheduler) {
this.globalScheduler = threadPoolTaskScheduler;
}

private boolean shouldAddToHealthCheck(String subscriptionName) {
if (healthTrackerRegistry == null) {
return false;
}

ProjectSubscriptionName projectSubscriptionName = PubSubSubscriptionUtils.toProjectSubscriptionName(subscriptionName, this.projectId);
return !healthTrackerRegistry.isTracked(projectSubscriptionName);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.google.cloud.pubsub.v1.Subscriber;
import com.google.cloud.spring.core.GcpProjectIdProvider;
import com.google.cloud.spring.pubsub.core.PubSubConfiguration;
import com.google.cloud.spring.pubsub.core.health.HealthTrackerRegistry;
import com.google.pubsub.v1.ProjectSubscriptionName;
import com.google.pubsub.v1.PullRequest;
import org.junit.Rule;
import org.junit.Test;
Expand All @@ -38,6 +40,10 @@
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
Expand Down Expand Up @@ -67,6 +73,10 @@ public class DefaultSubscriberFactoryTests {
private ThreadPoolTaskScheduler mockScheduler;
@Mock
private ThreadPoolTaskScheduler mockGlobalScheduler;
@Mock
private HealthTrackerRegistry healthTrackerRegistry;
@Mock
private ExecutorProvider executorProvider;

/**
* used to check exception messages and types.
Expand Down Expand Up @@ -519,4 +529,43 @@ public void testGetPullEndpoint_configurationIsNull() {

assertThat(factory.getPullEndpoint("subscription-name")).isNull();
}


@Test
public void testNewSubscriber_shouldNotAddToHealthCheck() {
ProjectSubscriptionName subscriptionName = ProjectSubscriptionName.of("angeldust", "midnight cowboy");

when(healthTrackerRegistry.isTracked(subscriptionName)).thenReturn(true);

DefaultSubscriberFactory factory = new DefaultSubscriberFactory(() -> "angeldust");
factory.setCredentialsProvider(this.credentialsProvider);
factory.setHealthTrackerRegistry(healthTrackerRegistry);

Subscriber subscriber = factory.createSubscriber("midnight cowboy", (message, consumer) -> { });
assertThat(subscriber.getSubscriptionNameString())
.isEqualTo("projects/angeldust/subscriptions/midnight cowboy");

verify(healthTrackerRegistry, times(1)).isTracked(subscriptionName);
verify(healthTrackerRegistry, times(0)).wrap(eq(subscriptionName), any());
}

@Test
public void testNewSubscriber_shouldAddToHealthCheck() {
ProjectSubscriptionName subscriptionName = ProjectSubscriptionName.of("angeldust", "midnight cowboy");

when(healthTrackerRegistry.isTracked(subscriptionName)).thenReturn(false);

DefaultSubscriberFactory factory = new DefaultSubscriberFactory(() -> "angeldust");
factory.setCredentialsProvider(this.credentialsProvider);
factory.setHealthTrackerRegistry(healthTrackerRegistry);

Subscriber subscriber = factory.createSubscriber("midnight cowboy", (message, consumer) -> { });
assertThat(subscriber.getSubscriptionNameString())
.isEqualTo("projects/angeldust/subscriptions/midnight cowboy");

verify(healthTrackerRegistry).isTracked(subscriptionName);
verify(healthTrackerRegistry).wrap(any(), any());
verify(healthTrackerRegistry).addListener(any());
}

}

0 comments on commit 9ab58d0

Please sign in to comment.