Skip to content

Commit

Permalink
Merge pull request #656 from Iterable/evan/MOB-7154-jwt-retry-omni
Browse files Browse the repository at this point in the history
[MOB-7154] Implement JWT Retry Logic for Android SDK
  • Loading branch information
evantk91 authored Nov 15, 2023
2 parents 7ee5748 + 7a8d798 commit 4bed13d
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import androidx.test.filters.MediumTest;

import org.hamcrest.CoreMatchers;
import org.json.JSONException;
import org.json.JSONObject;
import org.junit.After;
import org.junit.Before;
Expand All @@ -25,6 +26,7 @@
import static com.iterable.iterableapi.IterableTestUtils.createIterableApi;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertNotNull;
import static junit.framework.Assert.assertNull;
import static junit.framework.Assert.assertTrue;
import static org.junit.Assert.assertThat;

Expand Down Expand Up @@ -198,6 +200,72 @@ public void onFailure(@NonNull String reason, @Nullable JSONObject data) {
assertTrue("onFailure is called", signal.await(1, TimeUnit.SECONDS));
}

@Test
public void testRetryOnInvalidJwtPayload() throws Exception {
final CountDownLatch signal = new CountDownLatch(3);
stubAnyRequestReturningStatusCode(401, "{\"msg\":\"JWT Authorization header error\",\"code\":\"InvalidJwtPayload\"}");

IterableApiRequest request = new IterableApiRequest("fake_key", "", new JSONObject(), IterableApiRequest.POST, null, null, new IterableHelper.FailureHandler() {
@Override
public void onFailure(@NonNull String reason, @Nullable JSONObject data) {
try {
if (data != null && "InvalidJwtPayload".equals(data.optString("code"))) {
final JSONObject responseData = new JSONObject("{\n" +
" \"key\":\"Success\",\n" +
" \"message\":\"Event tracked successfully.\"\n" +
"}");
stubAnyRequestReturningStatusCode(200, responseData);

new IterableRequestTask().execute(new IterableApiRequest("fake_key", "", new JSONObject(), IterableApiRequest.POST, null, new IterableHelper.SuccessHandler() {
@Override
public void onSuccess(@NonNull JSONObject successData) {
try {
assertEquals(responseData.toString(), successData.toString());
} catch (AssertionError e) {
e.printStackTrace();
} finally {
signal.countDown();
}
}
}, null));
server.takeRequest(2, TimeUnit.SECONDS);
}
} catch (JSONException e) {
e.printStackTrace();
} catch (Exception e) {
e.printStackTrace();
} finally {
signal.countDown();
}
}
});

new IterableRequestTask().execute(request);
server.takeRequest(1, TimeUnit.SECONDS);

// Await for the background tasks to complete
signal.await(5, TimeUnit.SECONDS);
}

@Test
public void testMaxRetriesOnMultipleInvalidJwtPayloads() throws Exception {
for (int i = 0; i < 5; i++) {
stubAnyRequestReturningStatusCode(401, "{\"msg\":\"JWT Authorization header error\",\"code\":\"InvalidJwtPayload\"}");
}

IterableApiRequest request = new IterableApiRequest("fake_key", "", new JSONObject(), IterableApiRequest.POST, null, null, null);
IterableRequestTask task = new IterableRequestTask();
task.execute(request);

RecordedRequest request1 = server.takeRequest(1, TimeUnit.SECONDS);
RecordedRequest request2 = server.takeRequest(5, TimeUnit.SECONDS);
RecordedRequest request3 = server.takeRequest(5, TimeUnit.SECONDS);
RecordedRequest request4 = server.takeRequest(5, TimeUnit.SECONDS);
RecordedRequest request5 = server.takeRequest(5, TimeUnit.SECONDS);
RecordedRequest request6 = server.takeRequest(5, TimeUnit.SECONDS);
assertNull("Request should be null since retries hit the max of 5", request6);
}

@Test
public void testResponseCode500() throws Exception {
for (int i = 0; i < 5; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import androidx.annotation.VisibleForTesting;

import com.iterable.iterableapi.util.Future;

import org.json.JSONException;
import org.json.JSONObject;

import java.io.UnsupportedEncodingException;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class IterableAuthManager {
private static final String TAG = "IterableAuth";
Expand All @@ -20,55 +20,53 @@ public class IterableAuthManager {
private final IterableApi api;
private final IterableAuthHandler authHandler;
private final long expiringAuthTokenRefreshPeriod;

private final long scheduledRefreshPeriod = 10000;
@VisibleForTesting
Timer timer;
private boolean hasFailedPriorAuth;
private boolean pendingAuth;
private boolean requiresAuthRefresh;

private final ExecutorService executor = Executors.newSingleThreadExecutor();

IterableAuthManager(IterableApi api, IterableAuthHandler authHandler, long expiringAuthTokenRefreshPeriod) {
this.api = api;
this.authHandler = authHandler;
this.expiringAuthTokenRefreshPeriod = expiringAuthTokenRefreshPeriod;
}

public synchronized void requestNewAuthToken(boolean hasFailedPriorAuth) {
requestNewAuthToken(hasFailedPriorAuth, null);
}

private void handleSuccessForAuthToken(String authToken, IterableHelper.SuccessHandler successCallback) {
try {
JSONObject object = new JSONObject();
object.put("newAuthToken", authToken);
successCallback.onSuccess(object);
} catch (JSONException e) {
e.printStackTrace();
}
}

public synchronized void requestNewAuthToken(
boolean hasFailedPriorAuth,
final IterableHelper.SuccessHandler successCallback) {
if (authHandler != null) {
if (!pendingAuth) {
if (!(this.hasFailedPriorAuth && hasFailedPriorAuth)) {
this.hasFailedPriorAuth = hasFailedPriorAuth;
pendingAuth = true;
Future.runAsync(new Callable<String>() {
@Override
public String call() throws Exception {
return authHandler.onAuthTokenRequested();
}
}).onSuccess(new Future.SuccessCallback<String>() {

executor.submit(new Runnable() {
@Override
public void onSuccess(String authToken) {
if (authToken != null) {
queueExpirationRefresh(authToken);
} else {
IterableLogger.w(TAG, "Auth token received as null. Calling the handler in 10 seconds");
//TODO: Make this time configurable and in sync with SDK initialization flow for auth null scenario
scheduleAuthTokenRefresh(10000);
authHandler.onTokenRegistrationFailed(new Throwable("Auth token null"));
return;
public void run() {
try {
final String authToken = authHandler.onAuthTokenRequested();
handleAuthTokenSuccess(authToken, successCallback);
} catch (final Exception e) {
handleAuthTokenFailure(e);
}
IterableApi.getInstance().setAuthToken(authToken);
pendingAuth = false;
reSyncAuth();
authHandler.onTokenRegistrationSuccessful(authToken);
}
})
.onFailure(new Future.FailureCallback() {
@Override
public void onFailure(Throwable throwable) {
IterableLogger.e(TAG, "Error while requesting Auth Token", throwable);
authHandler.onTokenRegistrationFailed(throwable);
pendingAuth = false;
reSyncAuth();
}
});
}
Expand All @@ -82,6 +80,32 @@ public void onFailure(Throwable throwable) {
}
}

private void handleAuthTokenSuccess(String authToken, IterableHelper.SuccessHandler successCallback) {
if (authToken != null) {
if (successCallback != null) {
handleSuccessForAuthToken(authToken, successCallback);
}
queueExpirationRefresh(authToken);
} else {
IterableLogger.w(TAG, "Auth token received as null. Calling the handler in 10 seconds");
//TODO: Make this time configurable and in sync with SDK initialization flow for auth null scenario
scheduleAuthTokenRefresh(scheduledRefreshPeriod);
authHandler.onTokenRegistrationFailed(new Throwable("Auth token null"));
return;
}
IterableApi.getInstance().setAuthToken(authToken);
pendingAuth = false;
reSyncAuth();
authHandler.onTokenRegistrationSuccessful(authToken);
}

private void handleAuthTokenFailure(Throwable throwable) {
IterableLogger.e(TAG, "Error while requesting Auth Token", throwable);
authHandler.onTokenRegistrationFailed(throwable);
pendingAuth = false;
reSyncAuth();
}

public void queueExpirationRefresh(String encodedJWT) {
clearRefreshTimer();
try {
Expand All @@ -96,7 +120,7 @@ public void queueExpirationRefresh(String encodedJWT) {
IterableLogger.e(TAG, "Error while parsing JWT for the expiration", e);
authHandler.onTokenRegistrationFailed(new Throwable("Auth token decode failure. Scheduling auth token refresh in 10 seconds..."));
//TODO: Sync with configured time duration once feature is available.
scheduleAuthTokenRefresh(10000);
scheduleAuthTokenRefresh(scheduledRefreshPeriod);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import android.net.Uri;
import android.os.AsyncTask;
import android.os.Handler;
import android.os.Looper;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.WorkerThread;
Expand Down Expand Up @@ -38,11 +40,13 @@ class IterableRequestTask extends AsyncTask<IterableApiRequest, Void, IterableAp
static final String ERROR_CODE_INVALID_JWT_PAYLOAD = "InvalidJwtPayload";

int retryCount = 0;
boolean shouldRetryWhileJwtInvalid = true;
IterableApiRequest iterableApiRequest;

/**
* Sends the given request to Iterable using a HttpUserConnection
* Reference - http://developer.android.com/reference/java/net/HttpURLConnection.html
*
* @param params
* @return
*/
Expand All @@ -54,6 +58,23 @@ protected IterableApiResponse doInBackground(IterableApiRequest... params) {
return executeApiRequest(iterableApiRequest);
}

public void setShouldRetryWhileJwtInvalid(boolean shouldRetryWhileJwtInvalid) {
this.shouldRetryWhileJwtInvalid = shouldRetryWhileJwtInvalid;
}

private void retryRequestWithNewAuthToken(String newAuthToken) {
IterableApiRequest request = new IterableApiRequest(
iterableApiRequest.apiKey,
iterableApiRequest.resourcePath,
iterableApiRequest.json,
iterableApiRequest.requestType,
newAuthToken,
iterableApiRequest.legacyCallback);
IterableRequestTask requestTask = new IterableRequestTask();
requestTask.setShouldRetryWhileJwtInvalid(false);
requestTask.execute(request);
}

@WorkerThread
static IterableApiResponse executeApiRequest(IterableApiRequest iterableApiRequest) {
IterableApiResponse apiResponse = null;
Expand Down Expand Up @@ -269,50 +290,75 @@ private static boolean isSensitive(String key) {
return (key.equals(IterableConstants.HEADER_API_KEY)) || key.equals(IterableConstants.HEADER_SDK_AUTHORIZATION);
}

private static final Handler handler = new Handler(Looper.getMainLooper());

@Override
protected void onPostExecute(IterableApiResponse response) {
boolean retryRequest = !response.success && response.responseCode >= 500;

if (retryRequest && retryCount <= MAX_RETRY_COUNT) {
final IterableRequestTask requestTask = new IterableRequestTask();
requestTask.setRetryCount(retryCount + 1);

long delay = 0;
if (retryCount > 2) {
delay = RETRY_DELAY_MS * retryCount;
}

Handler handler = new Handler();
handler.postDelayed(new Runnable() {
@Override
public void run() {
requestTask.execute(iterableApiRequest);
}
}, delay);
if (shouldRetry(response)) {
retryRequestWithDelay();
return;
} else if (response.success) {
IterableApi.getInstance().getAuthManager().resetFailedAuth();
if (iterableApiRequest.successCallback != null) {
iterableApiRequest.successCallback.onSuccess(response.responseJson);
}
handleSuccessResponse(response);
} else {
if (matchesErrorCode(response.responseJson, ERROR_CODE_INVALID_JWT_PAYLOAD)) {
IterableApi.getInstance().getAuthManager().requestNewAuthToken(true);
}
if (iterableApiRequest.failureCallback != null) {
iterableApiRequest.failureCallback.onFailure(response.errorMessage, response.responseJson);
}
handleErrorResponse(response);
}

if (iterableApiRequest.legacyCallback != null) {
iterableApiRequest.legacyCallback.execute(response.responseBody);
}
super.onPostExecute(response);
}

private boolean shouldRetry(IterableApiResponse response) {
return !response.success && response.responseCode >= 500 && retryCount <= MAX_RETRY_COUNT;
}

private void retryRequestWithDelay() {
final IterableRequestTask requestTask = new IterableRequestTask();
requestTask.setRetryCount(retryCount + 1);

long delay = (retryCount > 2) ? RETRY_DELAY_MS * retryCount : 0;

handler.postDelayed(new Runnable() {
@Override
public void run() {
requestTask.execute(iterableApiRequest);
}
}, delay);
}

private void handleSuccessResponse(IterableApiResponse response) {
IterableApi.getInstance().getAuthManager().resetFailedAuth();
if (iterableApiRequest.successCallback != null) {
iterableApiRequest.successCallback.onSuccess(response.responseJson);
}
}

private void handleErrorResponse(IterableApiResponse response) {
if (matchesErrorCode(response.responseJson, ERROR_CODE_INVALID_JWT_PAYLOAD) && shouldRetryWhileJwtInvalid) {
requestNewAuthTokenAndRetry(response);
}

if (iterableApiRequest.failureCallback != null) {
iterableApiRequest.failureCallback.onFailure(response.errorMessage, response.responseJson);
}
}

private void requestNewAuthTokenAndRetry(IterableApiResponse response) {
IterableApi.getInstance().getAuthManager().requestNewAuthToken(false, data -> {
try {
String newAuthToken = data.getString("newAuthToken");
retryRequestWithNewAuthToken(newAuthToken);
} catch (JSONException e) {
e.printStackTrace();
}
});
}

protected void setRetryCount(int count) {
retryCount = count;
}

}

/**
Expand Down

0 comments on commit 4bed13d

Please sign in to comment.