diff --git a/.github/workflows/build-containers.yaml b/.github/workflows/build-containers.yaml index 2500ead3b..a24094bdf 100644 --- a/.github/workflows/build-containers.yaml +++ b/.github/workflows/build-containers.yaml @@ -21,7 +21,7 @@ jobs: contents: read steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Docker meta fedn id: meta1 @@ -63,15 +63,15 @@ jobs: type=sha,suffix=-mnist-pytorch - - name: Log in to GitHub Docker Registry - uses: docker/login-action@v1 + - name: Log in to GitHub Container Registry + uses: docker/login-action@v2 with: registry: docker.pkg.github.com username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - name: Build and push - uses: docker/build-push-action@v2 + uses: docker/build-push-action@v4 with: push: "${{ github.event_name != 'pull_request' }}" tags: ${{ steps.meta1.outputs.tags }} @@ -79,7 +79,7 @@ jobs: file: Dockerfile - name: Build and push (mnist-keras) - uses: docker/build-push-action@v2 + uses: docker/build-push-action@v4 with: push: "${{ github.event_name != 'pull_request' }}" tags: ${{ steps.meta2.outputs.tags }} @@ -89,7 +89,7 @@ jobs: REQUIREMENTS=examples/mnist-keras/requirements.txt - name: Build and push (mnist-pytorch) - uses: docker/build-push-action@v2 + uses: docker/build-push-action@v4 with: push: "${{ github.event_name != 'pull_request' }}" tags: ${{ steps.meta3.outputs.tags }} diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 311af03fa..c1ec38548 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-20.04 steps: - name: checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: init venv run: .devcontainer/bin/init_venv.sh diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 2bf711af3..d5c49a8f0 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -25,7 +25,7 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - uses: actions/setup-python@v3 with: diff --git a/examples/mnist-pytorch/requirements.txt b/examples/mnist-pytorch/requirements.txt index 72e57c0cb..381aaa4f3 100644 --- a/examples/mnist-pytorch/requirements.txt +++ b/examples/mnist-pytorch/requirements.txt @@ -1,4 +1,4 @@ torch==1.13.1 -torchvision==0.12.0 +torchvision==0.14.1 fire==0.3.1 docker==5.0.2 \ No newline at end of file diff --git a/fedn/fedn/client.py b/fedn/fedn/client.py index 37ba07be7..f1abfc3db 100644 --- a/fedn/fedn/client.py +++ b/fedn/fedn/client.py @@ -32,6 +32,14 @@ VALID_NAME_REGEX = '^[a-zA-Z0-9_-]*$' +class GrpcAuth(grpc.AuthMetadataPlugin): + def __init__(self, key): + self._key = key + + def __call__(self, context, callback): + callback((('authorization', f'Token {self._key}'),), None) + + class Client: """FEDn Client. Service running on client/datanodes in a federation, recieving and handling model update and model validation requests. @@ -250,7 +258,7 @@ def _connect(self, client_config): host = client_config['fqdn'] # assuming https if fqdn is used port = 443 - print(f"CLIENT: Connecting to combiner host: {host}", flush=True) + print(f"CLIENT: Connecting to combiner host: {host}:{port}", flush=True) if client_config['certificate']: print("CLIENT: using certificate from Reducer for GRPC channel") @@ -271,9 +279,16 @@ def _connect(self, client_config): cert = ssl.get_server_certificate((host, port)) credentials = grpc.ssl_channel_credentials(cert.encode('utf-8')) - channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) + if self.config['token']: + token = self.config['token'] + auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) + channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) + else: + channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) else: print("CLIENT: using insecure GRPC channel") + if port == 443: + port = 80 channel = grpc.insecure_channel("{}:{}".format( host, str(port)))