diff --git a/CHANGELOG.md b/CHANGELOG.md index 583e6655..c003a891 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Fixed +- Improve module loading ([#61]) ## [1.4.3] - 2020-05-14 ### Fixed @@ -71,6 +73,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [1.0.1]: https://github.com/GoogleCloudPlatform/functions-framework-python/releases/tag/v1.0.1 [1.0.0]: https://github.com/GoogleCloudPlatform/functions-framework-python/releases/tag/v1.0.0 +[#61]: https://github.com/GoogleCloudPlatform/functions-framework-python/pull/61 [#49]: https://github.com/GoogleCloudPlatform/functions-framework-python/pull/49 [#44]: https://github.com/GoogleCloudPlatform/functions-framework-python/pull/44 [#38]: https://github.com/GoogleCloudPlatform/functions-framework-python/pull/38 diff --git a/src/functions_framework/__init__.py b/src/functions_framework/__init__.py index ce0e2fbf..641c715b 100644 --- a/src/functions_framework/__init__.py +++ b/src/functions_framework/__init__.py @@ -140,10 +140,24 @@ def create_app(target=None, source=None, signature_type=None): # Set the environment variable if it wasn't already os.environ["FUNCTION_SIGNATURE_TYPE"] = signature_type - # Load the source file - spec = importlib.util.spec_from_file_location("__main__", source) + # Load the source file: + # 1. Extract the module name from the source path + realpath = os.path.realpath(source) + directory, filename = os.path.split(realpath) + name, extension = os.path.splitext(filename) + + # 2. Create a new module + spec = importlib.util.spec_from_file_location(name, realpath) source_module = importlib.util.module_from_spec(spec) - sys.path.append(os.path.dirname(os.path.realpath(source))) + + # 3. Add the directory of the source to sys.path to allow the function to + # load modules relative to its location + sys.path.append(directory) + + # 4. Add the module to sys.modules + sys.modules[name] = source_module + + # 5. Execute the module spec.loader.exec_module(source_module) app = flask.Flask(target, template_folder=template_folder) diff --git a/tests/test_functions/module_is_correct/main.py b/tests/test_functions/module_is_correct/main.py index ed7a30cf..06d2f971 100644 --- a/tests/test_functions/module_is_correct/main.py +++ b/tests/test_functions/module_is_correct/main.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os.path import typing @@ -21,7 +22,9 @@ class TestClass: def function(request): # Ensure that the module for any object in this file is set correctly - assert TestClass.__mro__[0].__module__ == "__main__" + _, filename = os.path.split(__file__) + name, _ = os.path.splitext(filename) + assert TestClass.__mro__[0].__module__ == name # Ensure that calling `get_type_hints` on an object in this file succeeds assert typing.get_type_hints(TestClass) == {}