Skip to content

Commit

Permalink
Fix type hints after adding Branca type checking (python-visualizatio…
Browse files Browse the repository at this point in the history
…n#2060)

* remove render() return types

* fix TypeBounds

* missing return statement

* split TypeBounds in input and return types

* deal with bounds from args to return

* fix VegaLite typing

* geojsondetail assert parent is geojson

* bin_edges in choropleth

* geojson/topojson in choropleth

* colormap type in ColorLine

* ruff check

* black

* fix circular import
  • Loading branch information
Conengmo authored Dec 29, 2024
1 parent 09c5905 commit ccb7509
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 45 deletions.
2 changes: 1 addition & 1 deletion folium/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class JSCSSMixin(Element):
default_js: List[Tuple[str, str]] = []
default_css: List[Tuple[str, str]] = []

def render(self, **kwargs) -> None:
def render(self, **kwargs):
figure = self.get_root()
assert isinstance(
figure, Figure
Expand Down
78 changes: 50 additions & 28 deletions folium/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,24 @@
import numpy as np
import requests
from branca.colormap import ColorMap, LinearColormap, StepColormap
from branca.element import Element, Figure, Html, IFrame, JavascriptLink, MacroElement
from branca.element import (
Div,
Element,
Figure,
Html,
IFrame,
JavascriptLink,
MacroElement,
)
from branca.utilities import color_brewer

from folium.elements import JSCSSMixin
from folium.folium import Map
from folium.map import FeatureGroup, Icon, Layer, Marker, Popup, Tooltip
from folium.template import Template
from folium.utilities import (
TypeBoundsReturn,
TypeContainer,
TypeJsonValue,
TypeLine,
TypePathOptions,
Expand Down Expand Up @@ -165,7 +175,7 @@ def __init__(
self.top = _parse_size(top)
self.position = position

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
super().render(**kwargs)

Expand Down Expand Up @@ -284,9 +294,15 @@ def __init__(
self.top = _parse_size(top)
self.position = position

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
self._parent.html.add_child(
parent = self._parent
if not isinstance(parent, (Figure, Div, Popup)):
raise TypeError(
"VegaLite elements can only be added to a Figure, Div, or Popup"
)

parent.html.add_child(
Element(
Template(
"""
Expand Down Expand Up @@ -331,7 +347,7 @@ def render(self, **kwargs) -> None:
embed_vegalite = embed_mapping.get(
self.vegalite_major_version, self._embed_vegalite_v2
)
embed_vegalite(figure)
embed_vegalite(figure=figure, parent=parent)

@property
def vegalite_major_version(self) -> Optional[int]:
Expand All @@ -342,8 +358,8 @@ def vegalite_major_version(self) -> Optional[int]:

return int(schema.split("/")[-1].split(".")[0].lstrip("v"))

def _embed_vegalite_v5(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v5(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega"
Expand All @@ -356,8 +372,8 @@ def _embed_vegalite_v5(self, figure: Figure) -> None:
name="vega-embed",
)

def _embed_vegalite_v4(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v4(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm//vega@5"), name="vega"
Expand All @@ -370,8 +386,8 @@ def _embed_vegalite_v4(self, figure: Figure) -> None:
name="vega-embed",
)

def _embed_vegalite_v3(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v3(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm/vega@4"), name="vega"
Expand All @@ -384,8 +400,8 @@ def _embed_vegalite_v3(self, figure: Figure) -> None:
name="vega-embed",
)

def _embed_vegalite_v2(self, figure: Figure) -> None:
self._vega_embed()
def _embed_vegalite_v2(self, figure: Figure, parent: TypeContainer) -> None:
self._vega_embed(parent=parent)

figure.header.add_child(
JavascriptLink("https://cdn.jsdelivr.net/npm/vega@3"), name="vega"
Expand All @@ -398,8 +414,8 @@ def _embed_vegalite_v2(self, figure: Figure) -> None:
name="vega-embed",
)

def _vega_embed(self) -> None:
self._parent.script.add_child(
def _vega_embed(self, parent: TypeContainer) -> None:
parent.script.add_child(
Element(
Template(
"""
Expand All @@ -412,8 +428,8 @@ def _vega_embed(self) -> None:
name=self.get_name(),
)

def _embed_vegalite_v1(self, figure: Figure) -> None:
self._parent.script.add_child(
def _embed_vegalite_v1(self, figure: Figure, parent: TypeContainer) -> None:
parent.script.add_child(
Element(
Template(
"""
Expand All @@ -436,19 +452,19 @@ def _embed_vegalite_v1(self, figure: Figure) -> None:
figure.header.add_child(
JavascriptLink("https://cdnjs.cloudflare.com/ajax/libs/vega/2.6.5/vega.js"),
name="vega",
) # noqa
)
figure.header.add_child(
JavascriptLink(
"https://cdnjs.cloudflare.com/ajax/libs/vega-lite/1.3.1/vega-lite.js"
),
name="vega-lite",
) # noqa
)
figure.header.add_child(
JavascriptLink(
"https://cdnjs.cloudflare.com/ajax/libs/vega-embed/2.2.0/vega-embed.js"
),
name="vega-embed",
) # noqa
)


class GeoJson(Layer):
Expand Down Expand Up @@ -820,7 +836,7 @@ def _get_self_bounds(self) -> List[List[Optional[float]]]:
"""
return get_bounds(self.data, lonlat=True)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
self.parent_map = get_obj_in_upper_tree(self, Map)
# Need at least one feature, otherwise style mapping fails
if (self.style or self.highlight) and self.data["features"]:
Expand Down Expand Up @@ -1041,12 +1057,12 @@ def recursive_get(data, keys):
self.style_function(feature)
) # noqa

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
self.style_data()
super().render(**kwargs)

def get_bounds(self) -> List[List[float]]:
def get_bounds(self) -> TypeBoundsReturn:
"""
Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]]
Expand Down Expand Up @@ -1146,6 +1162,7 @@ def __init__(

def warn_for_geometry_collections(self) -> None:
"""Checks for GeoJson GeometryCollection features to warn user about incompatibility."""
assert isinstance(self._parent, GeoJson)
geom_collections = [
feature.get("properties") if feature.get("properties") is not None else key
for key, feature in enumerate(self._parent.data["features"])
Expand All @@ -1160,7 +1177,7 @@ def warn_for_geometry_collections(self) -> None:
UserWarning,
)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
figure = self.get_root()
if isinstance(self._parent, GeoJson):
Expand Down Expand Up @@ -1565,7 +1582,7 @@ def __init__(
color_range = color_brewer(fill_color, n=nb_bins)
self.color_scale = StepColormap(
color_range,
index=bin_edges,
index=list(bin_edges),
vmin=bins_min,
vmax=bins_max,
caption=legend_name,
Expand Down Expand Up @@ -1625,7 +1642,7 @@ def highlight_function(x):
return {"weight": line_weight + 2, "fillOpacity": fill_opacity + 0.2}

if topojson:
self.geojson = TopoJson(
self.geojson: Union[TopoJson, GeoJson] = TopoJson(
geo_data,
topojson,
style_function=style_function,
Expand Down Expand Up @@ -1657,7 +1674,7 @@ def _get_by_key(cls, obj: Union[dict, list], key: str) -> Union[float, str, None
else:
return value

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Render the GeoJson/TopoJson and color scale objects."""
if self.color_scale:
# ColorMap needs Map as its parent
Expand Down Expand Up @@ -1963,8 +1980,13 @@ def __init__(
vmin=min(colors),
vmax=max(colors),
).to_step(nb_steps)
else:
elif isinstance(colormap, StepColormap):
cm = colormap
else:
raise TypeError(
f"Unexpected type for argument `colormap`: {type(colormap)}"
)

out: Dict[str, List[List[List[float]]]] = {}
for (lat1, lng1), (lat2, lng2), color in zip(coords[:-1], coords[1:], colors):
out.setdefault(cm(color), []).append([[lat1, lng1], [lat2, lng2]])
Expand Down
2 changes: 1 addition & 1 deletion folium/folium.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _repr_png_(self) -> Optional[bytes]:
return None
return self._to_png()

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
figure = self.get_root()
assert isinstance(
Expand Down
13 changes: 7 additions & 6 deletions folium/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Optional, Sequence, Union, cast

from branca.element import Element, Figure, Html, MacroElement

Expand All @@ -14,6 +14,7 @@
from folium.utilities import (
JsCode,
TypeBounds,
TypeBoundsReturn,
TypeJsonValue,
escape_backticks,
parse_options,
Expand Down Expand Up @@ -221,7 +222,7 @@ def reset(self) -> None:
self.base_layers = OrderedDict()
self.overlays = OrderedDict()

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
self.reset()
for item in self._parent._children.values():
Expand Down Expand Up @@ -396,15 +397,15 @@ def __init__(
tooltip if isinstance(tooltip, Tooltip) else Tooltip(str(tooltip))
)

def _get_self_bounds(self) -> List[List[float]]:
def _get_self_bounds(self) -> TypeBoundsReturn:
"""Computes the bounds of the object itself.
Because a marker has only single coordinates, we repeat them.
"""
assert self.location is not None
return [self.location, self.location]
return cast(TypeBoundsReturn, [self.location, self.location])

def render(self) -> None:
def render(self):
if self.location is None:
raise ValueError(
f"{self._name} location must be assigned when added directly to map."
Expand Down Expand Up @@ -492,7 +493,7 @@ def __init__(
**kwargs,
)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
"""Renders the HTML representation of the element."""
for name, child in self._children.items():
child.render(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion folium/plugins/overlapping_marker_spiderfier.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def add_to(
) -> Element:
self._parent = parent
self.markers = self._get_all_markers(parent)
super().add_to(parent, name=name, index=index)
return super().add_to(parent, name=name, index=index)

def _get_all_markers(self, element: Element) -> list:
markers = []
Expand Down
16 changes: 9 additions & 7 deletions folium/raster_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from folium.template import Template
from folium.utilities import (
TypeBounds,
TypeBoundsReturn,
TypeJsonValue,
image_to_url,
mercator_transform,
normalize_bounds_type,
parse_options,
remove_empty,
)
Expand Down Expand Up @@ -246,7 +248,7 @@ class ImageOverlay(Layer):
* If string, it will be written directly in the output file.
* If file, it's content will be converted as embedded in the output file.
* If array-like, it will be converted to PNG base64 string and embedded in the output.
bounds: list
bounds: list/tuple of list/tuple of float
Image bounds on the map in the form
[[lat_min, lon_min], [lat_max, lon_max]]
opacity: float, default Leaflet's default (1.0)
Expand Down Expand Up @@ -319,7 +321,7 @@ def __init__(

self.url = image_to_url(image, origin=origin, colormap=colormap)

def render(self, **kwargs) -> None:
def render(self, **kwargs):
super().render()

figure = self.get_root()
Expand All @@ -344,13 +346,13 @@ def render(self, **kwargs) -> None:
Element(pixelated), name="leaflet-image-layer"
) # noqa

def _get_self_bounds(self) -> TypeBounds:
def _get_self_bounds(self) -> TypeBoundsReturn:
"""
Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]].
"""
return self.bounds
return normalize_bounds_type(self.bounds)


class VideoOverlay(Layer):
Expand All @@ -361,7 +363,7 @@ class VideoOverlay(Layer):
----------
video_url: str
URL of the video
bounds: list
bounds: list/tuple of list/tuple of float
Video bounds on the map in the form
[[lat_min, lon_min], [lat_max, lon_max]]
autoplay: bool, default True
Expand Down Expand Up @@ -411,10 +413,10 @@ def __init__(
self.bounds = bounds
self.options = remove_empty(autoplay=autoplay, loop=loop, **kwargs)

def _get_self_bounds(self) -> TypeBounds:
def _get_self_bounds(self) -> TypeBoundsReturn:
"""
Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]]
"""
return self.bounds
return normalize_bounds_type(self.bounds)
Loading

0 comments on commit ccb7509

Please sign in to comment.