Skip to content

Commit

Permalink
Merge pull request #331 from chromaui/312-api-channel-proxy
Browse files Browse the repository at this point in the history
Relay client-side fetch requests to the server using the Storybook channel API
  • Loading branch information
ghengeveld authored Sep 6, 2024
2 parents b781bdf + 67b13f2 commit f0c702f
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 72 deletions.
10 changes: 5 additions & 5 deletions src/Panel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import {
IS_OFFLINE,
IS_OUTDATED,
LOCAL_BUILD_PROGRESS,
PANEL_ID,
REMOVE_ADDON,
TELEMETRY,
} from "./constants";
Expand All @@ -28,9 +27,10 @@ import { ControlsProvider } from "./screens/VisualTests/ControlsContext";
import { RunBuildProvider } from "./screens/VisualTests/RunBuildContext";
import { VisualTests } from "./screens/VisualTests/VisualTests";
import { GitInfoPayload, LocalBuildProgress, UpdateStatusFunction } from "./types";
import { client, Provider, useAccessToken } from "./utils/graphQLClient";
import { createClient, GraphQLClientProvider, useAccessToken } from "./utils/graphQLClient";
import { TelemetryProvider } from "./utils/TelemetryContext";
import { useBuildEvents } from "./utils/useBuildEvents";
import { useChannelFetch } from "./utils/useChannelFetch";
import { useProjectId } from "./utils/useProjectId";
import { clearSessionState, useSessionState } from "./utils/useSessionState";
import { useSharedState } from "./utils/useSharedState";
Expand Down Expand Up @@ -93,8 +93,9 @@ export const Panel = ({ active, api }: PanelProps) => {
const trackEvent = useCallback((data: any) => emit(TELEMETRY, data), [emit]);
const { isRunning, startBuild, stopBuild } = useBuildEvents({ localBuildProgress, accessToken });

const fetch = useChannelFetch();
const withProviders = (children: React.ReactNode) => (
<Provider key={PANEL_ID} value={client}>
<GraphQLClientProvider value={createClient({ fetch })}>
<TelemetryProvider value={trackEvent}>
<AuthProvider value={{ accessToken, setAccessToken }}>
<UninstallProvider
Expand All @@ -111,7 +112,7 @@ export const Panel = ({ active, api }: PanelProps) => {
</UninstallProvider>
</AuthProvider>
</TelemetryProvider>
</Provider>
</GraphQLClientProvider>
);

if (!active) {
Expand All @@ -134,7 +135,6 @@ export const Panel = ({ active, api }: PanelProps) => {
if (!accessToken) {
return withProviders(
<Authentication
key={PANEL_ID}
setAccessToken={setAccessToken}
setCreatedProjectId={setCreatedProjectId}
hasProjectId={!!projectId}
Expand Down
4 changes: 4 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ export const ENABLE_FILTER = `${ADDON_ID}/enableFilter`;
export const REMOVE_ADDON = `${ADDON_ID}/removeAddon`;
export const PARAM_KEY = "chromatic";

export const FETCH_ABORTED = `${ADDON_ID}/ChannelFetch/aborted`;
export const FETCH_REQUEST = `${ADDON_ID}ChannelFetch/request`;
export const FETCH_RESPONSE = `${ADDON_ID}ChannelFetch/response`;

export const CONFIG_OVERRIDES = {
// Local changes should never be auto-accepted
autoAcceptChanges: false,
Expand Down
4 changes: 4 additions & 0 deletions src/preset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import {
LocalBuildProgress,
ProjectInfoPayload,
} from "./types";
import { ChannelFetch } from "./utils/ChannelFetch";
import { SharedState } from "./utils/SharedState";
import { updateChromaticConfig } from "./utils/updateChromaticConfig";

Expand Down Expand Up @@ -160,6 +161,9 @@ const watchConfigFile = async (
async function serverChannel(channel: Channel, options: Options & { configFile?: string }) {
const { configFile, presets } = options;

// Handle relayed fetch requests from the client
ChannelFetch.subscribe(ADDON_ID, channel);

// Lazy load these APIs since we don't need them right away
const apiPromise = presets.apply<any>("experimental_serverAPI");
const corePromise = presets.apply("core");
Expand Down
96 changes: 96 additions & 0 deletions src/utils/ChannelFetch.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import { beforeEach, describe, expect, it, vi } from "vitest";

import { FETCH_ABORTED, FETCH_REQUEST, FETCH_RESPONSE } from "../constants";
import { ChannelFetch } from "./ChannelFetch";
import { MockChannel } from "./MockChannel";

const resolveAfter = (ms: number, value: any) =>
new Promise((resolve) => setTimeout(resolve, ms, value));

const rejectAfter = (ms: number, reason: any) =>
new Promise((_, reject) => setTimeout(reject, ms, reason));

describe("ChannelFetch", () => {
let channel: MockChannel;

beforeEach(() => {
channel = new MockChannel();
});

it("should handle fetch requests", async () => {
const fetch = vi.fn(() => resolveAfter(100, { headers: [], text: async () => "data" }));
ChannelFetch.subscribe("req", channel, fetch as any);

channel.emit(FETCH_REQUEST, {
requestId: "req",
input: "https://example.com",
init: { headers: { foo: "bar" } },
});

await vi.waitFor(() => {
expect(fetch).toHaveBeenCalledWith("https://example.com", {
headers: { foo: "bar" },
signal: expect.any(AbortSignal),
});
});
});

it("should send fetch responses", async () => {
const fetch = vi.fn(() => resolveAfter(100, { headers: [], text: async () => "data" }));
const instance = ChannelFetch.subscribe("res", channel, fetch as any);

const promise = new Promise<void>((resolve) => {
channel.on(FETCH_RESPONSE, ({ response, error }) => {
expect(response.body).toBe("data");
expect(error).toBeUndefined();
resolve();
});
});

channel.emit(FETCH_REQUEST, { requestId: "res", input: "https://example.com" });
await vi.waitFor(() => {
expect(instance.abortControllers.size).toBe(1);
});

await promise;

expect(instance.abortControllers.size).toBe(0);
});

it("should send fetch error responses", async () => {
const fetch = vi.fn(() => rejectAfter(100, new Error("oops")));
const instance = ChannelFetch.subscribe("err", channel, fetch as any);

const promise = new Promise<void>((resolve) => {
channel.on(FETCH_RESPONSE, ({ response, error }) => {
expect(response).toBeUndefined();
expect(error).toMatch(/oops/);
resolve();
});
});

channel.emit(FETCH_REQUEST, { requestId: "err", input: "https://example.com" });
await vi.waitFor(() => {
expect(instance.abortControllers.size).toBe(1);
});

await promise;
expect(instance.abortControllers.size).toBe(0);
});

it("should abort fetch requests", async () => {
const fetch = vi.fn((input, init) => new Promise<Response>(() => {}));
const instance = ChannelFetch.subscribe("abort", channel, fetch);

channel.emit(FETCH_REQUEST, { requestId: "abort", input: "https://example.com" });
await vi.waitFor(() => {
expect(instance.abortControllers.size).toBe(1);
});

channel.emit(FETCH_ABORTED, { requestId: "abort" });
await vi.waitFor(() => {
expect(fetch.mock.lastCall?.[1].signal.aborted).toBe(true);
expect(instance.abortControllers.size).toBe(0);
});
});
});
47 changes: 47 additions & 0 deletions src/utils/ChannelFetch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import type { Channel } from "@storybook/channels";

import { FETCH_ABORTED, FETCH_REQUEST, FETCH_RESPONSE } from "../constants";

type ChannelLike = Pick<Channel, "emit" | "on" | "off">;

const instances = new Map<string, ChannelFetch>();

export class ChannelFetch {
channel: ChannelLike;

abortControllers: Map<string, AbortController>;

constructor(channel: ChannelLike, _fetch = fetch) {
this.channel = channel;
this.abortControllers = new Map<string, AbortController>();

this.channel.on(FETCH_ABORTED, ({ requestId }) => {
this.abortControllers.get(requestId)?.abort();
this.abortControllers.delete(requestId);
});

this.channel.on(FETCH_REQUEST, async ({ requestId, input, init }) => {
const controller = new AbortController();
this.abortControllers.set(requestId, controller);

try {
const res = await _fetch(input as RequestInfo, { ...init, signal: controller.signal });
const body = await res.text();
const headers = Array.from(res.headers as any);
const response = { body, headers, status: res.status, statusText: res.statusText };
this.channel.emit(FETCH_RESPONSE, { requestId, response });
} catch (err) {
const error = err instanceof Error ? err.message : String(err);
this.channel.emit(FETCH_RESPONSE, { requestId, error });
} finally {
this.abortControllers.delete(requestId);
}
});
}

static subscribe(key: string, channel: ChannelLike, _fetch = fetch) {
const instance = instances.get(key) || new ChannelFetch(channel, _fetch);
if (!instances.has(key)) instances.set(key, instance);
return instance;
}
}
16 changes: 16 additions & 0 deletions src/utils/MockChannel.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
export class MockChannel {
private listeners: Record<string, ((...args: any[]) => void)[]> = {};

on(event: string, listener: (...args: any[]) => void) {
this.listeners[event] = [...(this.listeners[event] ?? []), listener];
}

off(event: string, listener: (...args: any[]) => void) {
this.listeners[event] = (this.listeners[event] ?? []).filter((l) => l !== listener);
}

emit(event: string, ...args: any[]) {
// setTimeout is used to simulate the asynchronous nature of the real channel
(this.listeners[event] || []).forEach((listener) => setTimeout(() => listener(...args)));
}
}
18 changes: 1 addition & 17 deletions src/utils/SharedState.test.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,8 @@
import { beforeEach, describe, expect, it } from "vitest";

import { MockChannel } from "./MockChannel";
import { SharedState } from "./SharedState";

class MockChannel {
private listeners: Record<string, ((...args: any[]) => void)[]> = {};

on(event: string, listener: (...args: any[]) => void) {
this.listeners[event] = [...(this.listeners[event] ?? []), listener];
}

off(event: string, listener: (...args: any[]) => void) {
this.listeners[event] = (this.listeners[event] ?? []).filter((l) => l !== listener);
}

emit(event: string, ...args: any[]) {
// setTimeout is used to simulate the asynchronous nature of the real channel
(this.listeners[event] || []).forEach((listener) => setTimeout(() => listener(...args)));
}
}

const tick = () => new Promise((resolve) => setTimeout(resolve, 0));

describe("SharedState", () => {
Expand Down
104 changes: 54 additions & 50 deletions src/utils/graphQLClient.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { useAddonState } from "@storybook/manager-api";
import { authExchange } from "@urql/exchange-auth";
import React from "react";
import { Client, fetchExchange, mapExchange, Provider } from "urql";
import { Client, ClientOptions, fetchExchange, mapExchange, Provider } from "urql";
import { v4 as uuid } from "uuid";

import { ACCESS_TOKEN_KEY, ADDON_ID, CHROMATIC_API_URL } from "../constants";

export { Provider };

let currentToken: string | null;
let currentTokenExpiration: number | null;
const setCurrentToken = (token: string | null) => {
Expand Down Expand Up @@ -56,56 +54,62 @@ export const getFetchOptions = (token?: string) => ({
},
});

export const client = new Client({
url: CHROMATIC_API_URL,
exchanges: [
// We don't use cacheExchange, because it would inadvertently share data between stories.
mapExchange({
onResult(result) {
// Not all queries contain the `viewer` field, in which case it will be `undefined`.
// When we do retrieve the field but the token is invalid, it will be `null`.
if (result.data?.viewer === null) setCurrentToken(null);
},
}),
authExchange(async (utils) => {
return {
addAuthToOperation(operation) {
if (!currentToken) return operation;
return utils.appendHeaders(operation, { Authorization: `Bearer ${currentToken}` });
export const createClient = (options?: Partial<ClientOptions>) =>
new Client({
url: CHROMATIC_API_URL,
exchanges: [
// We don't use cacheExchange, because it would inadvertently share data between stories.
mapExchange({
onResult(result) {
// Not all queries contain the `viewer` field, in which case it will be `undefined`.
// When we do retrieve the field but the token is invalid, it will be `null`.
if (result.data?.viewer === null) setCurrentToken(null);
},
}),
authExchange(async (utils) => {
return {
addAuthToOperation(operation) {
if (!currentToken) return operation;
return utils.appendHeaders(operation, { Authorization: `Bearer ${currentToken}` });
},

// Determine if the current error is an authentication error.
didAuthError: (error) =>
error.response.status === 401 ||
error.graphQLErrors.some((e) => e.message.includes("Must login")),
// Determine if the current error is an authentication error.
didAuthError: (error) =>
error.response.status === 401 ||
error.graphQLErrors.some((e) => e.message.includes("Must login")),

// If didAuthError returns true, clear the token. Ideally we should refresh the token here.
// The operation will be retried automatically.
async refreshAuth() {
setCurrentToken(null);
},
// If didAuthError returns true, clear the token. Ideally we should refresh the token here.
// The operation will be retried automatically.
async refreshAuth() {
setCurrentToken(null);
},

// Prevent making a request if we know the token is missing, invalid or expired.
// This handler is called repeatedly so we avoid parsing the token each time.
willAuthError() {
if (!currentToken) return true;
try {
if (!currentTokenExpiration) {
const { exp } = JSON.parse(atob(currentToken.split(".")[1]));
currentTokenExpiration = exp;
// Prevent making a request if we know the token is missing, invalid or expired.
// This handler is called repeatedly so we avoid parsing the token each time.
willAuthError() {
if (!currentToken) return true;
try {
if (!currentTokenExpiration) {
const { exp } = JSON.parse(atob(currentToken.split(".")[1]));
currentTokenExpiration = exp;
}
return Date.now() / 1000 > (currentTokenExpiration || 0);
} catch (e) {
return true;
}
return Date.now() / 1000 > (currentTokenExpiration || 0);
} catch (e) {
return true;
}
},
};
}),
fetchExchange,
],
fetchOptions: getFetchOptions(), // Auth header (token) is handled by authExchange
});
},
};
}),
fetchExchange,
],
fetchOptions: getFetchOptions(), // Auth header (token) is handled by authExchange
...options,
});

export const GraphQLClientProvider = ({ children }: { children: React.ReactNode }) => {
return <Provider value={client}>{children}</Provider>;
};
export const GraphQLClientProvider = ({
children,
value = createClient(),
}: {
children: React.ReactNode;
value?: Client;
}) => <Provider value={value}>{children}</Provider>;
Loading

0 comments on commit f0c702f

Please sign in to comment.