Skip to content

Commit

Permalink
#197 fix zenodo downloads + es points + csv
Browse files Browse the repository at this point in the history
  • Loading branch information
2320sharon committed Oct 24, 2023
1 parent eb3e55b commit 44a9b2b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 20 deletions.
63 changes: 48 additions & 15 deletions src/coastseg/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ipywidgets import ToggleButton, HBox, VBox, Layout, HTML
from requests.exceptions import SSLError
from shapely.geometry import Polygon
from shapely.geometry import MultiPoint, LineString

# Specific classes/functions from modules
from typing import Callable, List, Optional, Union, Dict, Set
Expand Down Expand Up @@ -1207,21 +1208,24 @@ def download_url(url: str, save_path: str, filename: str = None, chunk_size: int
content_length = r.headers.get("Content-Length")
if content_length:
total_length = int(content_length)
with open(save_path, "wb") as fd:
with tqdm(
total=total_length,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc=f"Downloading {filename}",
initial=0,
ascii=True,
) as pbar:
for chunk in r.iter_content(chunk_size=chunk_size):
fd.write(chunk)
pbar.update(len(chunk))
else:
logger.warning("Content length not found in response headers")
print(
"Content length not found in response headers. Downloading without progress bar."
)
total_length = None
with open(save_path, "wb") as fd:
with tqdm(
total=total_length,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc=f"Downloading {filename}",
initial=0,
ascii=True,
) as pbar:
for chunk in r.iter_content(chunk_size=chunk_size):
fd.write(chunk)
pbar.update(len(chunk))


def get_center_point(coords: list) -> tuple:
Expand Down Expand Up @@ -1593,6 +1597,33 @@ def save_extracted_shoreline_figures(
)


def convert_linestrings_to_multipoints(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
"""
Convert LineString geometries in a GeoDataFrame to MultiPoint geometries.
Args:
- gdf (gpd.GeoDataFrame): The input GeoDataFrame.
Returns:
- gpd.GeoDataFrame: A new GeoDataFrame with MultiPoint geometries. If the input GeoDataFrame
already contains MultiPoints, the original GeoDataFrame is returned.
"""

# Check if the gdf already contains MultiPoints
if any(gdf.geometry.type == "MultiPoint"):
return gdf

def linestring_to_multipoint(linestring):
if isinstance(linestring, LineString):
return MultiPoint(linestring.coords)
return linestring

# Convert each LineString to a MultiPoint
gdf["geometry"] = gdf["geometry"].apply(linestring_to_multipoint)

return gdf


def save_extracted_shorelines(
extracted_shorelines: "Extracted_Shoreline", save_path: str
):
Expand All @@ -1619,9 +1650,11 @@ def save_extracted_shorelines(
save_path, "extracted_shorelines_lines.geojson", extracted_shorelines_gdf_lines
)

points_gdf = convert_linestrings_to_multipoints(extracted_shorelines.gdf)
projected_gdf = stringify_datetime_columns(points_gdf)
# Save extracted shorelines as a GeoJSON file
extracted_shorelines.to_file(
save_path, "extracted_shorelines_points.geojson", extracted_shorelines.gdf
save_path, "extracted_shorelines_points.geojson", projected_gdf
)

# Save shoreline settings as a JSON file
Expand Down
4 changes: 2 additions & 2 deletions src/coastseg/tide_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute_tidal_corrections(
def save_csv_per_id(
df: pd.DataFrame,
save_location: str,
filename: str = "timeseries_tidally_corrected",
filename: str = "timeseries_tidally_corrected.csv",
id_column_name: str = "transect_id",
):
new_df = pd.DataFrame()
Expand Down Expand Up @@ -220,7 +220,7 @@ def correct_tides(
save_csv_per_id(
tide_corrected_timeseries_df,
session_path,
filename="timeseries_tidally_corrected",
filename="timeseries_tidally_corrected.csv",
)
update(f"{roi_id} was tidally corrected")
return tide_corrected_timeseries_df
Expand Down
7 changes: 4 additions & 3 deletions src/coastseg/zoo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_files_to_download(
response = next((f for f in available_files if f["filename"] == filename), None)
if response is None:
raise ValueError(f"Cannot find {filename} at {model_id}")
link = response["links"]["self"]
link = response["links"]["download"]
file_path = os.path.join(model_path, filename)
url_dict[file_path] = link
return url_dict
Expand Down Expand Up @@ -798,6 +798,7 @@ def extract_shorelines_with_unet(

# save extracted shorelines, detection jpgs, configs, model settings files to the session directory
common.save_extracted_shorelines(extracted_shorelines, new_session_path)

# common.save_extracted_shoreline_figures(extracted_shorelines, new_session_path)
print(f"Saved extracted shorelines to {new_session_path}")

Expand Down Expand Up @@ -1320,7 +1321,7 @@ def download_best(
# if best BEST_MODEL.txt file not exist then download it
if not os.path.isfile(BEST_MODEL_txt_path):
common.download_url(
best_model_json["links"]["self"],
best_model_json["links"]["download"],
BEST_MODEL_txt_path,
"Downloading best_model.txt",
)
Expand Down Expand Up @@ -1405,7 +1406,7 @@ def download_ensemble(
logger.info(f"all_json_reponses : {all_json_reponses }")
for response in all_models_reponses + all_json_reponses:
# get the link of the best model
link = response["links"]["self"]
link = response["links"]["download"]
filename = response["key"]
filepath = os.path.join(model_path, filename)
download_dict[filepath] = link
Expand Down

0 comments on commit 44a9b2b

Please sign in to comment.