Skip to content

Commit

Permalink
Allow 'airflow variables export' to print to stdout (#33279)
Browse files Browse the repository at this point in the history
Co-authored-by: vedantlodha <[email protected]>
  • Loading branch information
uranusjr and vedantlodha authored Aug 11, 2023
1 parent bfa09da commit 09d478e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 66 deletions.
10 changes: 9 additions & 1 deletion airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,11 @@ def string_lower_type(val):
ARG_DESERIALIZE_JSON = Arg(("-j", "--json"), help="Deserialize JSON variable", action="store_true")
ARG_SERIALIZE_JSON = Arg(("-j", "--json"), help="Serialize JSON variable", action="store_true")
ARG_VAR_IMPORT = Arg(("file",), help="Import variables from JSON file")
ARG_VAR_EXPORT = Arg(("file",), help="Export all variables to JSON file")
ARG_VAR_EXPORT = Arg(
("file",),
help="Export all variables to JSON file",
type=argparse.FileType("w", encoding="UTF-8"),
)

# kerberos
ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs="?")
Expand Down Expand Up @@ -1521,6 +1525,10 @@ class GroupCommand(NamedTuple):
ActionCommand(
name="export",
help="Export all variables",
description=(
"All variables can be exported in STDOUT using the following command:\n"
"airflow variables export -\n"
),
func=lazy_load_command("airflow.cli.commands.variable_command.variables_export"),
args=(ARG_VAR_EXPORT, ARG_VERBOSE),
),
Expand Down
54 changes: 24 additions & 30 deletions airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""Connection sub-commands."""
from __future__ import annotations

import io
import json
import os
import sys
Expand All @@ -30,6 +29,7 @@
from sqlalchemy.orm import exc

from airflow.cli.simple_table import AirflowConsole
from airflow.cli.utils import is_stdout
from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.exceptions import AirflowNotFoundException
Expand Down Expand Up @@ -138,10 +138,6 @@ def _format_connections(conns: list[Connection], file_format: str, serialization
return json.dumps(connections_dict)


def _is_stdout(fileio: io.TextIOWrapper) -> bool:
return fileio.name == "<stdout>"


def _valid_uri(uri: str) -> bool:
"""Check if a URI is valid, by checking if scheme (conn_type) provided."""
return urlsplit(uri).scheme != ""
Expand Down Expand Up @@ -171,32 +167,30 @@ def connections_export(args):
if args.format or args.file_format:
provided_file_format = f".{(args.format or args.file_format).lower()}"

file_is_stdout = _is_stdout(args.file)
if file_is_stdout:
filetype = provided_file_format or default_format
elif provided_file_format:
filetype = provided_file_format
else:
filetype = Path(args.file.name).suffix
filetype = filetype.lower()
if filetype not in file_formats:
raise SystemExit(
f"Unsupported file format. The file must have the extension {', '.join(file_formats)}."
)

if args.serialization_format and not filetype == ".env":
raise SystemExit("Option `--serialization-format` may only be used with file type `env`.")

with create_session() as session:
connections = session.scalars(select(Connection).order_by(Connection.conn_id)).all()

msg = _format_connections(
conns=connections,
file_format=filetype,
serialization_format=args.serialization_format or "uri",
)

with args.file as f:
if file_is_stdout := is_stdout(f):
filetype = provided_file_format or default_format
elif provided_file_format:
filetype = provided_file_format
else:
filetype = Path(args.file.name).suffix.lower()
if filetype not in file_formats:
raise SystemExit(
f"Unsupported file format. The file must have the extension {', '.join(file_formats)}."
)

if args.serialization_format and not filetype == ".env":
raise SystemExit("Option `--serialization-format` may only be used with file type `env`.")

with create_session() as session:
connections = session.scalars(select(Connection).order_by(Connection.conn_id)).all()

msg = _format_connections(
conns=connections,
file_format=filetype,
serialization_format=args.serialization_format or "uri",
)

f.write(msg)

if file_is_stdout:
Expand Down
61 changes: 26 additions & 35 deletions airflow/cli/commands/variable_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@

import json
import os
import sys
from json import JSONDecodeError

from sqlalchemy import select

from airflow.cli.simple_table import AirflowConsole
from airflow.cli.utils import is_stdout
from airflow.models import Variable
from airflow.utils import cli as cli_utils
from airflow.utils.cli import suppress_logs_and_warning
Expand Down Expand Up @@ -76,44 +78,30 @@ def variables_delete(args):
@providers_configuration_loaded
def variables_import(args):
"""Imports variables from a given file."""
if os.path.exists(args.file):
_import_helper(args.file)
else:
if not os.path.exists(args.file):
raise SystemExit("Missing variables file.")
with open(args.file) as varfile:
try:
var_json = json.load(varfile)
except JSONDecodeError:
raise SystemExit("Invalid variables file.")
suc_count = fail_count = 0
for k, v in var_json.items():
try:
Variable.set(k, v, serialize_json=not isinstance(v, str))
except Exception as e:
print(f"Variable import failed: {repr(e)}")
fail_count += 1
else:
suc_count += 1
print(f"{suc_count} of {len(var_json)} variables successfully updated.")
if fail_count:
print(f"{fail_count} variable(s) failed to be updated.")


@providers_configuration_loaded
def variables_export(args):
"""Exports all the variables to the file."""
_variable_export_helper(args.file)


def _import_helper(filepath):
"""Helps import variables from the file."""
with open(filepath) as varfile:
data = varfile.read()

try:
var_json = json.loads(data)
except JSONDecodeError:
raise SystemExit("Invalid variables file.")
else:
suc_count = fail_count = 0
for k, v in var_json.items():
try:
Variable.set(k, v, serialize_json=not isinstance(v, str))
except Exception as e:
print(f"Variable import failed: {repr(e)}")
fail_count += 1
else:
suc_count += 1
print(f"{suc_count} of {len(var_json)} variables successfully updated.")
if fail_count:
print(f"{fail_count} variable(s) failed to be updated.")


def _variable_export_helper(filepath):
"""Helps export all the variables to the file."""
var_dict = {}
with create_session() as session:
qry = session.scalars(select(Variable))
Expand All @@ -126,6 +114,9 @@ def _variable_export_helper(filepath):
val = var.val
var_dict[var.key] = val

with open(filepath, "w") as varfile:
varfile.write(json.dumps(var_dict, sort_keys=True, indent=4))
print(f"{len(var_dict)} variables successfully exported to {filepath}")
with args.file as varfile:
json.dump(var_dict, varfile, sort_keys=True, indent=4)
if is_stdout(varfile):
print("\nVariables successfully exported.", file=sys.stderr)
else:
print(f"Variables successfully exported to {varfile.name}.")
33 changes: 33 additions & 0 deletions airflow/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import io
import sys


def is_stdout(fileio: io.IOBase) -> bool:
"""Check whether a file IO is stdout.
The intended use case for this helper is to check whether an argument parsed
with argparse.FileType points to stdout (by setting the path to ``-``). This
is why there is no equivalent for stderr; argparse does not allow using it.
.. warning:: *fileio* must be open for this check to be successful.
"""
return fileio.fileno() == sys.stdout.fileno()

0 comments on commit 09d478e

Please sign in to comment.