diff --git a/arango/aql.py b/arango/aql.py index 4dccef33..941000e5 100644 --- a/arango/aql.py +++ b/arango/aql.py @@ -276,6 +276,7 @@ def execute( fill_block_cache: Optional[bool] = None, allow_dirty_read: bool = False, allow_retry: bool = False, + force_one_shard_attribute_value: Optional[str] = None, ) -> Result[Cursor]: """Execute the query and return the result cursor. @@ -373,6 +374,16 @@ def execute( :param allow_retry: Make it possible to retry fetching the latest batch from a cursor. :type allow_retry: bool + :param force_one_shard_attribute_value: (Enterprise Only) Explicitly set + a shard key value that will be used during query snippet distribution + to limit the query to a specific server in the cluster. This query option + can be used in complex queries in case the query optimizer cannot + automatically detect that the query can be limited to only a single + server (e.g. in a disjoint smart graph case). If the option is set + incorrectly, i.e. to a wrong shard key value, then the query may be + shipped to a wrong DB server and may not return results + (i.e. empty result set). Use at your own risk. + :param force_one_shard_attribute_value: str | None :return: Result cursor. :rtype: arango.cursor.Cursor :raise arango.exceptions.AQLQueryExecuteError: If execute fails. @@ -418,10 +429,10 @@ def execute( options["skipInaccessibleCollections"] = skip_inaccessible_cols if max_runtime is not None: options["maxRuntime"] = max_runtime - - # New in 3.11 if allow_retry is not None: options["allowRetry"] = allow_retry + if force_one_shard_attribute_value is not None: + options["forceOneShardAttributeValue"] = force_one_shard_attribute_value if options: data["options"] = options diff --git a/tests/test_aql.py b/tests/test_aql.py index 25f3f501..65b7365e 100644 --- a/tests/test_aql.py +++ b/tests/test_aql.py @@ -17,7 +17,7 @@ AQLQueryTrackingSetError, AQLQueryValidateError, ) -from tests.helpers import assert_raises, extract +from tests.helpers import assert_raises, extract, generate_col_name def test_aql_attributes(db, username): @@ -246,6 +246,36 @@ def test_aql_query_management(db_version, db, bad_db, col, docs): assert err.value.error_code in {11, 1228} +def test_aql_query_force_one_shard_attribute_value(db, db_version, enterprise, cluster): + if db_version < version.parse("3.10") or not enterprise or not cluster: + return + + name = generate_col_name() + col = db.create_collection(name, shard_fields=["foo"], shard_count=3) + + doc = {"foo": "bar"} + col.insert(doc) + + cursor = db.aql.execute( + "FOR d IN @@c RETURN d", + bind_vars={"@c": name}, + force_one_shard_attribute_value="bar", + ) + + results = [doc for doc in cursor] + assert len(results) == 1 + assert results[0]["foo"] == "bar" + + cursor = db.aql.execute( + "FOR d IN @@c RETURN d", + bind_vars={"@c": name}, + force_one_shard_attribute_value="ooo", + ) + + results = [doc for doc in cursor] + assert len(results) == 0 + + def test_aql_function_management(db, bad_db): fn_group = "functions::temperature" fn_name_1 = "functions::temperature::celsius_to_fahrenheit"