Skip to content

Commit

Permalink
feat: Add geospatial post processing operations (#9661)
Browse files Browse the repository at this point in the history
* feat: Add geospatial post processing operations

* Linting

* Refactor

* Add tests

* Improve docs

* Address comments

* fix latitude/longitude mixup

* fix: bad refactor by pycharm
  • Loading branch information
villebro authored Apr 28, 2020
1 parent c474ea8 commit a52cfcd
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 21 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ flask-talisman==0.7.0 # via apache-superset (setup.py)
flask-wtf==0.14.2 # via apache-superset (setup.py), flask-appbuilder
flask==1.1.1 # via apache-superset (setup.py), flask-appbuilder, flask-babel, flask-caching, flask-compress, flask-jwt-extended, flask-login, flask-migrate, flask-openid, flask-sqlalchemy, flask-wtf
geographiclib==1.50 # via geopy
geopy==1.20.0 # via apache-superset (setup.py)
geopy==1.21.0 # via apache-superset (setup.py)
gunicorn==20.0.4 # via apache-superset (setup.py)
humanize==0.5.1 # via apache-superset (setup.py)
importlib-metadata==1.4.0 # via jsonschema, kombu
Expand Down
84 changes: 82 additions & 2 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,23 @@ class ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema
columns = fields.List(
fields.String(),
description="Columns which to select from the input data, in the desired "
"order. If columns are renamed, the old column name should be "
"order. If columns are renamed, the original column name should be "
"referenced here.",
example=["country", "gender", "age"],
required=False,
)
exclude = fields.List(
fields.String(),
description="Columns to exclude from selection.",
example=["my_temp_column"],
required=False,
)
rename = fields.List(
fields.Dict(),
description="columns which to rename, mapping source column to target column. "
"For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.",
example=[{"age": "average_age"}],
required=False,
)


Expand Down Expand Up @@ -335,12 +343,81 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema)
aggregates = ChartDataAggregateConfigField()


class ChartDataGeohashDecodeOptionsSchema(
ChartDataPostProcessingOperationOptionsSchema
):
"""
Geohash decode operation config.
"""

geohash = fields.String(
description="Name of source column containing geohash string", required=True,
)
latitude = fields.String(
description="Name of target column for decoded latitude", required=True,
)
longitude = fields.String(
description="Name of target column for decoded longitude", required=True,
)


class ChartDataGeohashEncodeOptionsSchema(
ChartDataPostProcessingOperationOptionsSchema
):
"""
Geohash encode operation config.
"""

latitude = fields.String(
description="Name of source latitude column", required=True,
)
longitude = fields.String(
description="Name of source longitude column", required=True,
)
geohash = fields.String(
description="Name of target column for encoded geohash string", required=True,
)


class ChartDataGeodeticParseOptionsSchema(
ChartDataPostProcessingOperationOptionsSchema
):
"""
Geodetic point string parsing operation config.
"""

geodetic = fields.String(
description="Name of source column containing geodetic point strings",
required=True,
)
latitude = fields.String(
description="Name of target column for decoded latitude", required=True,
)
longitude = fields.String(
description="Name of target column for decoded longitude", required=True,
)
altitude = fields.String(
description="Name of target column for decoded altitude. If omitted, "
"altitude information in geodetic string is ignored.",
required=False,
)


class ChartDataPostProcessingOperationSchema(Schema):
operation = fields.String(
description="Post processing operation type",
required=True,
validate=validate.OneOf(
choices=("aggregate", "pivot", "rolling", "select", "sort")
choices=(
"aggregate",
"geodetic_parse",
"geohash_decode",
"geohash_encode",
"pivot",
"rolling",
"select",
"sort",
)
),
example="aggregate",
)
Expand Down Expand Up @@ -638,4 +715,7 @@ class ChartDataResponseSchema(Schema):
ChartDataRollingOptionsSchema,
ChartDataSelectOptionsSchema,
ChartDataSortOptionsSchema,
ChartDataGeohashDecodeOptionsSchema,
ChartDataGeohashEncodeOptionsSchema,
ChartDataGeodeticParseOptionsSchema,
)
121 changes: 110 additions & 11 deletions superset/utils/pandas_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import geohash as geohash_lib
import numpy as np
from flask_babel import gettext as _
from geopy.point import Point
from pandas import DataFrame, NamedAgg

from superset.exceptions import QueryObjectValidationError
Expand Down Expand Up @@ -144,10 +146,7 @@ def _append_columns(
:return: new DataFrame with combined data from `base_df` and `append_df`
"""
return base_df.assign(
**{
target: append_df[append_df.columns[idx]]
for idx, target in enumerate(columns.values())
}
**{target: append_df[source] for source, target in columns.items()}
)


Expand Down Expand Up @@ -323,25 +322,34 @@ def rolling( # pylint: disable=too-many-arguments
return df


@validate_column_args("columns", "rename")
@validate_column_args("columns", "drop", "rename")
def select(
df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None
df: DataFrame,
columns: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
rename: Optional[Dict[str, str]] = None,
) -> DataFrame:
"""
Only select a subset of columns in the original dataset. Can be useful for
removing unnecessary intermediate results, renaming and reordering columns.
:param df: DataFrame on which the rolling period will be based.
:param columns: Columns which to select from the DataFrame, in the desired order.
If columns are renamed, the old column name should be referenced
here.
If left undefined, all columns will be selected. If columns are
renamed, the original column name should be referenced here.
:param exclude: columns to exclude from selection. If columns are renamed, the new
column name should be referenced here.
:param rename: columns which to rename, mapping source column to target column.
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
:raises ChartDataValidationError: If the request in incorrect
"""
df_select = df[columns]
df_select = df.copy(deep=False)
if columns:
df_select = df_select[columns]
if exclude:
df_select = df_select.drop(exclude, axis=1)
if rename is not None:
df_select = df_select.rename(columns=rename)
return df_select
Expand All @@ -350,6 +358,7 @@ def select(
@validate_column_args("columns")
def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame:
"""
Calculate row-by-row difference for select columns.
:param df: DataFrame on which the diff will be based.
:param columns: columns on which to perform diff, mapping source column to
Expand All @@ -369,6 +378,7 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame
@validate_column_args("columns")
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
"""
Calculate cumulative sum/product/min/max for select columns.
:param df: DataFrame on which the cumulative operation will be based.
:param columns: columns on which to perform a cumulative operation, mapping source
Expand All @@ -377,7 +387,7 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
`y2` based on cumulative values calculated from `y`, leaving the original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
:return:
:return: DataFrame with cumulated columns
"""
df_cum = df[columns.keys()]
operation = "cum" + operator
Expand All @@ -388,3 +398,92 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
return _append_columns(df, getattr(df_cum, operation)(), columns)


def geohash_decode(
df: DataFrame, geohash: str, longitude: str, latitude: str
) -> DataFrame:
"""
Decode a geohash column into longitude and latitude
:param df: DataFrame containing geohash data
:param geohash: Name of source column containing geohash location.
:param longitude: Name of new column to be created containing longitude.
:param latitude: Name of new column to be created containing latitude.
:return: DataFrame with decoded longitudes and latitudes
"""
try:
lonlat_df = DataFrame()
lonlat_df["latitude"], lonlat_df["longitude"] = zip(
*df[geohash].apply(geohash_lib.decode)
)
return _append_columns(
df, lonlat_df, {"latitude": latitude, "longitude": longitude}
)
except ValueError:
raise QueryObjectValidationError(_("Invalid geohash string"))


def geohash_encode(
df: DataFrame, geohash: str, longitude: str, latitude: str,
) -> DataFrame:
"""
Encode longitude and latitude into geohash
:param df: DataFrame containing longitude and latitude data
:param geohash: Name of new column to be created containing geohash location.
:param longitude: Name of source column containing longitude.
:param latitude: Name of source column containing latitude.
:return: DataFrame with decoded longitudes and latitudes
"""
try:
encode_df = df[[latitude, longitude]]
encode_df.columns = ["latitude", "longitude"]
encode_df["geohash"] = encode_df.apply(
lambda row: geohash_lib.encode(row["latitude"], row["longitude"]), axis=1,
)
return _append_columns(df, encode_df, {"geohash": geohash})
except ValueError:
QueryObjectValidationError(_("Invalid longitude/latitude"))


def geodetic_parse(
df: DataFrame,
geodetic: str,
longitude: str,
latitude: str,
altitude: Optional[str] = None,
) -> DataFrame:
"""
Parse a column containing a geodetic point string
[Geopy](https://geopy.readthedocs.io/en/stable/#geopy.point.Point).
:param df: DataFrame containing geodetic point data
:param geodetic: Name of source column containing geodetic point string.
:param longitude: Name of new column to be created containing longitude.
:param latitude: Name of new column to be created containing latitude.
:param altitude: Name of new column to be created containing altitude.
:return: DataFrame with decoded longitudes and latitudes
"""

def _parse_location(location: str) -> Tuple[float, float, float]:
"""
Parse a string containing a geodetic point and return latitude, longitude
and altitude
"""
point = Point(location) # type: ignore
return point[0], point[1], point[2]

try:
geodetic_df = DataFrame()
(
geodetic_df["latitude"],
geodetic_df["longitude"],
geodetic_df["altitude"],
) = zip(*df[geodetic].apply(_parse_location))
columns = {"latitude": latitude, "longitude": longitude}
if altitude:
columns["altitude"] = altitude
return _append_columns(df, geodetic_df, columns)
except ValueError:
raise QueryObjectValidationError(_("Invalid geodetic string"))
14 changes: 14 additions & 0 deletions tests/fixtures/dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,17 @@
index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]),
data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]},
)

lonlat_df = DataFrame(
{
"city": ["New York City", "Sydney"],
"geohash": ["dr5regw3pg6f", "r3gx2u9qdevk"],
"latitude": [40.71277496, -33.85598011],
"longitude": [-74.00597306, 151.20666526],
"altitude": [5.5, 0.012],
"geodetic": [
"40.71277496, -74.00597306, 5.5km",
"-33.85598011, 151.20666526, 12m",
],
}
)
Loading

0 comments on commit a52cfcd

Please sign in to comment.