diff --git a/minari/cli.py b/minari/cli.py index 1573213e..1f1da582 100644 --- a/minari/cli.py +++ b/minari/cli.py @@ -36,16 +36,19 @@ def _show_dataset_table(datasets, table_title): table.add_column("Email", justify="left", style="magenta") for dst_metadata in datasets.values(): + author = dst_metadata.get("author", "Unknown") + author_email = dst_metadata.get("author_email", "Unknown") + assert isinstance(dst_metadata["dataset_id"], str) - assert isinstance(dst_metadata["author"], str) - assert isinstance(dst_metadata["author_email"], str) + assert isinstance(author, str) + assert isinstance(author_email, str) table.add_row( dst_metadata["dataset_id"], str(dst_metadata["total_episodes"]), str(dst_metadata["total_steps"]), "Coming soon ...", - dst_metadata["author"], - dst_metadata["author_email"], + author, + author_email, ) print(table) diff --git a/minari/storage/local.py b/minari/storage/local.py index e5ecb6d6..dd5bfdec 100644 --- a/minari/storage/local.py +++ b/minari/storage/local.py @@ -59,8 +59,9 @@ def list_local_datasets( main_file_path = os.path.join(datasets_path, dst_id, "data/main_data.hdf5") with h5py.File(main_file_path, "r") as f: metadata = dict(f.attrs.items()) - if compatible_minari_version and __version__ not in SpecifierSet( - metadata["minari_version"] + if ("minari_version" not in metadata) or ( + compatible_minari_version + and __version__ not in SpecifierSet(metadata["minari_version"]) ): continue env_name, dataset_name, version = parse_dataset_id(dst_id) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..70e1232b --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,54 @@ +import pytest +from typer.testing import CliRunner + +from minari.cli import app +from minari.storage.local import delete_dataset, list_local_datasets +from tests.dataset.test_dataset_download import get_latest_compatible_dataset_id + + +runner = CliRunner() + + +def test_list(): + # local test + result = runner.invoke(app, ["list", "local", "--all"]) + assert result.exit_code == 0 + + # some of the other columns may be cut off by Rich + assert "Name" in result.stdout + + result = runner.invoke(app, ["list", "remote"]) + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "dataset_id", + [get_latest_compatible_dataset_id(env_name="pen", dataset_name="human")], +) +def test_dataset_download_then_delete(dataset_id: str): + """test download dataset invocation from CLI. + the downloading functionality is already tested in test_dataset_download.py so this is primarily to assert that the CLI is working as expected. + """ + + # might have to clear up the local dataset first. + # ideally this seems like it could just be handled by the tests + if dataset_id in list_local_datasets(): + delete_dataset(dataset_id) + + result = runner.invoke(app, ["download", dataset_id]) + + assert result.exit_code == 0 + assert f"Downloading {dataset_id} from Farama servers..." in result.stdout + assert f"Dataset {dataset_id} downloaded to" in result.stdout + + result = runner.invoke(app, ["delete", dataset_id], input="n") + assert result.exit_code == 1 # aborted but no error + assert "Aborted" in result.stdout + + result = runner.invoke(app, ["delete", dataset_id], input="😳") + assert result.exit_code == 1 + assert "Error: invalid input" in result.stdout + + result = runner.invoke(app, ["delete", dataset_id], input="y") + assert result.exit_code == 0 + assert f"Dataset {dataset_id} deleted!" in result.stdout