Skip to content

Commit

Permalink
feat(credential-provider-imds): support static stability (#3402)
Browse files Browse the repository at this point in the history
* feat(credential-provider-imds): support static stability

* fix(credential-provider-imds): remove static stability for container provider

* feat(credential-provider-imds): add jitter to static stable refresh interval

Co-authored-by: Trivikram Kamat <[email protected]>
  • Loading branch information
AllanZhengYP and trivikr authored Mar 9, 2022
1 parent 813ff8a commit a4beeba
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { httpRequest } from "./remoteProvider/httpRequest";
import { fromImdsCredentials, ImdsCredentials } from "./remoteProvider/ImdsCredentials";

const mockHttpRequest = <any>httpRequest;
jest.mock("./remoteProvider/httpRequest", () => ({ httpRequest: jest.fn() }));
jest.mock("./remoteProvider/httpRequest");

const relativeUri = process.env[ENV_CMDS_RELATIVE_URI];
const fullUri = process.env[ENV_CMDS_FULL_URI];
Expand Down
62 changes: 18 additions & 44 deletions packages/credential-provider-imds/src/fromInstanceMetadata.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import { fromImdsCredentials, isImdsCredentials } from "./remoteProvider/ImdsCre
import { providerConfigFromInit } from "./remoteProvider/RemoteProviderInit";
import { retry } from "./remoteProvider/retry";
import { getInstanceMetadataEndpoint } from "./utils/getInstanceMetadataEndpoint";
import { staticStabilityProvider } from "./utils/staticStabilityProvider";

jest.mock("./remoteProvider/httpRequest");
jest.mock("./remoteProvider/ImdsCredentials");
jest.mock("./remoteProvider/retry");
jest.mock("./remoteProvider/RemoteProviderInit");
jest.mock("./utils/getInstanceMetadataEndpoint");
jest.mock("./utils/staticStabilityProvider");

describe("fromInstanceMetadata", () => {
const hostname = "127.0.0.1";
Expand Down Expand Up @@ -39,11 +41,12 @@ describe("fromInstanceMetadata", () => {
},
};

const ONE_HOUR_IN_FUTURE = new Date(Date.now() + 60 * 60 * 1000);
const mockImdsCreds = Object.freeze({
AccessKeyId: "foo",
SecretAccessKey: "bar",
Token: "baz",
Expiration: new Date().toISOString(),
Expiration: ONE_HOUR_IN_FUTURE.toISOString(),
});

const mockCreds = Object.freeze({
Expand All @@ -54,6 +57,7 @@ describe("fromInstanceMetadata", () => {
});

beforeEach(() => {
(staticStabilityProvider as jest.Mock).mockImplementation((input) => input);
(getInstanceMetadataEndpoint as jest.Mock).mockResolvedValue({ hostname });
(isImdsCredentials as unknown as jest.Mock).mockReturnValue(true);
(providerConfigFromInit as jest.Mock).mockReturnValue({
Expand Down Expand Up @@ -192,6 +196,19 @@ describe("fromInstanceMetadata", () => {
await expect(fromInstanceMetadata()()).rejects.toEqual(tokenError);
});

it("should call staticStabilityProvider with the credential loader", async () => {
(httpRequest as jest.Mock)
.mockResolvedValueOnce(mockToken)
.mockResolvedValueOnce(mockProfile)
.mockResolvedValueOnce(JSON.stringify(mockImdsCreds));

(retry as jest.Mock).mockImplementation((fn: any) => fn());
(fromImdsCredentials as jest.Mock).mockReturnValue(mockCreds);

await fromInstanceMetadata()();
expect(staticStabilityProvider as jest.Mock).toBeCalledTimes(1);
});

describe("disables fetching of token", () => {
beforeEach(() => {
(retry as jest.Mock).mockImplementation((fn: any) => fn());
Expand Down Expand Up @@ -268,47 +285,4 @@ describe("fromInstanceMetadata", () => {
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
});

describe("re-enables fetching of token", () => {
const error401 = Object.assign(new Error("error"), { statusCode: 401 });

beforeEach(() => {
const tokenError = new Error("TimeoutError");

(httpRequest as jest.Mock)
.mockRejectedValueOnce(tokenError)
.mockResolvedValueOnce(mockProfile)
.mockResolvedValueOnce(JSON.stringify(mockImdsCreds));

(retry as jest.Mock).mockImplementation((fn: any) => fn());
(fromImdsCredentials as jest.Mock).mockReturnValue(mockCreds);
});

it("when profile error with 401", async () => {
(httpRequest as jest.Mock)
.mockRejectedValueOnce(error401)
.mockResolvedValueOnce(mockToken)
.mockResolvedValueOnce(mockProfile)
.mockResolvedValueOnce(JSON.stringify(mockImdsCreds));

const fromInstanceMetadataFunc = fromInstanceMetadata();
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
await expect(fromInstanceMetadataFunc()).rejects.toEqual(error401);
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
});

it("when creds error with 401", async () => {
(httpRequest as jest.Mock)
.mockResolvedValueOnce(mockProfile)
.mockRejectedValueOnce(error401)
.mockResolvedValueOnce(mockToken)
.mockResolvedValueOnce(mockProfile)
.mockResolvedValueOnce(JSON.stringify(mockImdsCreds));

const fromInstanceMetadataFunc = fromInstanceMetadata();
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
await expect(fromInstanceMetadataFunc()).rejects.toEqual(error401);
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
});
});
});
9 changes: 7 additions & 2 deletions packages/credential-provider-imds/src/fromInstanceMetadata.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { CredentialsProviderError } from "@aws-sdk/property-provider";
import { CredentialProvider, Credentials } from "@aws-sdk/types";
import { Credentials, Provider } from "@aws-sdk/types";
import { RequestOptions } from "http";

import { httpRequest } from "./remoteProvider/httpRequest";
import { fromImdsCredentials, isImdsCredentials } from "./remoteProvider/ImdsCredentials";
import { providerConfigFromInit, RemoteProviderInit } from "./remoteProvider/RemoteProviderInit";
import { retry } from "./remoteProvider/retry";
import { InstanceMetadataCredentials } from "./types";
import { getInstanceMetadataEndpoint } from "./utils/getInstanceMetadataEndpoint";
import { staticStabilityProvider } from "./utils/staticStabilityProvider";

const IMDS_PATH = "/latest/meta-data/iam/security-credentials/";
const IMDS_TOKEN_PATH = "/latest/api/token";
Expand All @@ -15,7 +17,10 @@ const IMDS_TOKEN_PATH = "/latest/api/token";
* Creates a credential provider that will source credentials from the EC2
* Instance Metadata Service
*/
export const fromInstanceMetadata = (init: RemoteProviderInit = {}): CredentialProvider => {
export const fromInstanceMetadata = (init: RemoteProviderInit = {}): Provider<InstanceMetadataCredentials> =>
staticStabilityProvider(getInstanceImdsProvider(init));

const getInstanceImdsProvider = (init: RemoteProviderInit) => {
// when set to true, metadata service will not fetch token
let disableFetchToken = false;
const { timeout, maxRetries } = providerConfigFromInit(init);
Expand Down
1 change: 1 addition & 0 deletions packages/credential-provider-imds/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export * from "./fromContainerMetadata";
export * from "./fromInstanceMetadata";
export * from "./remoteProvider/RemoteProviderInit";
export * from "./types";
export { httpRequest } from "./remoteProvider/httpRequest";
export { getInstanceMetadataEndpoint } from "./utils/getInstanceMetadataEndpoint";
5 changes: 5 additions & 0 deletions packages/credential-provider-imds/src/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import { Credentials } from "@aws-sdk/types";

export interface InstanceMetadataCredentials extends Credentials {
readonly originalExpiration?: Date;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { getExtendedInstanceMetadataCredentials } from "./getExtendedInstanceMetadataCredentials";

describe("getExtendedInstanceMetadataCredentials()", () => {
let nowMock: jest.SpyInstance;
const staticSecret = {
accessKeyId: "key",
secretAccessKey: "secret",
};

beforeEach(() => {
jest.spyOn(global.console, "warn").mockImplementation(() => {});
jest.spyOn(global.Math, "random");
nowMock = jest.spyOn(Date, "now").mockReturnValueOnce(new Date("2022-02-22T00:00:00Z").getTime());
});

afterEach(() => {
nowMock.mockRestore();
});

it("should extend the expiration random time(~15 mins) from now", () => {
const anyDate: Date = "any date" as unknown as Date;
(Math.random as jest.Mock).mockReturnValue(0.5);
expect(getExtendedInstanceMetadataCredentials({ ...staticSecret, expiration: anyDate })).toEqual({
...staticSecret,
originalExpiration: anyDate,
expiration: new Date("2022-02-22T00:17:30Z"),
});
expect(Math.random).toBeCalledTimes(1);
});

it("should print warning message when extending the credentials", () => {
const anyDate: Date = "any date" as unknown as Date;
getExtendedInstanceMetadataCredentials({ ...staticSecret, expiration: anyDate });
// TODO: fill the doc link
expect(console.warn).toBeCalledWith(expect.stringContaining("Attempting credential expiration extension"));
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { InstanceMetadataCredentials } from "../types";

const STATIC_STABILITY_REFRESH_INTERVAL_SECONDS = 15 * 60;
const STATIC_STABILITY_REFRESH_INTERVAL_JITTER_WINDOW_SECONDS = 5 * 60;
// TODO
const STATIC_STABILITY_DOC_URL = "https://docs.aws.amazon.com/sdkref/latest/guide/feature-static-credentials.html";

export const getExtendedInstanceMetadataCredentials = (
credentials: InstanceMetadataCredentials
): InstanceMetadataCredentials => {
const refreshInterval =
STATIC_STABILITY_REFRESH_INTERVAL_SECONDS +
Math.floor(Math.random() * STATIC_STABILITY_REFRESH_INTERVAL_JITTER_WINDOW_SECONDS);
const newExpiration = new Date(Date.now() + refreshInterval * 1000);
// ToDo: Call warn function on logger from configuration
console.warn(
"Attempting credential expiration extension due to a credential service availability issue. A refresh of these " +
"credentials will be attempted after ${new Date(newExpiration)}.\nFor more information, please visit: " +
STATIC_STABILITY_DOC_URL
);
const originalExpiration = credentials.originalExpiration ?? credentials.expiration;
return {
...credentials,
...(originalExpiration ? { originalExpiration } : {}),
expiration: newExpiration,
};
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import { getExtendedInstanceMetadataCredentials } from "./getExtendedInstanceMetadataCredentials";
import { staticStabilityProvider } from "./staticStabilityProvider";

jest.mock("./getExtendedInstanceMetadataCredentials");

describe("staticStabilityProvider", () => {
const ONE_HOUR_IN_FUTURE = new Date(Date.now() + 60 * 60 * 1000);
const mockCreds = {
accessKeyId: "key",
secretAccessKey: "secret",
sessionToken: "settion",
expiration: ONE_HOUR_IN_FUTURE,
};

beforeEach(() => {
(getExtendedInstanceMetadataCredentials as jest.Mock).mockImplementation(
(() => {
let extensionCount = 0;
return (input) => {
extensionCount++;
return {
...input,
expiration: `Extending expiration count: ${extensionCount}`,
};
};
})()
);
jest.spyOn(global.console, "warn").mockImplementation(() => {});
});

afterEach(() => {
jest.resetAllMocks();
});

it("should refresh credentials if provider is functional", async () => {
const provider = jest.fn();
const stableProvider = staticStabilityProvider(provider);
const repeat = 3;
for (let i = 0; i < repeat; i++) {
const newCreds = { ...mockCreds, accessKeyId: String(i + 1) };
provider.mockReset().mockResolvedValue(newCreds);
expect(await stableProvider()).toEqual(newCreds);
}
});

it("should throw if cannot load credentials at 1st load", async () => {
const provider = jest.fn().mockRejectedValue("Error");
try {
await staticStabilityProvider(provider)();
fail("This provider should throw");
} catch (e) {
expect(getExtendedInstanceMetadataCredentials).not.toBeCalled();
expect(provider).toBeCalledTimes(1);
expect(e).toEqual("Error");
}
});

it("should extend expired credentials if refresh fails", async () => {
const provider = jest.fn().mockResolvedValueOnce(mockCreds).mockRejectedValue("Error");
const stableProvider = staticStabilityProvider(provider);
expect(await stableProvider()).toEqual(mockCreds);
const repeat = 3;
for (let i = 0; i < repeat; i++) {
const newCreds = await stableProvider();
expect(newCreds).toMatchObject({ ...mockCreds, expiration: expect.stringContaining(`count: ${i + 1}`) });
expect(console.warn).toHaveBeenLastCalledWith(
expect.stringContaining("Credential renew failed:"),
expect.anything()
);
}
expect(getExtendedInstanceMetadataCredentials).toBeCalledTimes(repeat);
expect(console.warn).toBeCalledTimes(repeat);
});

it("should extend expired credentials if loaded expired credentials", async () => {
const ONE_HOUR_AGO = new Date(Date.now() - 60 * 60 * 1000);
const provider = jest.fn().mockResolvedValue({ ...mockCreds, expiration: ONE_HOUR_AGO });
const stableProvider = staticStabilityProvider(provider);
const repeat = 3;
for (let i = 0; i < repeat; i++) {
const newCreds = await stableProvider();
expect(newCreds).toMatchObject({ ...mockCreds, expiration: expect.stringContaining(`count: ${i + 1}`) });
}
expect(getExtendedInstanceMetadataCredentials).toBeCalledTimes(repeat);
expect(console.warn).not.toBeCalled();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { Credentials, Provider } from "@aws-sdk/types";

import { InstanceMetadataCredentials } from "../types";
import { getExtendedInstanceMetadataCredentials } from "./getExtendedInstanceMetadataCredentials";

/**
* IMDS credential supports static stability feature. When used, the expiration
* of recently issued credentials is extended. The server side allows using
* the recently expired credentials. This mitigates impact when clients using
* refreshable credentials are unable to retrieve updates.
*
* @param provider Credential provider
* @returns A credential provider that supports static stability
*/
export const staticStabilityProvider = (
provider: Provider<InstanceMetadataCredentials>
): Provider<InstanceMetadataCredentials> => {
let pastCredentials: InstanceMetadataCredentials;
return async () => {
let credentials: InstanceMetadataCredentials;
try {
credentials = await provider();
if (credentials.expiration && credentials.expiration.getTime() < Date.now()) {
credentials = getExtendedInstanceMetadataCredentials(credentials);
}
} catch (e) {
if (pastCredentials) {
// ToDo: Call warn function on logger from configuration
console.warn("Credential renew failed: ", e);
credentials = getExtendedInstanceMetadataCredentials(pastCredentials);
} else {
throw e;
}
}
pastCredentials = credentials;
return credentials;
};
};

0 comments on commit a4beeba

Please sign in to comment.