Skip to content
This repository has been archived by the owner on Apr 13, 2023. It is now read-only.

Fix React.createContext in SSR #2304

Merged
merged 12 commits into from
Sep 27, 2018

Conversation

emmatown
Copy link
Contributor

This fixes React.createContext so that the provider value is reset to the previous value it had after its children are walked.

Checklist:

  • If this PR is a new feature, please reference an issue where a consensus about the design was reached (not necessary for small changes)
  • Make sure all of the significant new logic is covered by tests
  • If this was a change that affects the external API used in GitHunt-React, update GitHunt-React and post a link to the PR in the discussion.

Closes #2291
Closes #2135
Closes #2139

diff --git a/src/getDataFromTree.ts b/src/getDataFromTree.ts
index 8008f09..cc731fb 100755
--- a/src/getDataFromTree.ts
+++ b/src/getDataFromTree.ts
@@ -140,8 +140,14 @@ export function walkTree(
         return;
       }

+      let isProvider = !!(element.type as any)._context;
+      let previousContextValue: any;
+      if (isProvider) {
+        previousContextValue = ((element.type as any)._context as any)._currentValue;
+      }
+
       let child;
-      if ((element.type as any)._context) {
+      if (isProvider) {
         // A provider - sets the context value before rendering children
         ((element.type as any)._context as any)._currentValue = element.props.value;
         child = element.props.children;
@@ -157,6 +163,9 @@ export function walkTree(
           walkTree(child, context, visitor);
         }
       }
+      if (isProvider) {
+        ((element.type as any)._context as any)._currentValue = previousContextValue;
+      }
     } else {
       // A basic string or dom element, just get children
       if (visitor(element, null, context) === false) {
diff --git a/test/server/server.test.tsx b/test/server/server.test.tsx
index 9b2c93c..96bd196 100644
--- a/test/server/server.test.tsx
+++ b/test/server/server.test.tsx
@@ -262,5 +262,32 @@ describe('SSR', () => {
         expect(markup).toMatch(/Tatooine/);
       });
     });
+    it('should work with React.createContext', async () => {
+      let defaultValue = 'default';
+      let Context = React.createContext(defaultValue);
+
+      let providerValue = 'provider';
+
+      expect(
+        await renderToStringWithData(
+          <React.Fragment>
+            <Context.Provider value={providerValue} />
+            <Context.Consumer>{val => val}</Context.Consumer>
+          </React.Fragment>,
+        ),
+      ).toBe('default');
+
+      expect(
+        await renderToStringWithData(
+          <Context.Provider value={providerValue}>
+            <Context.Consumer>{val => val}</Context.Consumer>
+          </Context.Provider>,
+        ),
+      ).toBe(providerValue);
+
+      expect(await renderToStringWithData(<Context.Consumer>{val => val}</Context.Consumer>)).toBe(
+        defaultValue,
+      );
+    });
   });
 });
@apollo-cla
Copy link

@mitchellhamilton: Thank you for submitting a pull request! Before we can merge it, you'll need to sign the Meteor Contributor Agreement here: https://contribute.meteor.com/

@tgriesser
Copy link
Contributor

Tried this out locally and it seems like a good start but I ran into some other issues with losing the context further down in the tree, I think I have a fix that takes care of it - will submit a little later / tomorrow.

diff --git a/src/getDataFromTree.ts b/src/getDataFromTree.ts
index cc731fb..01c8bc1 100755
--- a/src/getDataFromTree.ts
+++ b/src/getDataFromTree.ts
@@ -6,7 +6,8 @@ export interface Context {

 interface PromiseTreeArgument {
   rootElement: React.ReactNode;
-  rootContext?: Context;
+  rootContext: Context;
+  rootNewContext: Map<any, any>;
 }
 interface FetchComponent extends React.Component<any> {
   fetchData(): Promise<void>;
@@ -16,6 +17,7 @@ interface PromiseTreeResult {
   promise: Promise<any>;
   context: Context;
   instance: FetchComponent;
+  newContext: Map<any, any>;
 }

 interface PreactElement<P> {
@@ -49,12 +51,14 @@ export function walkTree(
   visitor: (
     element: React.ReactNode,
     instance: React.Component<any> | null,
+    newContextMap: Map<any, any>,
     context: Context,
     childContext?: Context,
   ) => boolean | void,
+  newContext: Map<any, any>,
 ) {
   if (Array.isArray(element)) {
-    element.forEach(item => walkTree(item, context, visitor));
+    element.forEach(item => walkTree(item, context, visitor, newContext));
     return;
   }

@@ -113,14 +117,14 @@ export function walkTree(
           childContext = Object.assign({}, context, instance.getChildContext());
         }

-        if (visitor(element, instance, context, childContext) === false) {
+        if (visitor(element, instance, newContext, context, childContext) === false) {
           return;
         }

         child = instance.render();
       } else {
         // Just a stateless functional
-        if (visitor(element, null, context) === false) {
+        if (visitor(element, null, newContext, context) === false) {
           return;
         }

@@ -129,60 +133,55 @@ export function walkTree(

       if (child) {
         if (Array.isArray(child)) {
-          child.forEach(item => walkTree(item, childContext, visitor));
+          child.forEach(item => walkTree(item, childContext, visitor, newContext));
         } else {
-          walkTree(child, childContext, visitor);
+          walkTree(child, childContext, visitor, newContext);
         }
       }
     } else if ((element.type as any)._context || (element.type as any).Consumer) {
       // A React context provider or consumer
-      if (visitor(element, null, context) === false) {
+      if (visitor(element, null, newContext, context) === false) {
         return;
       }

-      let isProvider = !!(element.type as any)._context;
-      let previousContextValue: any;
-      if (isProvider) {
-        previousContextValue = ((element.type as any)._context as any)._currentValue;
-      }
-
       let child;
-      if (isProvider) {
+      if (!!(element.type as any)._context) {
         // A provider - sets the context value before rendering children
-        ((element.type as any)._context as any)._currentValue = element.props.value;
+        // this needs to clone the map because this value should only apply to children of the provider
+        newContext = new Map(newContext.entries());
+        newContext.set(element.type, element.props.value);
         child = element.props.children;
       } else {
         // A consumer
-        child = element.props.children((element.type as any)._currentValue);
+        child = element.props.children(
+          newContext.get((element.type as any).Provider) || (element.type as any)._currentValue,
+        );
       }

       if (child) {
         if (Array.isArray(child)) {
-          child.forEach(item => walkTree(item, context, visitor));
+          child.forEach(item => walkTree(item, context, visitor, newContext));
         } else {
-          walkTree(child, context, visitor);
+          walkTree(child, context, visitor, newContext);
         }
       }
-      if (isProvider) {
-        ((element.type as any)._context as any)._currentValue = previousContextValue;
-      }
     } else {
       // A basic string or dom element, just get children
-      if (visitor(element, null, context) === false) {
+      if (visitor(element, null, newContext, context) === false) {
         return;
       }

       if (element.props && element.props.children) {
         React.Children.forEach(element.props.children, (child: any) => {
           if (child) {
-            walkTree(child, context, visitor);
+            walkTree(child, context, visitor, newContext);
           }
         });
       }
     }
   } else if (typeof element === 'string' || typeof element === 'number') {
     // Just visit these, they are leaves so we don't keep traversing.
-    visitor(element, null, context);
+    visitor(element, null, newContext, context);
   }
   // TODO: Portals?
 }
@@ -197,37 +196,49 @@ function isPromise<T>(promise: Object): promise is Promise<T> {

 function getPromisesFromTree({
   rootElement,
-  rootContext = {},
+  rootContext,
+  rootNewContext,
 }: PromiseTreeArgument): PromiseTreeResult[] {
   const promises: PromiseTreeResult[] = [];

-  walkTree(rootElement, rootContext, (_, instance, context, childContext) => {
-    if (instance && hasFetchDataFunction(instance)) {
-      const promise = instance.fetchData();
-      if (isPromise<Object>(promise)) {
-        promises.push({ promise, context: childContext || context, instance });
-        return false;
+  walkTree(
+    rootElement,
+    rootContext,
+    (_, instance, newContext, context, childContext) => {
+      if (instance && hasFetchDataFunction(instance)) {
+        const promise = instance.fetchData();
+        if (isPromise<Object>(promise)) {
+          promises.push({
+            promise,
+            context: childContext || context,
+            instance,
+            newContext,
+          });
+          return false;
+        }
       }
-    }
-  });
+    },
+    rootNewContext,
+  );

   return promises;
 }

 function getDataAndErrorsFromTree(
   rootElement: React.ReactNode,
-  rootContext: any = {},
+  rootContext: Object,
   storeError: Function,
+  rootNewContext: Map<any, any> = new Map(),
 ): Promise<any> {
-  const promises = getPromisesFromTree({ rootElement, rootContext });
+  const promises = getPromisesFromTree({ rootElement, rootContext, rootNewContext });

   if (!promises.length) {
     return Promise.resolve();
   }

-  const mappedPromises = promises.map(({ promise, context, instance }) => {
+  const mappedPromises = promises.map(({ promise, context, instance, newContext }) => {
     return promise
-      .then(_ => getDataAndErrorsFromTree(instance.render(), context, storeError))
+      .then(_ => getDataAndErrorsFromTree(instance.render(), context, storeError, newContext))
       .catch(e => storeError(e));
   });

diff --git a/test/server/server.test.tsx b/test/server/server.test.tsx
index 1082872..907d6f5 100644
--- a/test/server/server.test.tsx
+++ b/test/server/server.test.tsx
@@ -11,97 +11,97 @@ import {
   GraphQLID,
   DocumentNode,
 } from 'graphql';
-import { graphql, ApolloProvider, renderToStringWithData, ChildProps } from '../../src';
+import { graphql, ApolloProvider, renderToStringWithData, ChildProps, Query } from '../../src';
 import gql from 'graphql-tag';
 import { InMemoryCache as Cache } from 'apollo-cache-inmemory';

-describe('SSR', () => {
-  describe('`renderToStringWithData`', () => {
-    // XXX break into smaller tests
-    // XXX mock all queries
-    it('should work on a non trivial example', function() {
-      const planetMap = new Map([['Planet:1', { id: 'Planet:1', name: 'Tatooine' }]]);
+const planetMap = new Map([['Planet:1', { id: 'Planet:1', name: 'Tatooine' }]]);

-      const shipMap = new Map([
-        [
-          'Ship:2',
-          {
-            id: 'Ship:2',
-            name: 'CR90 corvette',
-            films: ['Film:4', 'Film:6', 'Film:3'],
-          },
-        ],
-        [
-          'Ship:3',
-          {
-            id: 'Ship:3',
-            name: 'Star Destroyer',
-            films: ['Film:4', 'Film:5', 'Film:6'],
-          },
-        ],
-      ]);
+const shipMap = new Map([
+  [
+    'Ship:2',
+    {
+      id: 'Ship:2',
+      name: 'CR90 corvette',
+      films: ['Film:4', 'Film:6', 'Film:3'],
+    },
+  ],
+  [
+    'Ship:3',
+    {
+      id: 'Ship:3',
+      name: 'Star Destroyer',
+      films: ['Film:4', 'Film:5', 'Film:6'],
+    },
+  ],
+]);

-      const filmMap = new Map([
-        ['Film:3', { id: 'Film:3', title: 'Revenge of the Sith' }],
-        ['Film:4', { id: 'Film:4', title: 'A New Hope' }],
-        ['Film:5', { id: 'Film:5', title: 'the Empire Strikes Back' }],
-        ['Film:6', { id: 'Film:6', title: 'Return of the Jedi' }],
-      ]);
+const filmMap = new Map([
+  ['Film:3', { id: 'Film:3', title: 'Revenge of the Sith' }],
+  ['Film:4', { id: 'Film:4', title: 'A New Hope' }],
+  ['Film:5', { id: 'Film:5', title: 'the Empire Strikes Back' }],
+  ['Film:6', { id: 'Film:6', title: 'Return of the Jedi' }],
+]);

-      const PlanetType = new GraphQLObjectType({
-        name: 'Planet',
-        fields: {
-          id: { type: GraphQLID },
-          name: { type: GraphQLString },
-        },
-      });
+const PlanetType = new GraphQLObjectType({
+  name: 'Planet',
+  fields: {
+    id: { type: GraphQLID },
+    name: { type: GraphQLString },
+  },
+});

-      const FilmType = new GraphQLObjectType({
-        name: 'Film',
-        fields: {
-          id: { type: GraphQLID },
-          title: { type: GraphQLString },
-        },
-      });
+const FilmType = new GraphQLObjectType({
+  name: 'Film',
+  fields: {
+    id: { type: GraphQLID },
+    title: { type: GraphQLString },
+  },
+});

-      const ShipType = new GraphQLObjectType({
-        name: 'Ship',
-        fields: {
-          id: { type: GraphQLID },
-          name: { type: GraphQLString },
-          films: {
-            type: new GraphQLList(FilmType),
-            resolve: ({ films }) => films.map((id: string) => filmMap.get(id)),
-          },
-        },
-      });
+const ShipType = new GraphQLObjectType({
+  name: 'Ship',
+  fields: {
+    id: { type: GraphQLID },
+    name: { type: GraphQLString },
+    films: {
+      type: new GraphQLList(FilmType),
+      resolve: ({ films }) => films.map((id: string) => filmMap.get(id)),
+    },
+  },
+});

-      const QueryType = new GraphQLObjectType({
-        name: 'Query',
-        fields: {
-          allPlanets: {
-            type: new GraphQLList(PlanetType),
-            resolve: () => Array.from(planetMap.values()),
-          },
-          allShips: {
-            type: new GraphQLList(ShipType),
-            resolve: () => Array.from(shipMap.values()),
-          },
-          ship: {
-            type: ShipType,
-            args: { id: { type: GraphQLID } },
-            resolve: (_, { id }) => shipMap.get(id),
-          },
-          film: {
-            type: FilmType,
-            args: { id: { type: GraphQLID } },
-            resolve: (_, { id }) => filmMap.get(id),
-          },
-        },
-      });
+const QueryType = new GraphQLObjectType({
+  name: 'Query',
+  fields: {
+    allPlanets: {
+      type: new GraphQLList(PlanetType),
+      resolve: () => Array.from(planetMap.values()),
+    },
+    allShips: {
+      type: new GraphQLList(ShipType),
+      resolve: () => Array.from(shipMap.values()),
+    },
+    ship: {
+      type: ShipType,
+      args: { id: { type: GraphQLID } },
+      resolve: (_, { id }) => shipMap.get(id),
+    },
+    film: {
+      type: FilmType,
+      args: { id: { type: GraphQLID } },
+      resolve: (_, { id }) => filmMap.get(id),
+    },
+  },
+});

-      const Schema = new GraphQLSchema({ query: QueryType });
+const Schema = new GraphQLSchema({ query: QueryType });

+describe('SSR', () => {
+  describe('`renderToStringWithData`', () => {
+    // XXX break into smaller tests
+    // XXX mock all queries
+    it('should work on a non trivial example', function() {
       const apolloClient = new ApolloClient({
         link: new ApolloLink(config => {
           return new Observable(observer => {
@@ -305,6 +305,49 @@ describe('SSR', () => {
           </Context.Consumer>,
         ),
       ).toBe(defaultValue);
+
+      const apolloClient = new ApolloClient({
+        link: new ApolloLink(config => {
+          return new Observable(observer => {
+            execute(Schema, print(config.query), null, null, config.variables, config.operationName)
+              .then(result => {
+                observer.next(result);
+                observer.complete();
+              })
+              .catch(e => {
+                observer.error(e);
+              });
+          });
+        }),
+        cache: new Cache(),
+      });
+
+      expect(
+        await renderToStringWithData(
+          <ApolloProvider client={apolloClient}>
+            <Context.Provider value={providerValue}>
+              <Query
+                query={gql`
+                  query ShipIds {
+                    allShips {
+                      id
+                    }
+                  }
+                `}
+              >
+                {() => (
+                  <Context.Consumer>
+                    {val => {
+                      expect(val).toBe(providerValue);
+                      return val;
+                    }}
+                  </Context.Consumer>
+                )}
+              </Query>
+            </Context.Provider>
+          </ApolloProvider>,
+        ),
+      ).toBe(providerValue);
     }
   });
 });
child = element.props.children;
} else {
// A consumer
child = element.props.children((element.type as any)._currentValue);
child = element.props.children(
newContext.get((element.type as any).Provider) || (element.type as any)._currentValue,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to be more specific with the check here:

let currentValue = (element.type as any)._currentValue;
if (childContext.subContexts.has((element.type as any).Provider)) {
  currentValue = childContext.subContexts.get((element.type as any).Provider);
}
child = element.props.children(currentValue);

In theory, someone could set the value of context to a falsy value, and that would be the value that should be respected. From the React docs:

The defaultValue argument is only used by a Consumer when it does not have a matching Provider above it in the tree. This can be helpful for testing components in isolation without wrapping them. Note: passing undefined as a Provider value does not cause Consumers to use defaultValue.

diff --git a/src/getDataFromTree.ts b/src/getDataFromTree.ts
index 01c8bc1..74eeb8d 100755
--- a/src/getDataFromTree.ts
+++ b/src/getDataFromTree.ts
@@ -148,14 +148,16 @@ export function walkTree(
       if (!!(element.type as any)._context) {
         // A provider - sets the context value before rendering children
         // this needs to clone the map because this value should only apply to children of the provider
-        newContext = new Map(newContext.entries());
+        newContext = new Map(newContext);
         newContext.set(element.type, element.props.value);
         child = element.props.children;
       } else {
         // A consumer
-        child = element.props.children(
-          newContext.get((element.type as any).Provider) || (element.type as any)._currentValue,
-        );
+        let value = (element.type as any)._currentValue;
+        if (newContext.has((element.type as any).Provider)) {
+          value = newContext.get((element.type as any).Provider);
+        }
+        child = element.props.children(value);
       }

       if (child) {
diff --git a/test/server/server.test.tsx b/test/server/server.test.tsx
index 907d6f5..8b73f9a 100644
--- a/test/server/server.test.tsx
+++ b/test/server/server.test.tsx
@@ -305,6 +305,20 @@ describe('SSR', () => {
           </Context.Consumer>,
         ),
       ).toBe(defaultValue);
+      let ContextForUndefined = React.createContext<void | string>(defaultValue);
+
+      expect(
+        await renderToStringWithData(
+          <ContextForUndefined.Provider value={undefined}>
+            <ContextForUndefined.Consumer>
+              {val => {
+                expect(val).toBeUndefined();
+                return val === undefined ? 'works' : 'broken';
+              }}
+            </ContextForUndefined.Consumer>
+          </ContextForUndefined.Provider>,
+        ),
+      ).toBe('works');

       const apolloClient = new ApolloClient({
         link: new ApolloLink(config => {
@tadeuszwojcik
Copy link

Thanks @mitchellhamilton , great work. Is there a change to merge it soonish? Or it's still blocked by something?

@emmatown
Copy link
Contributor Author

@tadeuszwojcik Just waiting on the apollo team :)

@hwillson hwillson self-assigned this Sep 26, 2018
Copy link
Member

@hwillson hwillson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this @mitchellhamilton - LGTM!

@hwillson hwillson merged commit 17704e1 into apollographql:master Sep 27, 2018
@emmatown emmatown deleted the fix-react-create-context branch September 27, 2018 11:15
@tadeuszwojcik
Copy link

Thanks @hwillson ! Any chance to publish that fix to npm soon? Could alpha/next tag? Thanks!

@hwillson
Copy link
Member

@tadeuszwojcik I'm pushing out a new release this morning, that will include this.

@tadeuszwojcik
Copy link

@hwillson that's great, thank you!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
5 participants