diff --git a/scripts/copy_weaviate_db.py b/scripts/copy_weaviate_db.py index 4b8dd2c..0db2a23 100644 --- a/scripts/copy_weaviate_db.py +++ b/scripts/copy_weaviate_db.py @@ -1,5 +1,6 @@ import argparse import logging +from itertools import islice from tempo_embeddings.embeddings.weaviate_database import WeaviateDatabaseManager @@ -13,6 +14,12 @@ action="store_true", help="Overwrite existing corpus in target database.", ) + parser.add_argument( + "--limit", + type=int, + required=False, + help="Maximum number of objects to copy per collection.", + ) source_args = parser.add_argument_group("Weaviate export database arguments") source_args.add_argument( @@ -89,11 +96,14 @@ target_db.delete_collection(corpus) config = source_db.collection_config(corpus) + + if args.limit: + config["total_count"] = min(config["total_count"], args.limit) + objects = islice(source_db.collection_objects(corpus), args.limit) + target_db.import_config(config) target_db.import_objects( - source_db.collection_objects(corpus), - config["corpus"], - total_count=config["total_count"], + objects, config["corpus"], total_count=config["total_count"] ) target_db.validate_config()