Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve user messaging for arrowload command #28

Merged
merged 6 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libraries/arrowload.ado
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ prog define arrowload
syntax anything(name=filename) ///
[, ///
Configfile(string) ///
Verbosity(integer 3) ///
]
if ( mi("`configfile'") ) local configfile none
python script /python_scripts/load_arrow.py, args(`filename' `configfile')
python script /python_scripts/load_arrow.py, args(`filename' `configfile' `verbosity')
end
26 changes: 20 additions & 6 deletions libraries/arrowload.sthlp
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,26 @@ Syntax

arrowload filename [, options]

options Description
options Description
---------------------------------------------------------------------------------
Main
configfile(path/to/csv/file) A csv file containing configuration for the import
verbosity(#) An integer to control the level of output;
default verbosity(3)
---------------------------------------------------------------------------------


Description
-----------

arrowload loads an arrow file
arrowload loads an arrow file into the stata dataset.

Columns in the arrow file are converted to the most appropriate stata variable type.

NOTE: Integer types will be loaded as byte, int or long, depending on the values of the
column in the arrow data. If the arrow data contains integers that are larger than 32 bit,
which cannot be represented as an integer in stata, the column will be loaded as a
string column, and values will be string representations of the integers.


Options
Expand All @@ -35,10 +44,15 @@ Options
long for use in stata. However, it is preferable fix the input file to use valid
names. Some or all variable names can be mapped.

verbosity determine the level of output
- 1: prints minimal info messages
- 2: prints warnings only
- 3: prints all info messages (the default)


Examples
--------

. arrowload "/workspace/dataset.arrow"

. arrowload "/workspace/dataset.arrow", configfile("/workspace/config.csv")
. arrowload /workspace/dataset.arrow
. arrowload /workspace/dataset.arrow, verbosity(1)
. arrowload /workspace/dataset.arrow, configfile(/workspace/config.csv)
. arrowload /workspace/dataset.arrow, configfile(/workspace/config.csv) verbosity(2)
114 changes: 89 additions & 25 deletions python_scripts/load_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,34 @@
)


def main(filename, configfile=None):
converter = ArrowConverter(filename, configfile)
def main(filename, configfile=None, verbosity=3):
if not Path(filename).exists():
print_error_and_exit(f"Arrow file not found at {filename}", 601)
if sfi.Data.getVarCount() > 0:
# exit with the standard "dataset in memory has changed" message
print_error_and_exit(None, 4)
converter = ArrowConverter(filename, configfile, verbosity)
converter.load_data()


def print_error_and_exit(message, error_code=7103):
"""
A helper function to log an error message and then exit.
This lets us catch expected errors, and print a user-friendly
message in the stata output, rather than an unexpected python
error.
If a relevant stata error code exists, we error with it here.
Otherwise, default to 7103, which is a generic stata return code
when it can't execute a python script
"""
if message:
sfi.SFIToolkit.errprintln(message)
sfi.SFIToolkit.error(error_code)


class ArrowConverter:
def __init__(self, filename, configfile):
def __init__(self, filename, configfile, verbosity):
self.set_verbosity(verbosity)
# Read config file if there is one, and identify aliases to use for column naming
config = self.read_config(configfile)
self.aliases = self.get_aliases_from_config(config)
Expand Down Expand Up @@ -53,6 +74,35 @@ def __init__(self, filename, configfile):
self.value_labels = None
self.arrow_batches = self.iter_arrow_batches(filename)

def set_verbosity(self, verbosity):
# verbosity can take one of 3 levels
# 3 = print all
# 2 = print warnings only
# 1 = print nothing
if verbosity < 1 or verbosity > 3:
self.verbosity = 3
self.display(f"Unknown verbosity level {verbosity}, defaulting to 3.")
else:
self.verbosity = verbosity

def display(self, message, is_warning=False, force=False):
if force or self.verbosity == 3 or (self.verbosity == 2 and is_warning):
sfi.SFIToolkit.displayln(message)

def warn(self, message):
self.display(message, is_warning=True)

def run_stata_command(self, command):
"""
Run a stata command. If not using the highest verbosity,
suppress the stata output
"""
if self.verbosity < 3:
prefix = "quietly: "
else:
prefix = ""
sfi.SFIToolkit.stata(f"{prefix}{command}")

def read_config(self, configfile):
"""
Read an optional config CSV file.
Expand All @@ -63,14 +113,14 @@ def read_config(self, configfile):
return config
config_path = Path(configfile)
if not config_path.exists():
print(f"WARNING: Config file not found at {configfile}")
self.warn(f"WARNING: Config file not found at {configfile}")
return config

with open(config_path) as config_csv:
reader = csv.DictReader(config_csv)
config = [row for row in reader]
if not config:
print(
self.warn(
f"WARNING: No data found in configfile {configfile}; does it contain headers?"
)
return config
Expand All @@ -83,15 +133,15 @@ def get_aliases_from_config(self, config):
expected_headers = {"original_column", "aliased_column"}
first_row = config[0]
if expected_headers - set(first_row.keys()):
print(
self.warn(
"WARNING: file does not contain expected column headers for aliases "
"(original_column, aliased_column)"
)
return aliases
aliases = {row["original_column"]: row["aliased_column"] for row in config}
too_long_aliases = any(key for key, value in aliases.items() if len(value) > 32)
if too_long_aliases:
raise ValueError(
print_error_and_exit(
"Config file contains aliases longer than the allowed length (32)"
)
return aliases
Expand Down Expand Up @@ -121,9 +171,11 @@ def get_range_for_column(self, batch, column_name, column_type="int"):
col_range = compute.min_max(batch[column_name]).as_py()
return col_range["min"], col_range["max"]

assert (
column_type == "category"
), f"Can only get range for int or category types, got {column_type}"
if column_type != "category":
print_error_and_exit(
f"Attempted to determine value range for unexpected type: {column_type}. "
"Expected one of byte, int, long or category."
)
return (0, len(batch[column_name].dictionary))

def get_stata_type_from_range(self, batch, min_val, max_val, column_name):
Expand All @@ -149,7 +201,9 @@ def get_stata_type_from_range(self, batch, min_val, max_val, column_name):

if batch_type is None:
# range is too big for stata integer types
print(f"Column {column_name} is out of integer range; converting to string")
self.display(
f"Column '{column_name}' is out of integer range; converting to string"
)
batch = self.convert_int64_column_to_string(batch, column_name)
batch_type = "string"

Expand Down Expand Up @@ -234,14 +288,16 @@ def clean_names(self, batch):
self.aliases.get(varname, varname), prefix=False
)
if cleaned_name != varname:
print(f"{varname} aliased to {cleaned_name}")
self.display(f"'{varname}' aliased to '{cleaned_name}'")
self.column_names.append(cleaned_name)

if too_long_names:
raise ValueError(
f"Invalid variable names found ({','.join(too_long_names)})\n"
f"To fix this, rename variables in arrow file to <32 characters.\n"
f"Alternatively, a CSV file of original to alias names can be provided."
quoted_names = [f"'{name}'" for name in too_long_names]
print_error_and_exit(
f"Invalid variable names found ({','.join(quoted_names)})\n"
f"To fix this, rename variables in the arrow file to <32 characters.\n"
f"Alternatively, a CSV configfile containing of original to aliased names "
"can be provided."
)

# Generate a new batch with the cleaned/aliased column names.
Expand Down Expand Up @@ -393,7 +449,9 @@ def make_vars(self, pre_processed_column_types, batch):
elif vartype == "float":
sfi.Data.addVarFloat(varname)
else:
assert False, f"Unhandled type: {vartype}"
print_error_and_exit(
f"Unhandled type: {vartype} for column '{varname}'"
)
else:
# If any columns have changed required type when we process a
# subsequent batch, recast the existing stata column
Expand All @@ -404,18 +462,20 @@ def make_vars(self, pre_processed_column_types, batch):
if self.column_types[col] != pre_processed_column_types[col]
}
for changed_col, changed_type in changed_cols.items():
print(f"Converting {changed_cols}")
self.display(f"Converting {changed_cols}")
if changed_type == "string":
# If the variable has changed in a subsequent batch to string type, it
# means it was previously considered integer type and is now too big to
# fit into `long`. recast doesn't work in this case; we need to change it
# to a new string column, drop the old column and rename it.
sfi.SFIToolkit.stata(f"tostring {changed_col}, gen({changed_col}1)")
self.run_stata_command(
f"tostring {changed_col}, gen({changed_col}1)"
)
sfi.Data.dropVar(changed_col)
sfi.SFIToolkit.stata(f"rename {changed_col}1 {changed_col}")
self.run_stata_command(f"rename {changed_col}1 {changed_col}")
else:
cast_to = column_type_mappings.get(changed_type, changed_type)
sfi.SFIToolkit.stata(f"recast {cast_to} {changed_col}")
self.run_stata_command(f"recast {cast_to} {changed_col}")

def define_value_labels(self):
"""
Expand Down Expand Up @@ -469,11 +529,11 @@ def replace_stata_missing_and_recast(self):
if column_type not in ["boolean", "byte", "int", "long", "date"]:
continue
column_type = column_type_mappings.get(column_type, column_type)
print(f"Finalising column {column_name} (type ({column_type})")
sfi.SFIToolkit.stata(
self.display(f"Finalising column '{column_name}' (type ({column_type})")
self.run_stata_command(
f"replace {column_name} = . if {column_name} == {self.MISSING_VALUES[column_type]}"
)
sfi.SFIToolkit.stata(f"recast {column_type} {column_name}")
self.run_stata_command(f"recast {column_type} {column_name}")

def process_batch(self, batch):
"""
Expand Down Expand Up @@ -509,7 +569,8 @@ def process_batch(self, batch):
def load_data(self):
next_obs = 0
for batch, batch_num, total in self.arrow_batches:
print(f"Reading batch {batch_num} of {total}")
# Always report batch progress, whatever the requested verbosity
self.display(f"Reading batch {batch_num} of {total}", force=True)
batch, pre_processed_column_types = self.process_batch(batch)
# make variables
self.make_vars(pre_processed_column_types, batch)
Expand Down Expand Up @@ -546,6 +607,9 @@ def parse_args():
# can ignore it
if sys.argv[2] != "none":
args.update(configfile=sys.argv[2])
True if sys.argv[3] == "true" else False
verbosity = int(sys.argv[3])
args.update(verbosity=verbosity)
return args


Expand Down
1 change: 1 addition & 0 deletions tests/analysis/arrowload/arrowload-verbosity-default.do
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
. arrowload fixtures/data.arrow, configfile(non_existent_file) verbosity(4)
1 change: 1 addition & 0 deletions tests/analysis/arrowload/arrowload-verbosity-none.do
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
. arrowload fixtures/data.arrow, configfile(non_existent_file) verbosity(1)
1 change: 1 addition & 0 deletions tests/analysis/arrowload/arrowload-verbosity-warning.do
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
. arrowload fixtures/data.arrow, configfile(non_existent_file) verbosity(2)
2 changes: 2 additions & 0 deletions tests/analysis/arrowload/arrowload-with-existing-data.do
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
. generate foo=1
. arrowload fixtures/data.arrow
60 changes: 55 additions & 5 deletions tests/test_arrowload.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_arrowload_multiple_batch():

def test_arrowload_too_long_variable_names():
"""
Variable names that are too long for stata variables raise an error
Variable names that are too long for stata variables print an error message and exit
"""
return_code, output, _ = run_stata("analysis/arrowload/arrowload-too-long.do")
assert return_code == 1
Expand All @@ -261,11 +261,12 @@ def test_arrowload_aliased_long_variable_names():

def test_arrowload_bad_aliases():
"""
Aliased variable names that are too long for stata variables raise an error
Aliased variable names that are too long for stata variables print an error
message and exit
"""
return_code, output, _ = run_stata("analysis/arrowload/arrowload-bad-aliases.do")
assert return_code == 1
assert "aliases longer than the allowed length" in output
assert "Config file contains aliases longer than the allowed length" in output


def test_arrowload_config_file_not_found():
Expand Down Expand Up @@ -314,6 +315,55 @@ def test_arrowload_aliases_with_multiple_batches():
return_code, output, _ = run_stata(
"analysis/arrowload/arrowload-batches-aliased.do"
)
assert "i3a aliased to aliased_i3a" in output
assert "s1 aliased to aliased_s1" in output
assert "'i3a' aliased to 'aliased_i3a'" in output
assert "'s1' aliased to 'aliased_s1'" in output
assert return_code == 0


def test_arrowload_data_exists():
"""
Test that loading an arrow file when data already exists returns
the expected error message
"""
return_code, output, _ = run_stata(
"analysis/arrowload/arrowload-with-existing-data.do"
)
assert return_code == 1
assert "no; dataset in memory has changed since last saved" in output


def test_arrowload_verbosity():
# All .do files in this test attempt to load valid data with a
# non-existent configfile, which outputs a warning, if verbosity
# level allows

# the progress message is always shown
progress_message = "Reading batch 1 of 1"
# warnings are shown with verbosity level 2 or 3
warning_message = "WARNING: Config file not found"
# other output messages are only shown with verbosity level 3
info_message = "Finalising column"

# Verbosity level 2
return_code, output, _ = run_stata(
"analysis/arrowload/arrowload-verbosity-warning.do"
)
assert return_code == 0
for message in [progress_message, warning_message]:
assert message in output
assert info_message not in output

# Verbosity level 1
return_code, output, _ = run_stata("analysis/arrowload/arrowload-verbosity-none.do")
assert return_code == 0
assert progress_message in output
for message in [warning_message, info_message]:
assert message not in output

# Verbosity level 4 (invalid, defaults to 3)
return_code, output, _ = run_stata(
"analysis/arrowload/arrowload-verbosity-default.do"
)
assert return_code == 0
for message in [progress_message, warning_message, info_message]:
assert message in output