Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pre-commit hooks to use latest versions of ruff and mypy #236

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.5.5"
rev: "v0.6.4"
hooks:
- id: ruff
args: ["--fix"]
Expand All @@ -26,13 +26,12 @@ repos:
- id: prettier

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.0
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies: [
# Type stubs
types-setuptools,
types-pkg_resources,
# Dependencies that are typed
numpy,
xarray,
Expand Down
71 changes: 36 additions & 35 deletions doc/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"outputs": [],
"source": [
"import xarray as xr\n",
"\n",
"import xbatcher"
]
},
Expand All @@ -46,12 +47,12 @@
"metadata": {},
"outputs": [],
"source": [
"store = \"az://carbonplan-share/example_cmip6_data.zarr\"\n",
"store = 'az://carbonplan-share/example_cmip6_data.zarr'\n",
"ds = xr.open_dataset(\n",
" store,\n",
" engine=\"zarr\",\n",
" engine='zarr',\n",
" chunks={},\n",
" backend_kwargs={\"storage_options\": {\"account_name\": \"carbonplan\"}},\n",
" backend_kwargs={'storage_options': {'account_name': 'carbonplan'}},\n",
")\n",
"\n",
"# the attributes contain a lot of useful information, but clutter the print out when we inspect the outputs\n",
Expand Down Expand Up @@ -98,10 +99,10 @@
"\n",
"bgen = xbatcher.BatchGenerator(\n",
" ds=ds,\n",
" input_dims={\"time\": n_timepoint_in_each_sample},\n",
" input_dims={'time': n_timepoint_in_each_sample},\n",
")\n",
"\n",
"print(f\"{len(bgen)} batches\")"
"print(f'{len(bgen)} batches')"
]
},
{
Expand Down Expand Up @@ -133,7 +134,7 @@
"outputs": [],
"source": [
"expected_n_batch = len(ds.time) / n_timepoint_in_each_sample\n",
"print(f\"Expecting {expected_n_batch} batches, getting {len(bgen)} batches\")"
"print(f'Expecting {expected_n_batch} batches, getting {len(bgen)} batches')"
]
},
{
Expand All @@ -153,7 +154,7 @@
"source": [
"expected_batch_size = len(ds.lat) * len(ds.lon)\n",
"print(\n",
" f\"Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch\"\n",
" f'Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch'\n",
")"
]
},
Expand All @@ -179,12 +180,12 @@
"\n",
"bgen = xbatcher.BatchGenerator(\n",
" ds=ds,\n",
" input_dims={\"time\": n_timepoint_in_each_sample},\n",
" batch_dims={\"time\": n_timepoint_in_each_batch},\n",
" input_dims={'time': n_timepoint_in_each_sample},\n",
" batch_dims={'time': n_timepoint_in_each_batch},\n",
" concat_input_dims=True,\n",
")\n",
"\n",
"print(f\"{len(bgen)} batches\")"
"print(f'{len(bgen)} batches')"
]
},
{
Expand Down Expand Up @@ -217,11 +218,11 @@
"source": [
"n_timepoint_in_batch = 31\n",
"\n",
"bgen = xbatcher.BatchGenerator(ds=ds, input_dims={\"time\": n_timepoint_in_batch})\n",
"bgen = xbatcher.BatchGenerator(ds=ds, input_dims={'time': n_timepoint_in_batch})\n",
"\n",
"for batch in bgen:\n",
" print(f\"last time point in ds is {ds.time[-1].values}\")\n",
" print(f\"last time point in batch is {batch.time[-1].values}\")\n",
" print(f'last time point in ds is {ds.time[-1].values}')\n",
" print(f'last time point in batch is {batch.time[-1].values}')\n",
"batch"
]
},
Expand Down Expand Up @@ -249,15 +250,15 @@
"\n",
"bgen = xbatcher.BatchGenerator(\n",
" ds=ds,\n",
" input_dims={\"time\": n_timepoint_in_each_sample},\n",
" batch_dims={\"time\": n_timepoint_in_each_batch},\n",
" input_dims={'time': n_timepoint_in_each_sample},\n",
" batch_dims={'time': n_timepoint_in_each_batch},\n",
" concat_input_dims=True,\n",
" input_overlap={\"time\": input_overlap},\n",
" input_overlap={'time': input_overlap},\n",
")\n",
"\n",
"batch = bgen[0]\n",
"\n",
"print(f\"{len(bgen)} batches\")\n",
"print(f'{len(bgen)} batches')\n",
"batch"
]
},
Expand All @@ -283,10 +284,10 @@
"display(pixel)\n",
"\n",
"print(\n",
" f\"sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}\"\n",
" f'sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}'\n",
")\n",
"print(\n",
" f\"sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}\"\n",
" f'sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}'\n",
")"
]
},
Expand All @@ -310,28 +311,28 @@
"outputs": [],
"source": [
"bgen = xbatcher.BatchGenerator(\n",
" ds=ds[[\"tasmax\"]].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n",
" input_dims={\"lat\": 9, \"lon\": 9, \"time\": 10},\n",
" batch_dims={\"lat\": 18, \"lon\": 18, \"time\": 15},\n",
" ds=ds[['tasmax']].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n",
" input_dims={'lat': 9, 'lon': 9, 'time': 10},\n",
" batch_dims={'lat': 18, 'lon': 18, 'time': 15},\n",
" concat_input_dims=True,\n",
" input_overlap={\"lat\": 8, \"lon\": 8, \"time\": 9},\n",
" input_overlap={'lat': 8, 'lon': 8, 'time': 9},\n",
")\n",
"\n",
"for i, batch in enumerate(bgen):\n",
" print(f\"batch {i}\")\n",
" print(f'batch {i}')\n",
" # make sure the ordering of dimension is consistent\n",
" batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n",
" batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n",
"\n",
" # only use the first 9 time points as features, since the last time point is the label to be predicted\n",
" features = batch.tasmax.isel(time_input=slice(0, 9))\n",
" # select the center pixel at the last time point to be the label to be predicted\n",
" # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n",
" labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n",
"\n",
" print(\"feature shape\", features.shape)\n",
" print(\"label shape\", labels.shape)\n",
" print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape)\n",
" print(\"\")"
" print('feature shape', features.shape)\n",
" print('label shape', labels.shape)\n",
" print('shape of lat of each sample', labels.coords['lat'].shape)\n",
" print('')"
]
},
{
Expand All @@ -350,21 +351,21 @@
"outputs": [],
"source": [
"for i, batch in enumerate(bgen):\n",
" print(f\"batch {i}\")\n",
" print(f'batch {i}')\n",
" # make sure the ordering of dimension is consistent\n",
" batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n",
" batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n",
"\n",
" # only use the first 9 time points as features, since the last time point is the label to be predicted\n",
" features = batch.tasmax.isel(time_input=slice(0, 9))\n",
" features = features.stack(features=[\"lat_input\", \"lon_input\", \"time_input\"])\n",
" features = features.stack(features=['lat_input', 'lon_input', 'time_input'])\n",
"\n",
" # select the center pixel at the last time point to be the label to be predicted\n",
" # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n",
" labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n",
"\n",
" print(\"feature shape\", features.shape)\n",
" print(\"label shape\", labels.shape)\n",
" print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape, \"\\n\")"
" print('feature shape', features.shape)\n",
" print('label shape', labels.shape)\n",
" print('shape of lat of each sample', labels.coords['lat'].shape, '\\n')"
]
},
{
Expand Down
Loading