Skip to content

Commit

Permalink
Airflow: always set credentials from args in channel ctor (#2952)
Browse files Browse the repository at this point in the history
In the GrpcChannelArguments constructor, always set the
credentials_callback_args member from what is given. Add a test to
verify serialization round-tripping is complete, and a __eq__
implementation for GrpcChannelArguments.

Signed-off-by: Rich Scott <[email protected]>
  • Loading branch information
richscott authored Sep 7, 2023
1 parent 99097de commit ba8a8f7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
9 changes: 9 additions & 0 deletions third_party/airflow/armada/operators/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,19 @@ def __init__(
self.options = options
self.compression = compression
self.credentials_callback = None
self.credentials_callback_args = credentials_callback_args

if credentials_callback_args is not None:
self.credentials_callback = CredentialsCallback(**credentials_callback_args)

def __eq__(self, o):
return (
self.target == o.target
and self.options == o.options
and self.compression == o.compression
and self.credentials_callback_args == o.credentials_callback_args
)

def channel(self) -> grpc.Channel:
"""
Create a grpc.Channel based on arguments supplied to this object.
Expand Down
26 changes: 26 additions & 0 deletions third_party/airflow/tests/unit/test_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import armada.operators.grpc


def test_serialize_grpc_channel():
src_chan_args = {
"target": "localhost:443",
"credentials_callback_args": {
"module_name": "channel_test",
"function_name": "get_credentials",
"function_kwargs": {
"example_arg": "test",
},
},
}

source = armada.operators.grpc.GrpcChannelArguments(**src_chan_args)

serialized = source.serialize()
assert serialized["target"] == src_chan_args["target"]
assert (
serialized["credentials_callback_args"]
== src_chan_args["credentials_callback_args"]
)

reconstituted = armada.operators.grpc.GrpcChannelArguments(**serialized)
assert reconstituted == source

0 comments on commit ba8a8f7

Please sign in to comment.