From ba8a8f771a744143fe92974ca7d9b687c4fbe1df Mon Sep 17 00:00:00 2001 From: Rich Scott Date: Thu, 7 Sep 2023 10:55:00 -0600 Subject: [PATCH] Airflow: always set credentials from args in channel ctor (#2952) In the GrpcChannelArguments constructor, always set the credentials_callback_args member from what is given. Add a test to verify serialization round-tripping is complete, and a __eq__ implementation for GrpcChannelArguments. Signed-off-by: Rich Scott --- third_party/airflow/armada/operators/grpc.py | 9 +++++++ third_party/airflow/tests/unit/test_grpc.py | 26 ++++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 third_party/airflow/tests/unit/test_grpc.py diff --git a/third_party/airflow/armada/operators/grpc.py b/third_party/airflow/armada/operators/grpc.py index bebb0f98835..3e146ccce07 100644 --- a/third_party/airflow/armada/operators/grpc.py +++ b/third_party/airflow/armada/operators/grpc.py @@ -78,10 +78,19 @@ def __init__( self.options = options self.compression = compression self.credentials_callback = None + self.credentials_callback_args = credentials_callback_args if credentials_callback_args is not None: self.credentials_callback = CredentialsCallback(**credentials_callback_args) + def __eq__(self, o): + return ( + self.target == o.target + and self.options == o.options + and self.compression == o.compression + and self.credentials_callback_args == o.credentials_callback_args + ) + def channel(self) -> grpc.Channel: """ Create a grpc.Channel based on arguments supplied to this object. diff --git a/third_party/airflow/tests/unit/test_grpc.py b/third_party/airflow/tests/unit/test_grpc.py new file mode 100644 index 00000000000..1e12b566067 --- /dev/null +++ b/third_party/airflow/tests/unit/test_grpc.py @@ -0,0 +1,26 @@ +import armada.operators.grpc + + +def test_serialize_grpc_channel(): + src_chan_args = { + "target": "localhost:443", + "credentials_callback_args": { + "module_name": "channel_test", + "function_name": "get_credentials", + "function_kwargs": { + "example_arg": "test", + }, + }, + } + + source = armada.operators.grpc.GrpcChannelArguments(**src_chan_args) + + serialized = source.serialize() + assert serialized["target"] == src_chan_args["target"] + assert ( + serialized["credentials_callback_args"] + == src_chan_args["credentials_callback_args"] + ) + + reconstituted = armada.operators.grpc.GrpcChannelArguments(**serialized) + assert reconstituted == source