diff --git a/src/library_analyzer/processing/api/__init__.py b/src/library_analyzer/processing/api/__init__.py index 7393d895..6e9c9cd4 100644 --- a/src/library_analyzer/processing/api/__init__.py +++ b/src/library_analyzer/processing/api/__init__.py @@ -1,6 +1,23 @@ from ._get_api import get_api from ._get_instance_attributes import get_instance_attributes from ._get_parameter_list import get_parameter_list +from ._infer_purity import ( + DefinitelyImpure, + DefinitelyPure, + ImpurityIndicator, + MaybeImpure, + OpenMode, + PurityInformation, + PurityResult, + calc_function_id, + determine_open_mode, + determine_purity, + extract_impurity_reasons, + generate_purity_information, + get_function_defs, + get_purity_result_str, + infer_purity, +) from ._package_metadata import ( distribution, distribution_version, diff --git a/src/library_analyzer/processing/api/_infer_purity.py b/src/library_analyzer/processing/api/_infer_purity.py new file mode 100644 index 00000000..eee7c65b --- /dev/null +++ b/src/library_analyzer/processing/api/_infer_purity.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass +from enum import Enum, auto +from typing import Optional + +import astroid +from library_analyzer.processing.api.model import ( + BuiltInFunction, + Call, + ConcreteImpurityIndicator, + FileRead, + FileWrite, + ImpurityCertainty, + ImpurityIndicator, + Reference, + StringLiteral, + SystemInteraction, + VariableRead, + VariableWrite, +) +from library_analyzer.utils import ASTWalker + +BUILTIN_FUNCTIONS = { + "open": BuiltInFunction(Reference("open"), ConcreteImpurityIndicator(), ImpurityCertainty.DEFINITELY_IMPURE), + # TODO: how to replace the ... with the correct type? + "print": BuiltInFunction(Reference("print"), SystemInteraction(), ImpurityCertainty.DEFINITELY_IMPURE), + "read": BuiltInFunction(Reference("read"), ConcreteImpurityIndicator(), ImpurityCertainty.DEFINITELY_IMPURE), + "write": BuiltInFunction(Reference("write"), ConcreteImpurityIndicator(), ImpurityCertainty.DEFINITELY_IMPURE), + "readline": BuiltInFunction( + Reference("readline"), ConcreteImpurityIndicator(), ImpurityCertainty.DEFINITELY_IMPURE + ), + "readlines": BuiltInFunction( + Reference("readlines"), ConcreteImpurityIndicator(), ImpurityCertainty.DEFINITELY_IMPURE + ), + "writelines": BuiltInFunction( + Reference("writelines"), ConcreteImpurityIndicator(), ImpurityCertainty.DEFINITELY_IMPURE + ), + "close": BuiltInFunction(Reference("close"), ConcreteImpurityIndicator(), ImpurityCertainty.DEFINITELY_PURE), +} + + +@dataclass +class FunctionID: + module: str + name: str + line: int + col: int + + def __str__(self) -> str: + return f"{self.module}.{self.name}.{self.line}.{self.col}" + + +class PurityResult(ABC): + def __init__(self) -> None: + self.reasons: list[ImpurityIndicator] = [] + + +@dataclass +class DefinitelyPure(PurityResult): + reasons = [] + + +@dataclass +class MaybeImpure(PurityResult): + reasons: list[ImpurityIndicator] + + # def __hash__(self) -> int: + # return hash(tuple(self.reasons)) + + +@dataclass +class DefinitelyImpure(PurityResult): + reasons: list[ImpurityIndicator] + + # def __hash__(self) -> int: + # return hash(tuple(self.reasons)) + + +@dataclass +class PurityInformation: + id: FunctionID + # purity: PurityResult + reasons: list[ImpurityIndicator] + + # def __hash__(self) -> int: + # return hash((self.id, self.reasons)) + + # def __eq__(self, other: object) -> bool: + # if not isinstance(other, PurityInformation): + # return NotImplemented + # return self.id == other.id and self.reasons == other.reasons + + +class PurityHandler: + def __init__(self) -> None: + self.purity_reason: list[ImpurityIndicator] = [] + + def append_reason(self, reason: list[ImpurityIndicator]) -> None: + for r in reason: + self.purity_reason.append(r) + + def enter_functiondef(self, node: astroid.FunctionDef) -> None: + # print(f"Enter functionDef node: {node.as_string()}") + # Handle the FunctionDef node here + pass # Are we analyzing function defs within function defs? Yes, we are. + + def enter_assign(self, node: astroid.Assign) -> None: + # print(f"Entering Assign node {node}") + # Handle the Assign node here + if isinstance(node.value, astroid.Call): + pass + if isinstance(node.value, astroid.Const): + self.append_reason([VariableWrite(Reference(node.as_string()))]) + else: # default case + self.append_reason([VariableWrite(Reference(node.as_string()))]) + # TODO: Assign node needs further analysis to determine if it is pure or impure + + def enter_assignattr(self, node: astroid.AssignAttr) -> None: + # print(f"Entering AssignAttr node {node.as_string()}") + # Handle the AssignAtr node here + self.append_reason([VariableWrite(Reference(node.as_string()))]) + # TODO: AssignAttr node needs further analysis to determine if it is pure or impure + + def enter_call(self, node: astroid.Call) -> None: + # print(f"Entering Call node {node.as_string()}") + # Handle the Call node here + if isinstance(node.func, astroid.Attribute): + pass + elif isinstance(node.func, astroid.Name): + if node.func.name in BUILTIN_FUNCTIONS: + value = node.args[0] + if isinstance(value, astroid.Name): + impurity_indicator = check_builtin_function(node, node.func.name, value.name, True) + self.append_reason(impurity_indicator) + else: + impurity_indicator = check_builtin_function(node, node.func.name, value.value) + self.append_reason(impurity_indicator) + + self.append_reason([Call(Reference(node.as_string()))]) + # TODO: Call node needs further analysis to determine if it is pure or impure + + def enter_attribute(self, node: astroid.Attribute) -> None: + # print(f"Entering Attribute node {node.as_string()}") + # Handle the Attribute node here + if isinstance(node.expr, astroid.Name): + if node.attrname in BUILTIN_FUNCTIONS: + impurity_indicator = check_builtin_function(node, node.attrname) + self.append_reason(impurity_indicator) + else: + self.append_reason([Call(Reference(node.as_string()))]) + + def enter_arguments(self, node: astroid.Arguments) -> None: + # print(f"Entering Arguments node {node.as_string()}") + # Handle the Arguments node here + pass + + def enter_expr(self, node: astroid.Expr) -> None: + # print(f"Entering Expr node {node.as_string()}") + # print(node.value) + # Handle the Expr node here + pass + + def enter_name(self, node: astroid.Name) -> None: + # print(f"Entering Name node {node.as_string()}") + # Handle the Name node here + pass + + def enter_const(self, node: astroid.Const) -> None: + # print(f"Entering Const node {node.as_string()}") + # Handle the Const node here + pass + + def enter_assignname(self, node: astroid.AssignName) -> None: + # print(f"Entering AssignName node {node.as_string()}") + # Handle the AssignName node here + pass + + def enter_with(self, node: astroid.With) -> None: + # print(f"Entering With node {node.as_string()}") + # Handle the With node here + pass + + +class OpenMode(Enum): + READ = auto() + WRITE = auto() + READ_WRITE = auto() + + +def determine_open_mode(args: list[str]) -> OpenMode: + write_mode = {"w", "wb", "a", "ab", "x", "xb", "wt", "at", "xt"} + read_mode = {"r", "rb", "rt"} + read_and_write_mode = { + "r+", + "rb+", + "w+", + "wb+", + "a+", + "ab+", + "x+", + "xb+", + "r+t", + "rb+t", + "w+t", + "wb+t", + "a+t", + "ab+t", + "x+t", + "xb+t", + "r+b", + "rb+b", + "w+b", + "wb+b", + "a+b", + "ab+b", + "x+b", + "xb+b", + } + if len(args) == 1: + return OpenMode.READ + + mode = args[1] + if isinstance(mode, astroid.Const): + mode = mode.value + + if mode in read_mode: + return OpenMode.READ + if mode in write_mode: + return OpenMode.WRITE + if mode in read_and_write_mode: + return OpenMode.READ_WRITE + + raise ValueError(f"{mode} is not a valid mode for the open function") + + +def check_builtin_function( + node: astroid.NodeNG, key: str, value: Optional[str] = None, is_var: bool = False +) -> list[ImpurityIndicator]: + if is_var: + if key == "open": + open_mode = determine_open_mode(node.args) + if open_mode == OpenMode.WRITE: + return [FileWrite(Reference(value))] + + if open_mode == OpenMode.READ: + return [FileRead(Reference(value))] + + if open_mode == OpenMode.READ_WRITE: + return [FileRead(Reference(value)), FileWrite(Reference(value))] + + elif isinstance(value, str): + if key == "open": + open_mode = determine_open_mode(node.args) + if open_mode == OpenMode.WRITE: # write mode + return [FileWrite(StringLiteral(value))] + + if open_mode == OpenMode.READ: # read mode + return [FileRead(StringLiteral(value))] + + if open_mode == OpenMode.READ_WRITE: # read and write mode + return [FileRead(StringLiteral(value)), FileWrite(StringLiteral(value))] + + raise TypeError(f"Unknown builtin function {key}") + + if key in ("read", "readline", "readlines"): + return [VariableRead(Reference(node.as_string()))] + if key in ("write", "writelines"): + return [VariableWrite(Reference(node.as_string()))] + + raise TypeError(f"Unknown builtin function {key}") + + +def infer_purity(code: str) -> list[PurityInformation]: + purity_handler: PurityHandler = PurityHandler() + walker = ASTWalker(purity_handler) + functions = get_function_defs(code) + result = [] + for function in functions: + # print(function) + # print(f"Analyse {function.name}:") + walker.walk(function) + purity_result = determine_purity(purity_handler.purity_reason) + # print(f"Result: {purity_result.__class__.__name__}") + # if not isinstance(purity_result, DefinitelyPure): + # print(f"Reasons: {purity_result.reasons}") + # print(f"Function {function.name} is done. \n") + result.append(generate_purity_information(function, purity_result)) + purity_handler.purity_reason = [] + return result + + +def determine_purity(indicators: list[ImpurityIndicator]) -> PurityResult: + if len(indicators) == 0: + return DefinitelyPure() + if any(indicator.certainty == ImpurityCertainty.DEFINITELY_IMPURE for indicator in indicators): + return DefinitelyImpure(reasons=indicators) + + return MaybeImpure(reasons=indicators) + + # print(f"Maybe check {(any(purity_reason.is_reason_for_impurity() for purity_reason in purity_reasons))}") + # if any(reason.is_reason_for_impurity() for reason in purity_reasons): + # # print(f"Definitely check {any(isinstance(reason, Call) for reason in purity_reasons)}") + # result = MaybeImpure(reasons=purity_reasons) + # if any(isinstance(reason, Call) for reason in purity_reasons): + # return DefinitelyImpure(reasons=purity_reasons) + # return result + # else: + # return DefinitelyPure() + + +def get_function_defs(code: str) -> list[astroid.FunctionDef]: + try: + module = astroid.parse(code) + except SyntaxError as error: + raise ValueError("Invalid Python code") from error + + function_defs = list[astroid.FunctionDef]() + for node in module.body: + if isinstance(node, astroid.FunctionDef): + function_defs.append(node) + return function_defs + # TODO: This function should read from a python file (module) and return a list of FunctionDefs + + +def extract_impurity_reasons(purity: PurityResult) -> list[ImpurityIndicator]: + if isinstance(purity, DefinitelyPure): + return [] + return purity.reasons + + +def generate_purity_information(function: astroid.FunctionDef, purity_result: PurityResult) -> PurityInformation: + function_id = calc_function_id(function) + reasons = extract_impurity_reasons(purity_result) + purity_info = PurityInformation(function_id, reasons) + return purity_info + + +def calc_function_id(node: astroid.NodeNG) -> FunctionID: + if not isinstance(node, astroid.FunctionDef): + raise TypeError("Node is not a function") + module = node.root().name + # module = "_infer_purity.py" + # if module.endswith(".py"): + # module = module[:-3] + name = node.name + line = node.position.lineno + col = node.position.col_offset + return FunctionID(module, name, line, col) + + +# this function is only for visualization purposes +def get_purity_result_str(indicators: list[ImpurityIndicator]) -> str: + if len(indicators) == 0: + return "Definitely Pure" + if any(indicator.certainty == ImpurityCertainty.DEFINITELY_IMPURE for indicator in indicators): + return "Definitely Impure" + + return "Maybe Impure" diff --git a/src/library_analyzer/processing/api/model/__init__.py b/src/library_analyzer/processing/api/model/__init__.py index 6597d39c..fc5cb43e 100644 --- a/src/library_analyzer/processing/api/model/__init__.py +++ b/src/library_analyzer/processing/api/model/__init__.py @@ -16,6 +16,26 @@ ParameterDocumentation, ) from ._parameters import Parameter, ParameterAssignment +from ._purity import ( + AttributeAccess, + BuiltInFunction, + Call, + ConcreteImpurityIndicator, + Expression, + FileRead, + FileWrite, + GlobalAccess, + ImpurityCertainty, + ImpurityIndicator, + InstanceAccess, + ParameterAccess, + Reference, + StringLiteral, + SystemInteraction, + UnknownCallTarget, + VariableRead, + VariableWrite, +) from ._types import ( AbstractType, BoundaryType, diff --git a/src/library_analyzer/processing/api/model/_purity.py b/src/library_analyzer/processing/api/model/_purity.py new file mode 100644 index 00000000..d01d26fc --- /dev/null +++ b/src/library_analyzer/processing/api/model/_purity.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum, auto + +import astroid + + +# Type of access +class Expression(astroid.NodeNG, ABC): + # @abstractmethod + # def __hash__(self) -> int: + # pass + ... + + +@dataclass +class AttributeAccess(Expression): + """Class for class attribute access""" + + name: str + + # def __hash__(self) -> int: + # return hash(self.name) + + +@dataclass +class GlobalAccess(Expression): + """Class for global variable access""" + + name: str + module: str = "None" + + # def __hash__(self) -> int: + # return hash(self.name) + + +@dataclass +class ParameterAccess(Expression): + """Class for function parameter access""" + + name: str + function: str + + # def __hash__(self) -> int: + # return hash(self.name) + + +@dataclass +class InstanceAccess(Expression): + """Class for field access of an instance attribute (receiver.target)""" + + receiver: Expression + target: Expression + + # def __hash__(self) -> int: + # return hash((self.receiver, self.target)) + + +@dataclass +class StringLiteral(Expression): + value: str + + # def __hash__(self) -> int: + # return hash(self.value) + + +@dataclass +class Reference(Expression): + name: str + + # def __hash__(self) -> int: + # return hash(self.name) + + +class ImpurityCertainty(Enum): + DEFINITELY_PURE = auto() + MAYBE_IMPURE = auto() + DEFINITELY_IMPURE = auto() + + +# Reasons for impurity +class ImpurityIndicator(ABC): + certainty: ImpurityCertainty + + # @abstractmethod + # def __hash__(self) -> int: + # pass + + @abstractmethod + def is_side_effect(self) -> bool: + pass + + +@dataclass +class ConcreteImpurityIndicator(ImpurityIndicator): + # def __hash__(self) -> int: + # return hash(self.certainty) + + def is_side_effect(self) -> bool: + return False + + +@dataclass +class VariableRead(ImpurityIndicator): + expression: Expression + certainty = ImpurityCertainty.MAYBE_IMPURE + + # def __hash__(self) -> int: + # return hash(self.expression) + + def is_side_effect(self) -> bool: + return False + + +@dataclass +class VariableWrite(ImpurityIndicator): + expression: Expression + certainty = ImpurityCertainty.MAYBE_IMPURE + + # def __hash__(self) -> int: + # return hash(self.expression) + + def is_side_effect(self) -> bool: + return True + + +@dataclass +class FileRead(ImpurityIndicator): + source: Expression + certainty = ImpurityCertainty.DEFINITELY_IMPURE + + # def __hash__(self) -> int: + # return hash(self.source) + + def is_side_effect(self) -> bool: + return False + + +@dataclass +class FileWrite(ImpurityIndicator): + source: Expression + certainty = ImpurityCertainty.DEFINITELY_IMPURE + + # def __hash__(self) -> int: + # return hash(self.source) + + def is_side_effect(self) -> bool: + return True + + +@dataclass +class UnknownCallTarget(ImpurityIndicator): + expression: Expression + certainty = ImpurityCertainty.DEFINITELY_IMPURE + + # def __hash__(self) -> int: + # return hash(self.expression) + + def is_side_effect(self) -> bool: + return True # TODO: improve this to make analysis more precise + + +@dataclass +class Call(ImpurityIndicator): + expression: Expression + certainty = ImpurityCertainty.DEFINITELY_IMPURE + + # def __hash__(self) -> int: + # return hash(self.expression) + + def is_side_effect(self) -> bool: + return True # TODO: improve this to make analysis more precise + + +@dataclass +class SystemInteraction(ImpurityIndicator): + certainty = ImpurityCertainty.DEFINITELY_IMPURE + + # def __hash__(self) -> int: + # return hash("SystemInteraction") + + def is_side_effect(self) -> bool: + return True + + +@dataclass +class BuiltInFunction(ImpurityIndicator): + """Class for built-in functions""" + + expression: Expression + indicator: ImpurityIndicator # this should be a list to handle multiple reasons + certainty: ImpurityCertainty + + # def __hash__(self) -> int: + # return hash(self.indicator) + + def is_side_effect(self) -> bool: + return False diff --git a/src/library_analyzer/utils/_ASTWalker.py b/src/library_analyzer/utils/_ASTWalker.py index 61ddb452..689162b3 100644 --- a/src/library_analyzer/utils/_ASTWalker.py +++ b/src/library_analyzer/utils/_ASTWalker.py @@ -29,6 +29,7 @@ def __walk(self, node: astroid.NodeNG, visited_nodes: set[astroid.NodeNG]) -> No if node in visited_nodes: raise AssertionError("Node visited twice") visited_nodes.add(node) + # print(f"{node}\n") self.__enter(node) for child_node in node.get_children(): @@ -52,12 +53,8 @@ def __get_callbacks(self, node: astroid.NodeNG) -> _EnterAndLeaveFunctions: if methods is None: handler = self._handler class_name = klass.__name__.lower() - enter_method = getattr( - handler, f"enter_{class_name}", getattr(handler, "enter_default", None) - ) - leave_method = getattr( - handler, f"leave_{class_name}", getattr(handler, "leave_default", None) - ) + enter_method = getattr(handler, f"enter_{class_name}", getattr(handler, "enter_default", None)) + leave_method = getattr(handler, f"leave_{class_name}", getattr(handler, "leave_default", None)) self._cache[klass] = (enter_method, leave_method) else: enter_method, leave_method = methods diff --git a/tests/library_analyzer/processing/api/test_infer_purity.py b/tests/library_analyzer/processing/api/test_infer_purity.py new file mode 100644 index 00000000..fda605ca --- /dev/null +++ b/tests/library_analyzer/processing/api/test_infer_purity.py @@ -0,0 +1,475 @@ +import astroid +import pytest +from library_analyzer.processing.api import ( + DefinitelyImpure, + DefinitelyPure, + ImpurityIndicator, + MaybeImpure, + OpenMode, + PurityInformation, + PurityResult, + calc_function_id, + determine_open_mode, + determine_purity, + extract_impurity_reasons, + infer_purity, +) +from library_analyzer.processing.api.model import ( + AttributeAccess, + Call, + FileRead, + FileWrite, + Reference, + StringLiteral, + VariableRead, + VariableWrite, +) + + +@pytest.mark.parametrize( + "code, expected", + [ + ( + """ + def fun1(a): + h(a) + return a + """, + ".fun1.2.0", + ), + ( + """ + + def fun2(a): + a = 1 + return a + """, + ".fun2.3.0", + ), + ( + """ + a += 1 # not a function => TypeError + """, + None, + ), + ], +) +def test_calc_function_id(code: str, expected: str) -> None: + module = astroid.parse(code) + function_node = module.body[0] + if expected is None: + with pytest.raises(TypeError): + calc_function_id(function_node) + + else: + result = calc_function_id(function_node) + assert str(result) == expected + + +# since we only look at FunctionDefs we can not use other types of CodeSnippets +@pytest.mark.parametrize( + "purity_result, expected", + [ + (DefinitelyPure(), []), + ( + DefinitelyImpure(reasons=[Call(expression=AttributeAccess(name="impure_call"))]), + [Call(expression=AttributeAccess(name="impure_call"))], + ), + ( + MaybeImpure(reasons=[FileRead(source=StringLiteral(value="read_path"))]), + [FileRead(source=StringLiteral(value="read_path"))], + ), + ( + MaybeImpure(reasons=[FileWrite(source=StringLiteral(value="write_path"))]), + [FileWrite(source=StringLiteral(value="write_path"))], + ), + ( + MaybeImpure(reasons=[VariableRead(StringLiteral(value="var_read"))]), + [VariableRead(StringLiteral(value="var_read"))], + ), + ( + MaybeImpure(reasons=[VariableWrite(StringLiteral(value="var_write"))]), + [VariableWrite(StringLiteral(value="var_write"))], + ), + ], +) +def test_generate_purity_information(purity_result: PurityResult, expected: list[ImpurityIndicator]) -> None: + purity_info = extract_impurity_reasons(purity_result) + + assert purity_info == expected + + +@pytest.mark.parametrize( + "purity_reasons, expected", + [ + ([], DefinitelyPure()), + ( + [Call(expression=AttributeAccess(name="impure_call"))], + DefinitelyImpure(reasons=[Call(expression=AttributeAccess(name="impure_call"))]), + ), + # TODO: improve analysis so this test does not fail: + # ( + # [Call(expression=AttributeAccess(name="pure_call"))], + # DefinitelyPure() + # ), + ( + [FileRead(source=StringLiteral(value="read_path"))], + DefinitelyImpure(reasons=[FileRead(source=StringLiteral(value="read_path"))]), + ), + ( + [FileWrite(source=StringLiteral(value="write_path"))], + DefinitelyImpure(reasons=[FileWrite(source=StringLiteral(value="write_path"))]), + ), + ( + [VariableRead(StringLiteral(value="var_read"))], + MaybeImpure(reasons=[VariableRead(StringLiteral(value="var_read"))]), + ), + ( + [VariableWrite(StringLiteral(value="var_write"))], + MaybeImpure(reasons=[VariableWrite(StringLiteral(value="var_write"))]), + ), + ], +) +def test_determine_purity(purity_reasons: list[ImpurityIndicator], expected: PurityResult) -> None: + result = determine_purity(purity_reasons) + assert result == expected + + +@pytest.mark.parametrize( + "args, expected", + [ + (["test"], OpenMode.READ), + (["test", "r"], OpenMode.READ), + (["test", "rb"], OpenMode.READ), + (["test", "rt"], OpenMode.READ), + (["test", "r+"], OpenMode.READ_WRITE), + (["test", "w"], OpenMode.WRITE), + (["test", "wb"], OpenMode.WRITE), + (["test", "wt"], OpenMode.WRITE), + (["test", "w+"], OpenMode.READ_WRITE), + (["test", "x"], OpenMode.WRITE), + (["test", "xb"], OpenMode.WRITE), + (["test", "xt"], OpenMode.WRITE), + (["test", "x+"], OpenMode.READ_WRITE), + (["test", "a"], OpenMode.WRITE), + (["test", "ab"], OpenMode.WRITE), + (["test", "at"], OpenMode.WRITE), + (["test", "a+"], OpenMode.READ_WRITE), + (["test", "r+b"], OpenMode.READ_WRITE), + (["test", "w+b"], OpenMode.READ_WRITE), + (["test", "x+b"], OpenMode.READ_WRITE), + (["test", "a+b"], OpenMode.READ_WRITE), + (["test", "r+t"], OpenMode.READ_WRITE), + (["test", "w+t"], OpenMode.READ_WRITE), + (["test", "x+t"], OpenMode.READ_WRITE), + (["test", "a+t"], OpenMode.READ_WRITE), + (["test", "error"], ValueError), + ], +) +def test_determine_open_mode(args: list[str], expected: OpenMode) -> None: + if expected is ValueError: + with pytest.raises(ValueError): + determine_open_mode(args) + else: + result = determine_open_mode(args) + assert result == expected + + +@pytest.mark.parametrize( + "code, expected", + [ + ( + """ + def fun1(): + open("test1.txt") # default mode: read only + """, + [FileRead(source=StringLiteral(value="test1.txt")), Call(expression=Reference(name="open('test1.txt')"))], + ), + ( + """ + def fun2(): + open("test2.txt", "r") # read only + """, + [ + FileRead(source=StringLiteral(value="test2.txt")), + Call(expression=Reference(name="open('test2.txt', 'r')")), + ], + ), + ( + """ + def fun3(): + open("test3.txt", "w") # write only + """, + [ + FileWrite(source=StringLiteral(value="test3.txt")), + Call(expression=Reference(name="open('test3.txt', 'w')")), + ], + ), + ( + """ + def fun4(): + open("test4.txt", "a") # append + """, + [ + FileWrite(source=StringLiteral(value="test4.txt")), + Call(expression=Reference(name="open('test4.txt', 'a')")), + ], + ), + ( + """ + def fun5(): + open("test5.txt", "r+") # read and write + """, + [ + FileRead(source=StringLiteral(value="test5.txt")), + FileWrite(source=StringLiteral(value="test5.txt")), + Call(expression=Reference(name="open('test5.txt', 'r+')")), + ], + ), + ( + """ + def fun6(): + f = open("test6.txt") # default mode: read only + f.read() + """, + [ + VariableWrite(expression=Reference(name="f = open('test6.txt')")), + FileRead(source=StringLiteral(value="test6.txt")), + Call(expression=Reference(name="open('test6.txt')")), + Call(expression=Reference(name="f.read()")), + VariableRead(expression=Reference(name="f.read")), + ], + ), + ( + """ + def fun7(): + f = open("test7.txt") # default mode: read only + f.readline([2]) + """, + [ + VariableWrite(expression=Reference(name="f = open('test7.txt')")), + FileRead(source=StringLiteral(value="test7.txt")), + Call(expression=Reference(name="open('test7.txt')")), + Call(expression=Reference(name="f.readline([2])")), + VariableRead(expression=Reference(name="f.readline")), + ], + ), + ( + """ + def fun8(): + f = open("test8.txt", "w") # write only + f.write("message") + """, + [ + VariableWrite(expression=Reference(name="f = open('test8.txt', 'w')")), + FileWrite(source=StringLiteral(value="test8.txt")), + Call(expression=Reference(name="open('test8.txt', 'w')")), + Call(expression=Reference(name="f.write('message')")), + VariableWrite(expression=Reference(name="f.write")), + ], + ), + ( + """ + def fun9(): + f = open("test9.txt", "w") # write only + f.writelines(["message1", "message2"]) + """, + [ + VariableWrite(expression=Reference(name="f = open('test9.txt', 'w')")), + FileWrite(source=StringLiteral(value="test9.txt")), + Call(expression=Reference(name="open('test9.txt', 'w')")), + Call(expression=Reference(name="f.writelines(['message1', 'message2'])")), + VariableWrite(expression=Reference(name="f.writelines")), + ], + ), + ( + """ + def fun10(): + with open("test10.txt") as f: # default mode: read only + f.read() + """, + [ + FileRead(source=StringLiteral(value="test10.txt")), + Call(expression=Reference(name="open('test10.txt')")), + Call(expression=Reference(name="f.read()")), + VariableRead(expression=Reference(name="f.read")), + ], + ), + ( + """ + def fun11(path11): # open with variable + open(path11) + """, + [FileRead(source=Reference("path11")), Call(expression=Reference(name="open(path11)"))], # ?? + ), + ( + """ + def fun12(path12): # open with variable write mode + open(path12, "w") + """, + [FileWrite(source=Reference(name="path12")), Call(expression=Reference(name="open(path12, 'w')"))], # ?? + ), + ( + """ + def fun13(path13): # open with variable write mode + open(path13, "wb+") + """, + [ + FileRead(source=Reference(name="path13")), + FileWrite(source=Reference(name="path13")), + Call(expression=Reference(name="open(path13, 'wb+')")), + ], # ?? + ), + ( + """ + def fun14(path14): + with open(path14) as f: + f.read() + """, + [ + FileRead(source=Reference("path14")), + Call(expression=Reference(name="open(path14)")), + Call(expression=Reference(name="f.read()")), + VariableRead(expression=Reference(name="f.read")), + ], # ?? + ), + ( + """ + def fun14(path14): + with open(path14) as f: + f.read() + """, + [ + FileRead(source=Reference("path14")), + Call(expression=Reference(name="open(path14)")), + Call(expression=Reference(name="f.read()")), + VariableRead(expression=Reference(name="f.read")), + ], # ?? + ), + ( + """ + def fun15(path15): # open with variable and wrong mode + open(path15, "test") + """, + ValueError, + ), + ( + """ + def fun16(): # this does not belong here but is needed for code coverage + print("test") + """, + TypeError, + ), + ], +) +# TODO: test for wrong arguments and Errors +def test_file_interaction(code: str, expected: list[ImpurityIndicator]) -> None: + if expected is ValueError: + with pytest.raises(ValueError): + infer_purity(code) + elif expected is TypeError: + with pytest.raises(TypeError): + infer_purity(code) + else: + purity_info: list[PurityInformation] = infer_purity(code) + assert purity_info[0].reasons == expected + + +# @pytest.mark.parametrize( +# "code, expected", +# [ +# ( +# """ +# def impure_fun(a): +# impure_call(a) # call => impure +# impure_call(a) # call => impure - check if the analysis is correct for multiple calls - done +# return a +# """, +# [Call(expression=Reference(name='impure_call(a)')), +# Call(expression=Reference(name='impure_call(a)'))], +# ), +# ( +# """ +# def pure_fun(a): +# a += 1 +# return a +# """, +# [], +# ), +# ( +# """ +# class A: +# def __init__(self): +# self.value = 42 +# +# a = A() +# +# def instance(a): +# res = a.value # InstanceAccess => pure?? +# return res +# """, +# [VariableWrite(expression=InstanceAccess( +# receiver=Reference(name='a'), +# target=Reference(name='a.value') +# ))], # TODO: is this correct? +# ), +# ( +# """ +# class B: +# name = "test" +# +# b = B() +# +# def attribute(b): +# res = b.name # AttributeAccess => maybe impure +# return res +# """, +# [VariableWrite(expression=AttributeAccess(name='res = b.name'))], # TODO: is this correct? +# ), +# ( +# """ +# global_var = 17 +# def global_access(): +# res = global_var # GlobalAccess => impure +# return res +# """, +# [VariableWrite(expression=GlobalAccess(name='res = global_var'))], # TODO: is this correct? +# ), +# ( +# """ +# def parameter_access(a): +# res = a # ParameterAccess => pure +# return res +# """, +# [Call(expression=ParameterAccess( +# name="a", +# function="parameter_access"), +# )], # TODO: is this correct? +# ), +# ( +# """ +# glob = g(1) # TODO: This will get filtered out because it is not a function call, but a variable assignment with a +# # function call and therefore further analysis is needed +# """, +# [VariableWrite(expression=Reference(name='b = g(a)')), +# Call(expression=Reference(name="g(1)"))], # TODO: is this correct? +# ), +# ( +# """ +# def fun(a): +# h(a) +# b = g(a) # call => impure +# b += 1 +# return b +# """, +# [Call(expression=Reference(name='h(a)')), +# VariableWrite(expression=Reference(name='b = g(a)')), +# Call(expression=Reference(name='g(a)'))], # TODO: is this correct? +# ), +# +# ] +# +# ) +# def test_infer_purity_basics(code: str, expected: list[ImpurityIndicator]) -> None: +# result_list = infer_purity(code) +# assert result_list[0].reasons == expected