Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clearer code for PodGenerator.deserialize_model_file #26641

Merged
merged 2 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions airflow/kubernetes/pod_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import copy
import datetime
import hashlib
import logging
import os
import re
import uuid
Expand All @@ -40,6 +41,8 @@
from airflow.utils import yaml
from airflow.version import version as airflow_version

log = logging.getLogger(__name__)

MAX_LABEL_LEN = 63


Expand Down Expand Up @@ -412,16 +415,13 @@ def deserialize_model_file(path: str) -> k8s.V1Pod:
"""
:param path: Path to the file
:return: a kubernetes.client.models.V1Pod

Unfortunately we need access to the private method
XD-DENG marked this conversation as resolved.
Show resolved Hide resolved
``_ApiClient__deserialize_model`` from the kubernetes client.
This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.
"""
if os.path.exists(path):
with open(path) as stream:
pod = yaml.safe_load(stream)
else:
pod = yaml.safe_load(path)
pod = None
log.warning("Model file %s does not exist", path)

return PodGenerator.deserialize_model_dict(pod)
uranusjr marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -430,6 +430,10 @@ def deserialize_model_dict(pod_dict: dict) -> k8s.V1Pod:
"""
Deserializes python dictionary to k8s.V1Pod

Unfortunately we need access to the private method
``_ApiClient__deserialize_model`` from the kubernetes client.
This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.

:param pod_dict: Serialized dict of k8s.V1Pod object
:return: De-serialized k8s.V1Pod
"""
Expand Down
29 changes: 23 additions & 6 deletions tests/executors/test_kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import random
import re
import string
import sys
import unittest
from datetime import datetime, timedelta
from unittest import mock

import pytest
import yaml
from kubernetes.client import models as k8s
from kubernetes.client.rest import ApiException
from urllib3 import HTTPResponse
Expand Down Expand Up @@ -100,14 +102,33 @@ def test_create_pod_id(self):
@mock.patch("airflow.kubernetes.pod_generator.PodGenerator")
@mock.patch("airflow.executors.kubernetes_executor.KubeConfig")
def test_get_base_pod_from_template(self, mock_kubeconfig, mock_generator):
# Provide non-existent file path,
# so None will be passed to deserialize_model_dict().
pod_template_file_path = "/bar/biz"
get_base_pod_from_template(pod_template_file_path, None)
assert "deserialize_model_dict" == mock_generator.mock_calls[0][0]
assert pod_template_file_path == mock_generator.mock_calls[0][1][0]
assert mock_generator.mock_calls[0][1][0] is None

mock_kubeconfig.pod_template_file = "/foo/bar"
get_base_pod_from_template(None, mock_kubeconfig)
assert "deserialize_model_dict" == mock_generator.mock_calls[1][0]
assert "/foo/bar" == mock_generator.mock_calls[1][1][0]
assert mock_generator.mock_calls[1][1][0] is None

# Provide existent file path,
# so loaded YAML file content should be used to call deserialize_model_dict(), rather than None.
path = sys.path[0] + '/tests/kubernetes/pod.yaml'
with open(path) as stream:
expected_pod_dict = yaml.safe_load(stream)

pod_template_file_path = path
get_base_pod_from_template(pod_template_file_path, None)
assert "deserialize_model_dict" == mock_generator.mock_calls[2][0]
assert mock_generator.mock_calls[2][1][0] == expected_pod_dict

mock_kubeconfig.pod_template_file = path
get_base_pod_from_template(None, mock_kubeconfig)
assert "deserialize_model_dict" == mock_generator.mock_calls[3][0]
assert mock_generator.mock_calls[3][1][0] == expected_pod_dict

def test_make_safe_label_value(self):
for dag_id, task_id in self._cases():
Expand Down Expand Up @@ -228,8 +249,6 @@ def test_run_next_exception_requeue(
- 400 BadRequest is returned when your parameters are invalid e.g. asking for cpu=100ABC123.

"""
import sys

path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml'

response = HTTPResponse(body='{"message": "any message"}', status=status)
Expand Down Expand Up @@ -283,8 +302,6 @@ def test_run_next_pod_reconciliation_error(self, mock_get_kube_client, mock_kube
"""
When construct_pod raises PodReconciliationError, we should fail the task.
"""
import sys

path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml'

mock_kube_client = mock.patch('kubernetes.client.CoreV1Api', autospec=True)
Expand Down
34 changes: 10 additions & 24 deletions tests/kubernetes/test_pod_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,11 +712,20 @@ def test_reconcile_specs_init_containers(self):
res = PodGenerator.reconcile_specs(base_spec, client_spec)
assert res.init_containers == base_spec.init_containers + client_spec.init_containers

def test_deserialize_model_file(self):
def test_deserialize_model_file(self, caplog):
path = sys.path[0] + '/tests/kubernetes/pod.yaml'
result = PodGenerator.deserialize_model_file(path)
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
assert sanitized_res == self.deserialize_result
assert len(caplog.records) == 0

def test_deserialize_non_existent_model_file(self, caplog):
path = sys.path[0] + '/tests/kubernetes/non_existent.yaml'
result = PodGenerator.deserialize_model_file(path)
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
assert sanitized_res == {}
assert len(caplog.records) == 1
assert 'does not exist' in caplog.text

@parameterized.expand(
(
Expand Down Expand Up @@ -761,29 +770,6 @@ def test_pod_name_is_valid(self, pod_id, expected_starts_with):

assert name.rsplit("-", 1)[0] == expected_starts_with

def test_deserialize_model_string(self):
fixture = """
apiVersion: v1
kind: Pod
metadata:
name: memory-demo
namespace: mem-example
spec:
containers:
- name: memory-demo-ctr
image: ghcr.io/apache/airflow-stress:1.0.4-2021.07.04
resources:
limits:
memory: "200Mi"
requests:
memory: "100Mi"
command: ["stress"]
args: ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"]
"""
result = PodGenerator.deserialize_model_file(fixture)
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
assert sanitized_res == self.deserialize_result

def test_validate_pod_generator(self):
with pytest.raises(AirflowConfigException):
PodGenerator(pod=k8s.V1Pod(), pod_template_file='k')
Expand Down