diff --git a/packages/async-rewriter2/src/async-writer-babel.spec.ts b/packages/async-rewriter2/src/async-writer-babel.spec.ts index fc5c4847f..4a3527662 100644 --- a/packages/async-rewriter2/src/async-writer-babel.spec.ts +++ b/packages/async-rewriter2/src/async-writer-babel.spec.ts @@ -57,6 +57,24 @@ describe('AsyncWriter', function () { [Symbol.for('@@mongosh.uncatchable')]: true, }); }, + regularIterable: function* () { + yield* [1, 2, 3]; + }, + regularAsyncIterable: async function* () { + await Promise.resolve(); + yield* [1, 2, 3]; + }, + implicitlyAsyncIterable: function () { + return Object.assign( + (async function* () { + await Promise.resolve(); + yield* [1, 2, 3]; + })(), + { + [Symbol.for('@@mongosh.syntheticAsyncIterable')]: true, + } + ); + }, }); runTranspiledCode = (code: string, context?: any) => { const transpiled = asyncWriter.process(code); @@ -543,6 +561,44 @@ describe('AsyncWriter', function () { expect(await ret).to.equal('bar'); }); + context('for-of', function () { + it('can iterate over implicit iterables', async function () { + expect( + await runTranspiledCode(`(function() { + let sum = 0; + for (const value of implicitlyAsyncIterable()) + sum += value; + return sum; + })()`) + ).to.equal(6); + }); + + it('can iterate over implicit iterables in async functions', async function () { + expect( + await runTranspiledCode(`(async function() { + let sum = 0; + for (const value of implicitlyAsyncIterable()) + sum += value; + return sum; + })()`) + ).to.equal(6); + }); + + it('can implicitly yield* inside of async generator functions', async function () { + expect( + await runTranspiledCode(`(async function() { + const gen = (async function*() { + yield* implicitlyAsyncIterable(); + })(); + let sum = 0; + for await (const value of gen) + sum += value; + return sum; + })()`) + ).to.equal(6); + }); + }); + context('invalid implicit awaits', function () { beforeEach(function () { runUntranspiledCode(asyncWriter.runtimeSupportCode()); @@ -594,6 +650,45 @@ describe('AsyncWriter', function () { '[ASYNC-10012] Result of expression "compareFn(...args)" cannot be used in this context' ); }); + + context('for-of', function () { + it('cannot implicitly yield* inside of generator functions', function () { + expect(() => + runTranspiledCode(`(function() { + const gen = (function*() { + yield* implicitlyAsyncIterable(); + })(); + for (const value of gen) return value; + })()`) + ).to.throw( + '[ASYNC-10013] Result of expression "implicitlyAsyncIterable()" cannot be iterated in this context' + ); + }); + + it('cannot implicitly for-of inside of generator functions', function () { + expect(() => + runTranspiledCode(`(function() { + const gen = (function*() { + for (const item of implicitlyAsyncIterable()) yield item; + })(); + for (const value of gen) return value; + })()`) + ).to.throw( + '[ASYNC-10013] Result of expression "implicitlyAsyncIterable()" cannot be iterated in this context' + ); + }); + + it('cannot implicitly for-of await inside of class constructors', function () { + expect( + () => + runTranspiledCode(`class A { + constructor() { for (this.foo of implicitlyAsyncIterable()) {} } + }; new A()`).value + ).to.throw( + '[ASYNC-10013] Result of expression "implicitlyAsyncIterable()" cannot be iterated in this context' + ); + }); + }); }); }); @@ -1040,7 +1135,7 @@ describe('AsyncWriter', function () { runTranspiledCode( 'globalThis.abcdefghijklmnopqrstuvwxyz = {}; abcdefghijklmnopqrstuvwxyz()' ) - ).to.throw('abcdefghijklm ... uvwxyz is not a function'); + ).to.throw('abcdefghijklmn ... uvwxyz is not a function'); }); }); diff --git a/packages/async-rewriter2/src/error-codes.ts b/packages/async-rewriter2/src/error-codes.ts index 218b647e4..c3abd6714 100644 --- a/packages/async-rewriter2/src/error-codes.ts +++ b/packages/async-rewriter2/src/error-codes.ts @@ -24,6 +24,27 @@ enum AsyncRewriterErrors { * **Solution: Do not use calls directly in such functions. If necessary, place these calls in an inner 'async' function.** */ SyntheticPromiseInAlwaysSyncContext = 'ASYNC-10012', + /** + * Signals the iteration of a Mongosh API object in a place where it is not supported. + * This occurs inside of constructors and (non-async) generator functions. + * + * Examples causing error: + * ```javascript + * class SomeClass { + * constructor() { + * for (const item of db.coll.find()) { ... } + * } + * } + * + * function*() { + * for (const item of db.coll.find()) yield item; + * yield* db.coll.find(); + * } + * ``` + * + * **Solution: Do not use calls directly in such functions. If necessary, place these calls in an inner 'async' function.** + */ + SyntheticAsyncIterableInAlwaysSyncContext = 'ASYNC-10013', } export { AsyncRewriterErrors }; diff --git a/packages/async-rewriter2/src/stages/transform-maybe-await.ts b/packages/async-rewriter2/src/stages/transform-maybe-await.ts index 3ede65f21..b03c5d182 100644 --- a/packages/async-rewriter2/src/stages/transform-maybe-await.ts +++ b/packages/async-rewriter2/src/stages/transform-maybe-await.ts @@ -56,8 +56,9 @@ export default ({ // of helpers which are available inside the function. const identifierGroupKey = '@@mongosh.identifierGroup'; - const syntheticPromiseSymbolTemplate = babel.template.statement(` + const syntheticPromiseSymbolTemplate = babel.template.statements(` const SP_IDENTIFIER = Symbol.for("@@mongosh.syntheticPromise"); + const SAI_IDENTIFIER = Symbol.for("@@mongosh.syntheticAsyncIterable"); `); const markSyntheticPromiseTemplate = babel.template.statement(` @@ -75,19 +76,23 @@ export default ({ `); const assertNotSyntheticPromiseTemplate = babel.template.statement(` - function ANSP_IDENTIFIER(p, s) { + function ANSP_IDENTIFIER(p, s, i = false) { if (p && p[SP_IDENTIFIER]) { throw new CUSTOM_ERROR_BUILDER( 'Result of expression "' + s + '" cannot be used in this context', 'SyntheticPromiseInAlwaysSyncContext'); } + if (i && p && p[SAI_IDENTIFIER]) { + throw new CUSTOM_ERROR_BUILDER( + 'Result of expression "' + s + '" cannot be iterated in this context', + 'SyntheticAsyncIterableInAlwaysSyncContext'); + } return p; } `); const adaptAsyncIterableToSyncIterableTemplate = babel.template.statement(` function AAITSI_IDENTIFIER(original) { - const SAI_IDENTIFIER = Symbol.for("@@mongosh.syntheticAsyncIterable"); if (!original || !original[SAI_IDENTIFIER]) { return { iterable: original, isSyntheticAsyncIterable: false }; } @@ -169,10 +174,6 @@ export default ({ } ); - const assertNotSyntheticExpressionTemplate = babel.template.expression(` - ANSP_IDENTIFIER(NODE, ORIGINAL_SOURCE) - `); - const rethrowTemplate = babel.template.statement(` try { ORIGINAL_CODE; @@ -248,7 +249,7 @@ export default ({ node.start !== undefined ? file.code.slice(node.start ?? undefined, node.end ?? undefined) : '', - 24 + 25 ); if (!wrap) return t.stringLiteral(prettyOriginalString); @@ -349,11 +350,11 @@ export default ({ const commonHelpers = existingIdentifiers ? [] : [ - Object.assign( - syntheticPromiseSymbolTemplate({ - SP_IDENTIFIER: syntheticPromiseSymbol, - }), - { [isGeneratedHelper]: true } + ...syntheticPromiseSymbolTemplate({ + SP_IDENTIFIER: syntheticPromiseSymbol, + SAI_IDENTIFIER: syntheticAsyncIterableSymbol, + }).map((helper) => + Object.assign(helper, { [isGeneratedHelper]: true }) ), Object.assign( expressionHolderVariableTemplate({ @@ -400,6 +401,7 @@ export default ({ assertNotSyntheticPromiseTemplate({ ANSP_IDENTIFIER: assertNotSyntheticPromise, SP_IDENTIFIER: syntheticPromiseSymbol, + SAI_IDENTIFIER: syntheticAsyncIterableSymbol, CUSTOM_ERROR_BUILDER: (this as any).opts.customErrorBuilder ?? t.identifier('Error'), }), @@ -666,17 +668,23 @@ export default ({ if (!functionParent.node.async) { // Transform expression `foo` into `assertNotSyntheticPromise(foo, 'foo')`. + const args = [ + path.node, + getOriginalSourceString(this, path.node, { + wrap: false, + }), + ]; + if ( + (path.parent.type === 'ForOfStatement' && + path.node === path.parent.right) || + (path.parent.type === 'YieldExpression' && path.parent.delegate) + ) { + args.push(t.booleanLiteral(true)); + } path.replaceWith( - Object.assign( - assertNotSyntheticExpressionTemplate({ - ORIGINAL_SOURCE: getOriginalSourceString(this, path.node, { - wrap: false, - }), - NODE: path.node, - ANSP_IDENTIFIER: assertNotSyntheticPromise, - }), - { [isGeneratedHelper]: true } - ) + Object.assign(t.callExpression(assertNotSyntheticPromise, args), { + [isGeneratedHelper]: true, + }) ); return; } @@ -732,7 +740,7 @@ export default ({ }, }, ForOfStatement(path) { - if (path.node.await) return; + if (path.node.await || !path.getFunctionParent()?.node.async) return; if ( path.find( diff --git a/packages/e2e-tests/test/e2e.spec.ts b/packages/e2e-tests/test/e2e.spec.ts index 48c477324..589c9616d 100644 --- a/packages/e2e-tests/test/e2e.spec.ts +++ b/packages/e2e-tests/test/e2e.spec.ts @@ -794,6 +794,19 @@ describe('e2e', function () { const result = await shell.executeLine('out[1]'); expect(result).to.include('i: 1'); }); + + it('works with for-of iteration', async function () { + await shell.executeLine('out = [];'); + const before = + await shell.executeLine(`for (const doc of db.coll.find()) { + print('enter for-of'); + out.push(db.coll.findOne({_id:doc._id})); + print('leave for-of'); + } print('after');`); + expect(before).to.match(/(enter for-of\r?\nleave for-of\r?\n){3}after/); + const result = await shell.executeLine('out[1]'); + expect(result).to.include('i: 1'); + }); }); });