-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pytest tmp_dir fixture #706
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,12 +136,12 @@ def item_has_header(self): | |
|
||
# Warning and error messages | ||
WARNING_MOVIE_LENS_HEADER = """MovieLens rating dataset has four columns | ||
(user id, movie id, rating, and timestamp), but more than four column headers are provided. | ||
Will only use the first four column headers.""" | ||
(user id, movie id, rating, and timestamp), but more than four column names are provided. | ||
Will only use the first four column names.""" | ||
WARNING_HAVE_SCHEMA_AND_HEADER = """Both schema and header are provided. | ||
The header argument will be ignored.""" | ||
ERROR_MOVIE_LENS_SIZE = "Invalid data size. Should be one of {100k, 1m, 10m, or 20m}" | ||
ERROR_NO_HEADER = "No header (schema) information" | ||
ERROR_NO_HEADER = "No header (schema) information. At least user and movie column names should be provided" | ||
|
||
|
||
def load_pandas_df( | ||
|
@@ -187,13 +187,14 @@ def load_pandas_df( | |
size = size.lower() | ||
if size not in DATA_FORMAT: | ||
raise ValueError(ERROR_MOVIE_LENS_SIZE) | ||
if header is None or len(header) == 0: | ||
raise ValueError(ERROR_NO_HEADER) | ||
|
||
if len(header) > 4: | ||
if header is None or len(header) < 2: | ||
raise ValueError(ERROR_NO_HEADER) | ||
elif len(header) > 4: | ||
warnings.warn(WARNING_MOVIE_LENS_HEADER) | ||
header = header[:4] | ||
movie_col = DEFAULT_ITEM_COL if len(header) < 2 else header[1] | ||
|
||
movie_col = header[1] | ||
|
||
with download_path(local_cache_path) as path: | ||
filepath = os.path.join(path, "ml-{}.zip".format(size)) | ||
|
@@ -205,10 +206,6 @@ def load_pandas_df( | |
) | ||
|
||
# Load rating data | ||
if len(header) == 1 and item_df is not None: | ||
# MovieID should be loaded to merge rating df w/ item_df | ||
header = [header[0], movie_col] | ||
|
||
df = pd.read_csv( | ||
datapath, | ||
sep=DATA_FORMAT[size].separator, | ||
|
@@ -268,11 +265,11 @@ def load_item_df( | |
|
||
def _load_item_df(size, item_datapath, movie_col, title_col, genres_col, year_col): | ||
"""Loads Movie info""" | ||
item_header = [] | ||
usecols = [] | ||
if movie_col is not None: | ||
item_header.append(movie_col) | ||
usecols.append(0) | ||
if title_col is None and genres_col is None and year_col is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like the check now is different from before? The previous checks the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, made |
||
return None | ||
|
||
item_header = [movie_col] | ||
usecols = [0] | ||
|
||
# Year is parsed from title | ||
if title_col is not None or year_col is not None: | ||
|
@@ -291,9 +288,6 @@ def _load_item_df(size, item_datapath, movie_col, title_col, genres_col, year_co | |
item_header.append(genres_col) | ||
usecols.append(2) # genres column | ||
|
||
if len(item_header) == 0: | ||
return None | ||
|
||
item_df = pd.read_csv( | ||
item_datapath, | ||
sep=DATA_FORMAT[size].item_separator, | ||
|
@@ -390,17 +384,17 @@ def load_spark_df( | |
... ) | ||
|
||
On DataBricks, pass the dbutils argument as follows: | ||
>>> spark_df = load_spark_df(spark, ..., dbutils=dbutils) | ||
>>> spark_df = load_spark_df(spark, dbutils=dbutils) | ||
""" | ||
size = size.lower() | ||
if size not in DATA_FORMAT: | ||
raise ValueError(ERROR_MOVIE_LENS_SIZE) | ||
|
||
schema = _get_schema(header, schema) | ||
if schema is None: | ||
if schema is None or len(schema) < 2: | ||
raise ValueError(ERROR_NO_HEADER) | ||
|
||
movie_col = DEFAULT_ITEM_COL if len(schema) < 2 else schema[1].name | ||
movie_col = schema[1].name | ||
|
||
with download_path(local_cache_path) as path: | ||
filepath = os.path.join(path, "ml-{}.zip".format(size)) | ||
|
@@ -410,11 +404,8 @@ def load_spark_df( | |
# Load movie features such as title, genres, and release year. | ||
# Since the file size is small, we directly load as pd.DataFrame from the driver node | ||
# and then convert into spark.DataFrame | ||
item_df = spark.createDataFrame( | ||
_load_item_df( | ||
size, item_datapath, movie_col, title_col, genres_col, year_col | ||
) | ||
) | ||
item_pd_df = _load_item_df(size, item_datapath, movie_col, title_col, genres_col, year_col) | ||
item_df = spark.createDataFrame(item_pd_df) if item_pd_df is not None else None | ||
|
||
if is_databricks(): | ||
if dbutils is None: | ||
|
@@ -430,11 +421,6 @@ def load_spark_df( | |
dbutils.fs.mv(spark_datapath, dbfs_datapath) | ||
spark_datapath = dbfs_datapath | ||
|
||
# Load rating data | ||
if len(schema) == 1 and item_df is not None: | ||
# MovieID should be loaded to merge rating df w/ item_df | ||
schema.add(StructField(movie_col, IntegerType())) | ||
|
||
# pySpark's read csv currently doesn't support multi-character delimiter, thus we manually handle that | ||
separator = DATA_FORMAT[size].separator | ||
if len(separator) > 1: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,16 +4,17 @@ | |
# NOTE: This file is used by pytest to inject fixtures automatically. As it is explained in the documentation | ||
# https://docs.pytest.org/en/latest/fixture.html: | ||
# "If during implementing your tests you realize that you want to use a fixture function from multiple test files | ||
# you can move it to a conftest.py file. You don’t need to import the fixture you want to use in a test, it | ||
# automatically gets discovered by pytest." | ||
# you can move it to a conftest.py file. You don't need to import the module you defined your fixtures to use in a test, | ||
# it automatically gets discovered by pytest and thus you can simply receive fixture objects by naming them as | ||
# an input argument in the test." | ||
|
||
import calendar | ||
import datetime | ||
import os | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from sklearn.model_selection import train_test_split | ||
from tempfile import TemporaryDirectory | ||
from tests.notebooks_common import path_notebooks | ||
from reco_utils.common.general_utils import get_number_processors, get_physical_memory | ||
|
||
|
@@ -23,6 +24,12 @@ | |
pass # so the environment without spark doesn't break | ||
|
||
|
||
@pytest.fixture | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also like the context manager option, maybe it can be implemented as we are discussing here: #701 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me try out that. looks promising. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. checked and turns out normal fixture does the same thing since newer version of pytest.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good, so if we use the fixture as you program it using: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. Even try/finally does remove the dir too, but using the context manager made the code simpler and give us peace of mind :-) |
||
def tmp(tmp_path_factory): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks nice and clean ❤️ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kudos to @miguelgfierro @anargyri! |
||
with TemporaryDirectory(dir=tmp_path_factory.getbasetemp()) as td: | ||
yield td | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def spark(app_name="Sample", url="local[*]"): | ||
"""Start Spark if not started. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really like these constant string of error/warning message. Do we later want to consolidate this into somewhere in
common/constants
? Or, at least, adopt this practice in other utility codes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we need, sure why not. We can even generalize further, e.g. use
obj.__name__
or to accept names as args.