diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89bdd32..5531ea7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: check-yaml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.5.5" + rev: "v0.6.4" hooks: - id: ruff args: ["--fix"] @@ -26,13 +26,12 @@ repos: - id: prettier - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.0 + rev: v1.11.2 hooks: - id: mypy additional_dependencies: [ # Type stubs types-setuptools, - types-pkg_resources, # Dependencies that are typed numpy, xarray, diff --git a/doc/demo.ipynb b/doc/demo.ipynb index 3139be9..bb4f318 100644 --- a/doc/demo.ipynb +++ b/doc/demo.ipynb @@ -26,6 +26,7 @@ "outputs": [], "source": [ "import xarray as xr\n", + "\n", "import xbatcher" ] }, @@ -46,12 +47,12 @@ "metadata": {}, "outputs": [], "source": [ - "store = \"az://carbonplan-share/example_cmip6_data.zarr\"\n", + "store = 'az://carbonplan-share/example_cmip6_data.zarr'\n", "ds = xr.open_dataset(\n", " store,\n", - " engine=\"zarr\",\n", + " engine='zarr',\n", " chunks={},\n", - " backend_kwargs={\"storage_options\": {\"account_name\": \"carbonplan\"}},\n", + " backend_kwargs={'storage_options': {'account_name': 'carbonplan'}},\n", ")\n", "\n", "# the attributes contain a lot of useful information, but clutter the print out when we inspect the outputs\n", @@ -98,10 +99,10 @@ "\n", "bgen = xbatcher.BatchGenerator(\n", " ds=ds,\n", - " input_dims={\"time\": n_timepoint_in_each_sample},\n", + " input_dims={'time': n_timepoint_in_each_sample},\n", ")\n", "\n", - "print(f\"{len(bgen)} batches\")" + "print(f'{len(bgen)} batches')" ] }, { @@ -133,7 +134,7 @@ "outputs": [], "source": [ "expected_n_batch = len(ds.time) / n_timepoint_in_each_sample\n", - "print(f\"Expecting {expected_n_batch} batches, getting {len(bgen)} batches\")" + "print(f'Expecting {expected_n_batch} batches, getting {len(bgen)} batches')" ] }, { @@ -153,7 +154,7 @@ "source": [ "expected_batch_size = len(ds.lat) * len(ds.lon)\n", "print(\n", - " f\"Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch\"\n", + " f'Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch'\n", ")" ] }, @@ -179,12 +180,12 @@ "\n", "bgen = xbatcher.BatchGenerator(\n", " ds=ds,\n", - " input_dims={\"time\": n_timepoint_in_each_sample},\n", - " batch_dims={\"time\": n_timepoint_in_each_batch},\n", + " input_dims={'time': n_timepoint_in_each_sample},\n", + " batch_dims={'time': n_timepoint_in_each_batch},\n", " concat_input_dims=True,\n", ")\n", "\n", - "print(f\"{len(bgen)} batches\")" + "print(f'{len(bgen)} batches')" ] }, { @@ -217,11 +218,11 @@ "source": [ "n_timepoint_in_batch = 31\n", "\n", - "bgen = xbatcher.BatchGenerator(ds=ds, input_dims={\"time\": n_timepoint_in_batch})\n", + "bgen = xbatcher.BatchGenerator(ds=ds, input_dims={'time': n_timepoint_in_batch})\n", "\n", "for batch in bgen:\n", - " print(f\"last time point in ds is {ds.time[-1].values}\")\n", - " print(f\"last time point in batch is {batch.time[-1].values}\")\n", + " print(f'last time point in ds is {ds.time[-1].values}')\n", + " print(f'last time point in batch is {batch.time[-1].values}')\n", "batch" ] }, @@ -249,15 +250,15 @@ "\n", "bgen = xbatcher.BatchGenerator(\n", " ds=ds,\n", - " input_dims={\"time\": n_timepoint_in_each_sample},\n", - " batch_dims={\"time\": n_timepoint_in_each_batch},\n", + " input_dims={'time': n_timepoint_in_each_sample},\n", + " batch_dims={'time': n_timepoint_in_each_batch},\n", " concat_input_dims=True,\n", - " input_overlap={\"time\": input_overlap},\n", + " input_overlap={'time': input_overlap},\n", ")\n", "\n", "batch = bgen[0]\n", "\n", - "print(f\"{len(bgen)} batches\")\n", + "print(f'{len(bgen)} batches')\n", "batch" ] }, @@ -283,10 +284,10 @@ "display(pixel)\n", "\n", "print(\n", - " f\"sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}\"\n", + " f'sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}'\n", ")\n", "print(\n", - " f\"sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}\"\n", + " f'sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}'\n", ")" ] }, @@ -310,17 +311,17 @@ "outputs": [], "source": [ "bgen = xbatcher.BatchGenerator(\n", - " ds=ds[[\"tasmax\"]].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n", - " input_dims={\"lat\": 9, \"lon\": 9, \"time\": 10},\n", - " batch_dims={\"lat\": 18, \"lon\": 18, \"time\": 15},\n", + " ds=ds[['tasmax']].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),\n", + " input_dims={'lat': 9, 'lon': 9, 'time': 10},\n", + " batch_dims={'lat': 18, 'lon': 18, 'time': 15},\n", " concat_input_dims=True,\n", - " input_overlap={\"lat\": 8, \"lon\": 8, \"time\": 9},\n", + " input_overlap={'lat': 8, 'lon': 8, 'time': 9},\n", ")\n", "\n", "for i, batch in enumerate(bgen):\n", - " print(f\"batch {i}\")\n", + " print(f'batch {i}')\n", " # make sure the ordering of dimension is consistent\n", - " batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n", + " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n", "\n", " # only use the first 9 time points as features, since the last time point is the label to be predicted\n", " features = batch.tasmax.isel(time_input=slice(0, 9))\n", @@ -328,10 +329,10 @@ " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n", " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n", "\n", - " print(\"feature shape\", features.shape)\n", - " print(\"label shape\", labels.shape)\n", - " print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape)\n", - " print(\"\")" + " print('feature shape', features.shape)\n", + " print('label shape', labels.shape)\n", + " print('shape of lat of each sample', labels.coords['lat'].shape)\n", + " print('')" ] }, { @@ -350,21 +351,21 @@ "outputs": [], "source": [ "for i, batch in enumerate(bgen):\n", - " print(f\"batch {i}\")\n", + " print(f'batch {i}')\n", " # make sure the ordering of dimension is consistent\n", - " batch = batch.transpose(\"input_batch\", \"lat_input\", \"lon_input\", \"time_input\")\n", + " batch = batch.transpose('input_batch', 'lat_input', 'lon_input', 'time_input')\n", "\n", " # only use the first 9 time points as features, since the last time point is the label to be predicted\n", " features = batch.tasmax.isel(time_input=slice(0, 9))\n", - " features = features.stack(features=[\"lat_input\", \"lon_input\", \"time_input\"])\n", + " features = features.stack(features=['lat_input', 'lon_input', 'time_input'])\n", "\n", " # select the center pixel at the last time point to be the label to be predicted\n", " # the actual lat/lon/time for each of the sample can be accessed in labels.coords\n", " labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)\n", "\n", - " print(\"feature shape\", features.shape)\n", - " print(\"label shape\", labels.shape)\n", - " print(\"shape of lat of each sample\", labels.coords[\"lat\"].shape, \"\\n\")" + " print('feature shape', features.shape)\n", + " print('label shape', labels.shape)\n", + " print('shape of lat of each sample', labels.coords['lat'].shape, '\\n')" ] }, {