From 9aa16a9b07651e58f20b3bf329b3b4dbbd59bc20 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Mon, 15 Jan 2024 17:09:11 -0800 Subject: [PATCH 1/3] Fix the default folds and assertion for checking valid folds in PASTIS --- torchgeo/datasets/pastis.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index 4c1c5445e80..4e38eb4af30 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -129,7 +129,7 @@ class PASTIS(NonGeoDataset): def __init__( self, root: str = "data", - folds: Sequence[int] = (0, 1, 2, 3, 4), + folds: Sequence[int] = (1, 2, 3, 4, 5), bands: str = "s2", mode: str = "semantic", transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, @@ -153,7 +153,8 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert set(folds) <= set(range(6)) + for fold in folds: + assert 1 <= fold <= 5 assert bands in ["s1a", "s1d", "s2"] assert mode in ["semantic", "instance"] self.root = root From 92cc28e0941bfedf5afa5a4d963644be5f18f776 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Mon, 15 Jan 2024 17:19:28 -0800 Subject: [PATCH 2/3] Maybe fix PASTIS tests --- tests/data/pastis/data.py | 2 +- tests/datasets/test_pastis.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/data/pastis/data.py b/tests/data/pastis/data.py index 36742e8a434..652b0de139c 100644 --- a/tests/data/pastis/data.py +++ b/tests/data/pastis/data.py @@ -78,7 +78,7 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: "coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]], }, "id": str(i), - "properties": {"Fold": i % 5, "ID_PATCH": i}, + "properties": {"Fold": (i % 5) + 1, "ID_PATCH": i}, } ) diff --git a/tests/datasets/test_pastis.py b/tests/datasets/test_pastis.py index 1decc20e0c8..fa6f021178c 100644 --- a/tests/datasets/test_pastis.py +++ b/tests/datasets/test_pastis.py @@ -24,9 +24,9 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: class TestPASTIS: @pytest.fixture( params=[ - {"folds": (0, 1), "bands": "s2", "mode": "semantic"}, - {"folds": (0, 1), "bands": "s1a", "mode": "semantic"}, - {"folds": (0, 1), "bands": "s1d", "mode": "instance"}, + {"folds": (1, 2), "bands": "s2", "mode": "semantic"}, + {"folds": (1, 2), "bands": "s1a", "mode": "semantic"}, + {"folds": (1, 2), "bands": "s1d", "mode": "instance"}, ] ) def dataset( @@ -91,7 +91,7 @@ def test_corrupted(self, tmp_path: Path) -> None: def test_invalid_fold(self) -> None: with pytest.raises(AssertionError): - PASTIS(folds=(6,)) + PASTIS(folds=(0,)) def test_invalid_mode(self) -> None: with pytest.raises(AssertionError): From dc3a86d45c39ee8cad4563a513f174754fb67ac1 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Tue, 16 Jan 2024 03:40:27 +0000 Subject: [PATCH 3/3] Update data --- tests/data/pastis/PASTIS-R.zip | Bin 577316 -> 577314 bytes tests/datasets/test_pastis.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/pastis/PASTIS-R.zip b/tests/data/pastis/PASTIS-R.zip index 6c40f2b2afeec13ab80e9bbcfc87cb12b17e73a4..a44be5035e05f2393893539924f3c4f7f391490f 100644 GIT binary patch delta 1664 zcmZuwYfMvT81D7-RT$8U+*+W830pxabeL`kRE2Fwa9q?Fz&hiy2tNwOWoU#!fo&*) z5Vm~Aa#84Xf=tQ4qq{IHYv~1u$R@)8Ru`rV$uh>GV3x4c?`y}#=%4SL=Xsy!eb3wH zGqi@cw}yLVrgM3|Eaz`ZqS)W)70$zFd<;2{M3eKJ>$F6X_h+hg&bo=X;u}*qze)3LcIqwxLEANzir%v(JCVR_M#Ql7IQGUx6-J7?XJ8}O@M<@M^_b$x|X z854W9o_OQ7OqN}cWqO`;H>BKga@yoQlXL2_sLj^hYYXxCbA$0ts!w7HkkuH;W7 z<0X4My7;0MwP^%mB((yCimbq=ZyWjJF_?Z(aB#2%4OH@FbpIdxtI0&JjiZDgU_QNh zZtKJyOPWA3(p0K|+7!pr*=*{Jci6Kq5dl)(f)gu#;B3?aNP{ zox~{A4ixPv0Yfr&(%*bK5GiGx6vXL%9&D9j&ESs@kPYC$17CHzxY}@yR3+!!TKr&sP9)?|yQ5`T(5medC`>b97ns*xL>T!(($?f1tMtaT8;P=Epv|2j|L!DP{C8{iz*u5?0H0TVd)$X2Z z*eT|f0xHx)Fk+I+M;^oA`LFbjF)QHSt8!KB9d^x4QjE#ZYF(tkO_DTY@_3VrTyT?b zY5HW12$Si0Oom0T{#bK!O>ig0+%qYQwfR58i#yPRQIMmS5fDn9-&vfvFGIg(sZS>< zU1Unhs$Ljn{_3aK37j+Pa=;g6O^3%cDP``_{}ZW<+%bM`P1z`#83i@fV*+OMzX3Bd BN%H^z delta 1652 zcmZuxeNa?Y6yJTj`}QD6-~h&#vaCg+xqM`hORzGUI7lpFfRF^?Bu?Z=;U5Jk3rsj5 z#8-1QzycXSlompGU=FwvEW68s3_~)Ld}-Dc5xS&+In{gj?mETr{<-)4&hPxrx#ym9 z>n=#^FGy{b;R2xpkNeh_KaDA~l?w5?Lq^urZDc)XUI(|5|D`}Hju)BT7d}-H$ba#4 zWKo`vUH`(0L-O!}FMmtcwZuL7<>0ZU&dNO*4w{|i4PyJ=OzqXfuVu(GFAGBxr-rIl z9#LtxJyWed*grp>wEtxB`IaV?Lh=6g-yfM8oA-@?D*p#%(Ll`dgUOJ;V#iXw{cBQc zyALe{w8gY5wFUR~3O_0pG&(IsrOcsa$J2ku&8na77=OL@=d}5lX`RV`s-*nhmddcB zR~r@sBAWitmxYGDdBoJbZ$7(TzUX}P;R}(P=u3?U)OvM>-d+5sOn6_~WaYUnE5Bc8 zAG%pm?GgHuboWAz$6N2&1z3s%xJc>hlq;Qqd|u8BpXbaLPPr@kU|u}M#8dq7KwGXX zP90?%ZU|XMX@!zP<_U|8O`>9Ka+eBx@ltH)C_lN9O>2j z#<~V#4{C!B0=^O75nN5cSGwS`+38LX%#tnA^n*WELd*R?%829x5Ja%dXS?qPKu)kH z0}zG-;t4?ES_meFVlB8(J1MkF&UZq+TCm^z|Artf!_6D_)vLVFf)-vPzOIAdBG{y$ z&_gXaqKHATmvGK)jV&vsV{A9wqHw0C=awhd40jUA^x#c=59`^4KbxUZ?F{Xo5pfU9x4N?(U1U0 zlK^h$ngMJGuIF$%3y>rWJdxE9*!-vcM-1R7n8?FMcf`(t2*p_#WEL=^vB=C|heG7I z`ix{gF!9q0LUi%pe*DBItBH|`9g;G{5&02FVg&LGknyn_W>I%^!f5I6RwH{S8=6tN w5xmi3qve>^!z{eqieVSOwW5&8!g{|V)v!fR!#UpSA$sBwaHQM&fxV}H0aih8*#H0l diff --git a/tests/datasets/test_pastis.py b/tests/datasets/test_pastis.py index fa6f021178c..d0284688cf4 100644 --- a/tests/datasets/test_pastis.py +++ b/tests/datasets/test_pastis.py @@ -34,7 +34,7 @@ def dataset( ) -> PASTIS: monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url) - md5 = "9b11ae132623a0d13f7f0775d2003703" + md5 = "135a29fb8221241dde14f31579c07f45" monkeypatch.setattr(PASTIS, "md5", md5) url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") monkeypatch.setattr(PASTIS, "url", url)