Skip to content

Commit

Permalink
feat(async-rewriter): allow cursor iteration with for-of MONGOSH-1527
Browse files Browse the repository at this point in the history
  • Loading branch information
addaleax committed Oct 4, 2023
1 parent 7467002 commit 42aad30
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 29 deletions.
18 changes: 14 additions & 4 deletions packages/async-rewriter2/src/async-writer-babel.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ describe('AsyncWriter', function () {
expect(implicitlyAsyncFn).to.have.callCount(10);
});

it('can use for loops as weird assignments', async function () {
it('can use for loops as weird assignments (sync)', async function () {
const obj = { foo: null };
implicitlyAsyncFn.resolves(obj);
await runTranspiledCode(
Expand All @@ -400,6 +400,16 @@ describe('AsyncWriter', function () {
expect(obj.foo).to.equal('bar');
});

it('can use for loops as weird assignments (async)', async function () {
const obj = { foo: null };
implicitlyAsyncFn.resolves(obj);
await runTranspiledCode(
'(async() => { for await (implicitlyAsyncFn().foo of ["foo", "bar"]); })()'
);
expect(implicitlyAsyncFn).to.have.callCount(2);
expect(obj.foo).to.equal('bar');
});

it('works with assignments to objects', async function () {
implicitlyAsyncFn.resolves({ foo: 'bar' });
const ret = runTranspiledCode(`
Expand Down Expand Up @@ -995,16 +1005,16 @@ describe('AsyncWriter', function () {
expect(() => runTranspiledCode('var db = {}; db.testx();')).to.throw(
'db.testx is not a function'
);
// (Note: The following ones would give better error messages in regular code)
// (Note: The following one would give better error messages in regular code)
expect(() =>
runTranspiledCode('var db = {}; new Promise(db.foo)')
).to.throw('Promise resolver undefined is not a function');
expect(() =>
runTranspiledCode('var db = {}; for (const a of db.foo) {}')
).to.throw(/undefined is not iterable/);
).to.throw(/db.foo is not iterable/);
expect(() =>
runTranspiledCode('var db = {}; for (const a of db[0]) {}')
).to.throw(/undefined is not iterable/);
).to.throw(/db\[0\] is not iterable/);
expect(() => runTranspiledCode('for (const a of 8) {}')).to.throw(
'8 is not iterable'
);
Expand Down
193 changes: 168 additions & 25 deletions packages/async-rewriter2/src/stages/transform-maybe-await.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ interface AsyncFunctionIdentifiers {
expressionHolder: babel.types.Identifier;
markSyntheticPromise: babel.types.Identifier;
isSyntheticPromise: babel.types.Identifier;
adaptAsyncIterableToSyncIterable: babel.types.Identifier;
syntheticPromiseSymbol: babel.types.Identifier;
syntheticAsyncIterableSymbol: babel.types.Identifier;
demangleError: babel.types.Identifier;
assertNotSyntheticPromise: babel.types.Identifier;
}
Expand All @@ -45,6 +47,7 @@ export default ({
const isGeneratedInnerFunction = asNodeKey(
Symbol('isGeneratedInnerFunction')
);
const isWrappedForOfLoop = asNodeKey(Symbol('isWrappedForOfLoop'));
const isGeneratedHelper = asNodeKey(Symbol('isGeneratedHelper'));
const isOriginalBody = asNodeKey(Symbol('isOriginalBody'));
const isAlwaysSyncFunction = asNodeKey(Symbol('isAlwaysSyncFunction'));
Expand Down Expand Up @@ -82,6 +85,47 @@ export default ({
}
`);

const adaptAsyncIterableToSyncIterableTemplate = babel.template.statement(`
function AAITSI_IDENTIFIER(original) {
const SAI_IDENTIFIER = Symbol.for("@@mongosh.syntheticAsyncIterable");
if (!original || !original[SAI_IDENTIFIER]) {
return { iterable: original, isSyntheticAsyncIterable: false };
}
const originalIterator = original[Symbol.asyncIterator]();
let next;
let returned;
return {
isSyntheticAsyncIterable: true,
iterable: {
[Symbol.iterator]() {
return this;
},
next() {
let _next = next;
next = undefined;
return _next;
},
return(value) {
returned = { value };
return {
value,
done: true
}
},
async expectNext() {
next ??= await originalIterator.next();
},
async syncReturn() {
if (returned) {
await originalIterator.return(returned.value);
}
}
}
}
}
`);

const asyncTryCatchWrapperTemplate = babel.template.expression(`
async () => {
try {
Expand Down Expand Up @@ -137,6 +181,26 @@ export default ({
}
`);

const forOfLoopTemplate = babel.template.statement(`{
const ITERABLE_INFO = AAITSI_IDENTIFIER(ORIGINAL_ITERABLE);
const ITERABLE_ISAI = (ITERABLE_INFO).isSyntheticAsyncIterable;
const ITERABLE = (ITERABLE_INFO).iterable;
try {
ITERABLE_ISAI && await (ITERABLE).expectNext();
for (const ITEM of (ORIGINAL_ITERABLE_SOURCE, ITERABLE)) {
ORIGINAL_DECLARATION;
try {
ORIGINAL_BODY;
} finally {
ITERABLE_ISAI && await (ITERABLE).expectNext();
}
}
} finally {
ITERABLE_ISAI && await (ITERABLE).syncReturn();
}
}`);

// If we encounter an error object, we fix up the error message from something
// like `("a" , foo(...)(...)) is not a function` to `a is not a function`.
// For that, we look for a) the U+FEFF markers we use to tag the original source
Expand Down Expand Up @@ -164,6 +228,34 @@ export default ({
FUNCTION_STATE_IDENTIFIER === 'async' ? SYNC_RETURN_VALUE_IDENTIFIER : null
)`);

// Transform expression `foo` into
// `('\uFEFFfoo\uFEFF', ex = foo, isSyntheticPromise(ex) ? await ex : ex)`
// The first part of the sequence expression is used to identify this
// expression for re-writing error messages, so that we can transform
// TypeError: ((intermediate value)(intermediate value) , (intermediate value)(intermediate value)(intermediate value)).findx is not a function
// back into
// TypeError: db.test.findx is not a function
// The U+FEFF markers are only used to rule out any practical chance of
// user code accidentally being recognized as the original source code.
// We limit the string length so that long expressions (e.g. those
// containing functions) are not included in full length.
function getOriginalSourceString(
{ file }: { file: babel.BabelFile },
node: babel.Node,
{ wrap = true } = {}
): babel.types.StringLiteral {
const prettyOriginalString = limitStringLength(
node.start !== undefined
? file.code.slice(node.start ?? undefined, node.end ?? undefined)
: '<unknown>',
24
);

if (!wrap) return t.stringLiteral(prettyOriginalString);

return t.stringLiteral('\ufeff' + prettyOriginalString + '\ufeff');
}

return {
pre(file: babel.BabelFile) {
this.file = file;
Expand Down Expand Up @@ -212,12 +304,18 @@ export default ({
const isSyntheticPromise =
existingIdentifiers?.isSyntheticPromise ??
path.scope.generateUidIdentifier('isp');
const adaptAsyncIterableToSyncIterable =
existingIdentifiers?.adaptAsyncIterableToSyncIterable ??
path.scope.generateUidIdentifier('aaitsi');
const assertNotSyntheticPromise =
existingIdentifiers?.assertNotSyntheticPromise ??
path.scope.generateUidIdentifier('ansp');
const syntheticPromiseSymbol =
existingIdentifiers?.syntheticPromiseSymbol ??
path.scope.generateUidIdentifier('sp');
const syntheticAsyncIterableSymbol =
existingIdentifiers?.syntheticAsyncIterableSymbol ??
path.scope.generateUidIdentifier('sai');
const demangleError =
existingIdentifiers?.demangleError ??
path.scope.generateUidIdentifier('de');
Expand All @@ -228,8 +326,10 @@ export default ({
expressionHolder,
markSyntheticPromise,
isSyntheticPromise,
adaptAsyncIterableToSyncIterable,
assertNotSyntheticPromise,
syntheticPromiseSymbol,
syntheticAsyncIterableSymbol,
demangleError,
};
path.parentPath.setData(identifierGroupKey, identifiersGroup);
Expand Down Expand Up @@ -273,6 +373,13 @@ export default ({
}),
{ [isGeneratedHelper]: true }
),
Object.assign(
adaptAsyncIterableToSyncIterableTemplate({
AAITSI_IDENTIFIER: adaptAsyncIterableToSyncIterable,
SAI_IDENTIFIER: syntheticAsyncIterableSymbol,
}),
{ [isGeneratedHelper]: true }
),
Object.assign(
isSyntheticPromiseTemplate({
ISP_IDENTIFIER: isSyntheticPromise,
Expand Down Expand Up @@ -556,22 +663,15 @@ export default ({
isSyntheticPromise,
assertNotSyntheticPromise,
} = identifierGroup;
const prettyOriginalString = limitStringLength(
path.node.start !== undefined
? this.file.code.slice(
path.node.start ?? undefined,
path.node.end ?? undefined
)
: '<unknown>',
24
);

if (!functionParent.node.async) {
// Transform expression `foo` into `assertNotSyntheticPromise(foo, 'foo')`.
path.replaceWith(
Object.assign(
assertNotSyntheticExpressionTemplate({
ORIGINAL_SOURCE: t.stringLiteral(prettyOriginalString),
ORIGINAL_SOURCE: getOriginalSourceString(this, path.node, {
wrap: false,
}),
NODE: path.node,
ANSP_IDENTIFIER: assertNotSyntheticPromise,
}),
Expand All @@ -581,24 +681,10 @@ export default ({
return;
}

// Transform expression `foo` into
// `('\uFEFFfoo\uFEFF', ex = foo, isSyntheticPromise(ex) ? await ex : ex)`
// The first part of the sequence expression is used to identify this
// expression for re-writing error messages, so that we can transform
// TypeError: ((intermediate value)(intermediate value) , (intermediate value)(intermediate value)(intermediate value)).findx is not a function
// back into
// TypeError: db.test.findx is not a function
// The U+FEFF markers are only used to rule out any practical chance of
// user code accidentally being recognized as the original source code.
// We limit the string length so that long expressions (e.g. those
// containing functions) are not included in full length.
const originalSource = t.stringLiteral(
'\ufeff' + prettyOriginalString + '\ufeff'
);
path.replaceWith(
Object.assign(
awaitSyntheticPromiseTemplate({
ORIGINAL_SOURCE: originalSource,
ORIGINAL_SOURCE: getOriginalSourceString(this, path.node),
EXPRESSION_HOLDER: expressionHolder,
ISP_IDENTIFIER: isSyntheticPromise,
NODE: path.node,
Expand Down Expand Up @@ -645,6 +731,63 @@ export default ({
);
},
},
ForOfStatement(path) {
if (path.node.await) return;

if (
path.find(
(path) => path.isFunction() || !!path.node[isGeneratedHelper]
)?.node?.[isGeneratedHelper]
) {
return path.skip();
}

if (
path.find(
(path) => path.isFunction() || !!path.node[isWrappedForOfLoop]
)?.node?.[isWrappedForOfLoop]
) {
return;
}

const identifierGroup: AsyncFunctionIdentifiers | null = path
.findParent((path) => !!path.getData(identifierGroupKey))
?.getData(identifierGroupKey);
if (!identifierGroup)
throw new Error('Missing identifier group for ForOfStatement');
const { adaptAsyncIterableToSyncIterable } = identifierGroup;
const item = path.scope.generateUidIdentifier('i');
path.replaceWith(
Object.assign(
forOfLoopTemplate({
ORIGINAL_ITERABLE: path.node.right,
ORIGINAL_ITERABLE_SOURCE: getOriginalSourceString(
this,
path.node.right
),
ORIGINAL_DECLARATION:
path.node.left.type === 'VariableDeclaration'
? t.variableDeclaration(
path.node.left.kind,
path.node.left.declarations.map((d) => ({
...d,
init: item,
}))
)
: t.expressionStatement(
t.assignmentExpression('=', path.node.left, item)
),
ORIGINAL_BODY: path.node.body,
ITERABLE_INFO: path.scope.generateUidIdentifier('ii'),
ITERABLE_ISAI: path.scope.generateUidIdentifier('isai'),
ITERABLE: path.scope.generateUidIdentifier('it'),
ITEM: item,
AAITSI_IDENTIFIER: adaptAsyncIterableToSyncIterable,
}),
{ [isWrappedForOfLoop]: true }
)
);
},
},
};
};
Expand Down
4 changes: 4 additions & 0 deletions packages/shell-api/src/abstract-cursor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ export abstract class AbstractCursor<
return result;
}

get [Symbol.for('@@mongosh.syntheticAsyncIterable')]() {
return true;
}

async *[Symbol.asyncIterator]() {
let doc;
// !== null should suffice, but some stubs in our tests return 'undefined'
Expand Down
4 changes: 4 additions & 0 deletions packages/shell-api/src/change-stream-cursor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ export default class ChangeStreamCursor extends ShellApiWithMongoClass {
return this._cursor.tryNext();
}

get [Symbol.for('@@mongosh.syntheticAsyncIterable')]() {
return true;
}

async *[Symbol.asyncIterator]() {
let doc;
while ((doc = await this.tryNext()) !== null) {
Expand Down

0 comments on commit 42aad30

Please sign in to comment.