diff --git a/merlin/systems/dag/ops/faiss.py b/merlin/systems/dag/ops/faiss.py index 748d1df99..60aefc9cc 100644 --- a/merlin/systems/dag/ops/faiss.py +++ b/merlin/systems/dag/ops/faiss.py @@ -63,7 +63,10 @@ def __init__(self, index_path, topk=10): def load_artifacts(self, artifact_path): filename = Path(self.index_path).name - full_index_path = str(Path(artifact_path) / filename) + path_artifact = Path(artifact_path) + if path_artifact.is_file(): + path_artifact = path_artifact.parent + full_index_path = str(path_artifact / filename) index = faiss.read_index(full_index_path) if HAS_GPU: