diff --git a/tests/test_terraform.py b/tests/test_terraform.py index 3958768..9a344a5 100644 --- a/tests/test_terraform.py +++ b/tests/test_terraform.py @@ -1,5 +1,4 @@ import json -import os import pytest from mock import ANY, MagicMock @@ -66,7 +65,7 @@ def terraform(tmp_path, mocker): @pytest.fixture def resource(tmp_path): name = "test-resource" - path = tmp_path / name / "terraform.tfstate" + path = tmp_path / "terraform.tfstate" path.parent.mkdir(parents=True, exist_ok=True) path.write_text(TEST_RESOURCE_STATE, encoding="utf-8") yield name @@ -92,12 +91,6 @@ def test_make_pemfile(): assert fobj.read() == content -def test_make_tf(terraform): - name = "foo" - with terraform.make_tf(name) as tf: - assert tf.working_dir == os.path.join(terraform.tmp_dir, name) - - def test_create(tmp_path, terraform): name = "foo" terraform.create(name=name, cloud="aws") @@ -105,8 +98,8 @@ def test_create(tmp_path, terraform): terraform.cmd_mock.assert_any_call( "apply", capture_output=False, auto_approve=IsFlagged ) - assert (tmp_path / name / "main.tf.json").exists() - with open(tmp_path / name / "main.tf.json") as fobj: + assert (tmp_path / "main.tf.json").exists() + with open(tmp_path / "main.tf.json") as fobj: data = json.load(fobj) assert data["resource"]["iterative_machine"] == { name: {"name": name, "cloud": "aws"}, diff --git a/tpi/terraform.py b/tpi/terraform.py index 994a399..d630358 100644 --- a/tpi/terraform.py +++ b/tpi/terraform.py @@ -1,7 +1,6 @@ import asyncio import os import sys -from contextlib import contextmanager from itertools import repeat from typing import TYPE_CHECKING, Iterator, Optional, Union @@ -18,22 +17,14 @@ class TPIException(Exception): class TerraformBackend: - def __init__(self, tmp_dir: StrPath, **kwargs): - self.tmp_dir = tmp_dir - os.makedirs(self.tmp_dir, exist_ok=True) + """Class for managing a named TPI iterative-machine resource.""" - @contextmanager - def make_tf(self, name: str): - from tpi import TerraformProviderIterative, TPIError + def __init__(self, working_dir: StrPath, **kwargs): + from tpi import TerraformProviderIterative - try: - working_dir = os.path.join(self.tmp_dir, name) - os.makedirs(working_dir, exist_ok=True) - yield TerraformProviderIterative(working_dir=working_dir) - except TPIError: - raise - except Exception as exc: - raise TPIError("terraform failed") from exc + self.working_dir = working_dir + os.makedirs(self.working_dir, exist_ok=True) + self.tf = TerraformProviderIterative(working_dir=self.working_dir) def create(self, name: Optional[str] = None, **config): """Create and start an instance of the specified machine.""" @@ -42,12 +33,11 @@ def create(self, name: Optional[str] = None, **config): from tpi import render_json assert name and "cloud" in config - with self.make_tf(name) as tf: - tf_file = os.path.join(tf.working_dir, "main.tf.json") - with open(tf_file, "w", encoding="utf-8") as fobj: - fobj.write(render_json(name=name, **config, indent=2)) - tf.cmd("init") - tf.cmd("apply", auto_approve=IsFlagged) + tf_file = os.path.join(self.working_dir, "main.tf.json") + with open(tf_file, "w", encoding="utf-8") as fobj: + fobj.write(render_json(name=name, **config, indent=2)) + self.tf.cmd("init") + self.tf.cmd("apply", auto_approve=IsFlagged) def destroy(self, name: Optional[str] = None, **config): """Stop and destroy all instances of the specified machine.""" @@ -55,16 +45,14 @@ def destroy(self, name: Optional[str] = None, **config): assert name - with self.make_tf(name) as tf: - if first(tf.iter_instances(name)): - tf.cmd("destroy", auto_approve=IsFlagged) + if first(self.tf.iter_instances(name)): + self.tf.cmd("destroy", auto_approve=IsFlagged) def instances(self, name: Optional[str] = None, **config) -> Iterator[dict]: """Iterate over status of all instances of the specified machine.""" assert name - with self.make_tf(name) as tf: - yield from tf.iter_instances(name) + yield from self.tf.iter_instances(name) def close(self): pass @@ -73,7 +61,7 @@ def run_shell(self, name: Optional[str] = None, **config): """Spawn an interactive SSH shell for the specified machine.""" from tpi import TerraformProviderIterative - resource = self._default_resource(name) + resource = self.default_resource(name) with TerraformProviderIterative.pemfile(resource) as pem: self._shell( host=resource["instance_ip"], @@ -82,7 +70,7 @@ def run_shell(self, name: Optional[str] = None, **config): known_hosts=None, ) - def _default_resource(self, name): + def default_resource(self, name): from tpi import TPIError resource = first(self.instances(name))