Skip to content

Commit

Permalink
Merge pull request #1059 from guardrails-ai/upgrade-validators
Browse files Browse the repository at this point in the history
Hub cli flag to upgrade validators
  • Loading branch information
zsimjee authored Sep 10, 2024
2 parents 41221d0 + ffb6825 commit ba1919e
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 5 deletions.
4 changes: 4 additions & 0 deletions guardrails/cli/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def install(
"--quiet",
help="Run the command in quiet mode to reduce output verbosity.",
),
upgrade: bool = typer.Option(
False, "--upgrade", help="Upgrade the package to the latest version."
),
):
try:
trace_if_enabled("hub/install")
Expand All @@ -41,6 +44,7 @@ def confirm():
package_uri,
install_local_models=local_models,
quiet=quiet,
upgrade=upgrade,
install_local_models_confirm=confirm,
)
except Exception as e:
Expand Down
7 changes: 6 additions & 1 deletion guardrails/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def install(
package_uri: str,
install_local_models=None,
quiet: bool = True,
upgrade: bool = False,
install_local_models_confirm: Callable = default_local_models_confirm,
) -> ValidatorModuleType:
"""Install a validator package from a hub URI.
Expand Down Expand Up @@ -84,7 +85,11 @@ def install(
dl_deps_msg = "Downloading dependencies"
with loader(dl_deps_msg, spinner="bouncingBar"):
ValidatorPackageService.install_hub_module(
module_manifest, site_packages, quiet=quiet, logger=cli_logger
module_manifest,
site_packages,
quiet=quiet,
upgrade=upgrade,
logger=cli_logger,
)

use_remote_endpoint = False
Expand Down
5 changes: 5 additions & 0 deletions guardrails/hub/validator_package_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def install_hub_module(
module_manifest: Manifest,
site_packages: str,
quiet: bool = False,
upgrade: bool = False,
logger=guardrails_logger,
):
install_url = ValidatorPackageService.get_install_url(module_manifest)
Expand All @@ -268,6 +269,10 @@ def install_hub_module(
)

pip_flags = [f"--target={install_directory}", "--no-deps"]

if upgrade:
pip_flags.append("--upgrade")

if quiet:
pip_flags.append("-q")

Expand Down
21 changes: 21 additions & 0 deletions tests/unit_tests/cli/hub/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_install_local_models__false(self, mocker):
"hub://guardrails/test-validator",
install_local_models=False,
quiet=False,
upgrade=False,
install_local_models_confirm=ANY,
)

Expand All @@ -45,6 +46,7 @@ def test_install_local_models__true(self, mocker):
"hub://guardrails/test-validator",
install_local_models=True,
quiet=False,
upgrade=False,
install_local_models_confirm=ANY,
)

Expand All @@ -61,6 +63,7 @@ def test_install_local_models__none(self, mocker):
"hub://guardrails/test-validator",
install_local_models=None,
quiet=False,
upgrade=False,
install_local_models_confirm=ANY,
)

Expand All @@ -77,6 +80,7 @@ def test_install_quiet(self, mocker):
"hub://guardrails/test-validator",
install_local_models=None,
quiet=True,
upgrade=False,
install_local_models_confirm=ANY,
)

Expand Down Expand Up @@ -205,6 +209,23 @@ def test_other_exception(self, mocker):

sys_exit_spy.assert_called_once_with(1)

def test_install_with_upgrade_flag(self, mocker):
mock_install = mocker.patch("guardrails.hub.install.install")
runner = CliRunner()
result = runner.invoke(
hub_command, ["install", "--upgrade", "hub://guardrails/test-validator"]
)

mock_install.assert_called_once_with(
"hub://guardrails/test-validator",
install_local_models=None,
quiet=False,
install_local_models_confirm=ANY,
upgrade=True,
)

assert result.exit_code == 0


def test_get_site_packages_location(mocker):
mock_pip_process = mocker.patch("guardrails.cli.hub.utils.pip_process")
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/hub/test_hub_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_install_local_models__false(self, mocker, use_remote_inferencing):
)

mock_pip_install_hub_module.assert_called_once_with(
self.manifest, self.site_packages, quiet=ANY, logger=ANY
self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY
)
mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages)

Expand Down Expand Up @@ -160,7 +160,7 @@ def test_install_local_models__true(self, mocker, use_remote_inferencing):
)

mock_pip_install_hub_module.assert_called_once_with(
self.manifest, self.site_packages, quiet=ANY, logger=ANY
self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY
)
mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages)

Expand Down Expand Up @@ -221,7 +221,7 @@ def test_install_local_models__none(self, mocker, use_remote_inferencing):
)

mock_pip_install_hub_module.assert_called_once_with(
self.manifest, self.site_packages, quiet=ANY, logger=ANY
self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY
)
mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages)

Expand Down Expand Up @@ -278,7 +278,7 @@ def test_happy_path(self, mocker, use_remote_inferencing):
)

mock_pip_install_hub_module.assert_called_once_with(
self.manifest, self.site_packages, quiet=ANY, logger=ANY
self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY
)
mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages)

Expand Down

0 comments on commit ba1919e

Please sign in to comment.