Skip to content

Commit

Permalink
fix: generation of Python imports (#979)
Browse files Browse the repository at this point in the history
Closes #974

### Summary of Changes

Fix the generation of Python imports. It now also works without explicit
imports or if stubs are declared in the same module.
  • Loading branch information
lars-reimann authored Apr 3, 2024
1 parent c7d006f commit f69d836
Show file tree
Hide file tree
Showing 91 changed files with 222 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import {
isSdsPipeline,
isSdsPlaceholder,
isSdsPrefixOperation,
isSdsQualifiedImport,
isSdsReference,
isSdsSegment,
isSdsTemplateString,
Expand All @@ -49,7 +48,6 @@ import {
isSdsTemplateStringPart,
isSdsTemplateStringStart,
isSdsWildcard,
isSdsWildcardImport,
isSdsYield,
SdsArgument,
SdsAssignee,
Expand All @@ -64,17 +62,16 @@ import {
SdsParameterList,
SdsPipeline,
SdsPlaceholder,
SdsReference,
SdsSegment,
SdsStatement,
} from '../generated/ast.js';
import { isInStubFile, isStubFile } from '../helpers/fileExtensions.js';
import { isStubFile } from '../helpers/fileExtensions.js';
import { IdManager } from '../helpers/idManager.js';
import {
getAbstractResults,
getArguments,
getAssignees,
getImportedDeclarations,
getImports,
getModuleMembers,
getParameters,
getPlaceholderByName,
Expand Down Expand Up @@ -905,11 +902,9 @@ export class SafeDsPythonGenerator {
}
} else if (isSdsReference(expression)) {
const declaration = expression.target.ref!;
const referenceImport =
this.getExternalReferenceNeededImport(expression, declaration) ||
this.getInternalReferenceNeededImport(expression, declaration);
const referenceImport = this.createImportDataForReference(expression);
frame.addImport(referenceImport);
return traceToNode(expression)(referenceImport?.alias || this.getPythonNameOrDefault(declaration));
return traceToNode(expression)(referenceImport?.alias ?? this.getPythonNameOrDefault(declaration));
}
/* c8 ignore next 2 */
throw new Error(`Unknown expression type: ${expression.$type}`);
Expand Down Expand Up @@ -1127,68 +1122,38 @@ export class SafeDsPythonGenerator {
}${this.generateExpression(argument.value, frame)}`;
}

private getExternalReferenceNeededImport(
expression: SdsExpression | undefined,
declaration: SdsDeclaration | undefined,
): ImportData | undefined {
if (!expression || !declaration) {
private createImportDataForReference(reference: SdsReference): ImportData | undefined {
const target = reference.target.ref;
if (!target) {
/* c8 ignore next 2 */
return undefined;
}

// Root Node is always a module.
const currentModule = <SdsModule>AstUtils.findRootNode(expression);
const targetModule = <SdsModule>AstUtils.findRootNode(declaration);
for (const value of getImports(currentModule)) {
// Verify same package
if (value.package !== targetModule.name) {
continue;
}
if (isSdsQualifiedImport(value)) {
const importedDeclarations = getImportedDeclarations(value);
for (const importedDeclaration of importedDeclarations) {
if (declaration === importedDeclaration.declaration?.ref) {
if (importedDeclaration.alias !== undefined) {
return {
importPath: this.getPythonModuleOrDefault(targetModule),
declarationName: importedDeclaration.declaration?.ref?.name,
alias: importedDeclaration.alias.alias,
};
} else {
return {
importPath: this.getPythonModuleOrDefault(targetModule),
declarationName: importedDeclaration.declaration?.ref?.name,
};
}
}
}
}
if (isSdsWildcardImport(value)) {
return {
importPath: this.getPythonModuleOrDefault(targetModule),
declarationName: declaration.name,
};
const sourceModule = <SdsModule>AstUtils.findRootNode(reference);
const targetModule = <SdsModule>AstUtils.findRootNode(target);

// Compute import path
let importPath: string | undefined = undefined;
if (isSdsPipeline(target) || isSdsSegment(target)) {
if (sourceModule !== targetModule) {
importPath = `${this.getPythonModuleOrDefault(targetModule)}.${this.formatGeneratedFileName(
this.getModuleFileBaseName(targetModule),
)}`;
}
} else if (isSdsModule(target.$container)) {
importPath = this.getPythonModuleOrDefault(targetModule);
}
return undefined;
}

private getInternalReferenceNeededImport(
expression: SdsExpression,
declaration: SdsDeclaration,
): ImportData | undefined {
// Root Node is always a module.
const currentModule = <SdsModule>AstUtils.findRootNode(expression);
const targetModule = <SdsModule>AstUtils.findRootNode(declaration);
if (currentModule !== targetModule && !isInStubFile(targetModule)) {
if (importPath) {
const refText = reference.target.$refText;
return {
importPath: `${this.getPythonModuleOrDefault(targetModule)}.${this.formatGeneratedFileName(
this.getModuleFileBaseName(targetModule),
)}`,
declarationName: this.getPythonNameOrDefault(declaration),
importPath,
declarationName: this.getPythonNameOrDefault(target),
alias: refText === target.name ? undefined : refText,
};
} else {
return undefined;
}
return undefined;
}

private getModuleFileBaseName(module: SdsModule): string {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.parameterWithPythonName import f1, f2

# Segments ---------------------------------------------------------------------

def test(param1, param_2, param_3=0):
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.pipelineWithPythonName import f

# Pipelines --------------------------------------------------------------------

def test_pipeline():
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.segmentWithPythonName import f

# Segments ---------------------------------------------------------------------

def test_segment():
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.twoPipelines import f

# Pipelines --------------------------------------------------------------------

def test1():
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.twoSegments import f

# Segments ---------------------------------------------------------------------

def test1(a, b=0):
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.blockLambdaResult import g, h

# Segments ---------------------------------------------------------------------

def f1(l):
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.blockLambda import f1, f2, f3, g, g2

# Pipelines --------------------------------------------------------------------

def test():
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Imports ----------------------------------------------------------------------

from tests.generator.call import f, g, h, i, j, k
from typing import Any, Callable, TypeVar

# Type variables ---------------------------------------------------------------
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.constant import f

# Pipelines --------------------------------------------------------------------

def test():
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.enumVariantCall import f, MyEnum

# Pipelines --------------------------------------------------------------------

def test():
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

from tests.generator.expressionLambda import f

# Pipelines --------------------------------------------------------------------

def test():
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Imports ----------------------------------------------------------------------

from tests.generator.indexedAccess import f
from typing import Any, TypeVar

# Type variables ---------------------------------------------------------------
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Imports ----------------------------------------------------------------------

from tests.generator.infixOperation import f, g, h, i
from typing import TypeVar

# Type variables ---------------------------------------------------------------
Expand Down
Loading

0 comments on commit f69d836

Please sign in to comment.