Skip to content

Commit

Permalink
For some reason, the part I didn't fix got fixed, so I put it back in.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryoji.nagata committed Nov 15, 2024
1 parent 09075ec commit dc507cd
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,12 @@ def _backend_should_export(
"""

export = model_args.pop("export", None)
if export is not None:
if export:
return export, model_args

file_name = model_args.get("file_name", target_file_name)
subfolder = model_args.get("subfolder", None)
primary_full_path = Path(subfolder, file_name).as_posix() if subfolder else file_name
primary_full_path = Path(subfolder, file_name).as_posix() if subfolder else Path(file_name).as_posix()
secondary_full_path = (
Path(subfolder, self.backend, file_name).as_posix()
if subfolder
Expand All @@ -322,17 +322,19 @@ def _backend_should_export(
# First check if the expected file exists in the root of the model directory
# If it doesn't, check if it exists in the backend subfolder.
# If it does, set the subfolder to include the backend
export = primary_full_path not in model_file_names
if export and "subfolder" not in model_args:
export = secondary_full_path not in model_file_names
if not export:
model_found = primary_full_path in model_file_names
if not model_found and "subfolder" not in model_args:
model_found = secondary_full_path in model_file_names
if model_found:
if len(model_file_names) > 1 and "file_name" not in model_args:
logger.warning(
f"Multiple {backend_name} files found in {load_path.as_posix()!r}: {model_file_names}, defaulting to {secondary_full_path!r}. "
f'Please specify the desired file name via `model_kwargs={{"file_name": "<file_name>"}}`.'
)
model_args["subfolder"] = self.backend
model_args["file_name"] = file_name
if export is None:
export = not model_found

# If the file_name contains subfolders, set it as the subfolder instead
file_name_parts = Path(file_name).parts
Expand Down

0 comments on commit dc507cd

Please sign in to comment.