From a4b3e6e254c77ffe95e74ce16083580051650718 Mon Sep 17 00:00:00 2001 From: Sujan Adhikari Date: Tue, 17 Dec 2024 18:06:09 +0545 Subject: [PATCH] fix(test): added db param to split sq test functions --- fmtm_splitter/splitter.py | 6 ++++-- tests/test_splitter.py | 44 +++++++++++++++++++++++++-------------- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/fmtm_splitter/splitter.py b/fmtm_splitter/splitter.py index 25979fe..bfb7c06 100755 --- a/fmtm_splitter/splitter.py +++ b/fmtm_splitter/splitter.py @@ -228,13 +228,15 @@ def splitBySquare( # noqa: N802 with create_connection(db) as conn: with conn.cursor() as cur: + # Drop the table if it exists + cur.execute("DROP TABLE IF EXISTS temp_polygons;") # Create temporary table cur.execute(""" CREATE TEMP TABLE temp_polygons ( id SERIAL PRIMARY KEY, geom GEOMETRY(GEOMETRY, 4326), area DOUBLE PRECISION - ) + ); """) extract_geoms = [] @@ -287,6 +289,7 @@ def splitBySquare( # noqa: N802 small_polygon RECORD; nearest_neighbor RECORD; BEGIN + DROP TABLE IF EXISTS small_polygons; CREATE TEMP TABLE small_polygons As SELECT id, geom, area FROM temp_polygons @@ -302,7 +305,6 @@ def splitBySquare( # noqa: N802 FROM temp_polygons lp WHERE id NOT IN (SELECT id FROM small_polygons) AND ST_Touches(small_polygon.geom, lp.geom) - AND ST_Touches(small_polygon.geom, lp.geom) AND ST_GEOMETRYTYPE( ST_INTERSECTION(small_polygon.geom, geom) ) != 'ST_Point' diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 7b1a7f3..eaf90d9 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -61,42 +61,50 @@ def test_init_splitter_types(aoi_json): assert str(error.value) == "The input AOI cannot contain multiple geometries." -def test_split_by_square_with_dict(aoi_json, extract_json): +def test_split_by_square_with_dict(db, aoi_json, extract_json): """Test divide by square from geojson dict types.""" features = split_by_square( - aoi_json.get("features")[0], meters=50, osm_extract=extract_json + aoi_json.get("features")[0], db, meters=50, osm_extract=extract_json ) - assert len(features.get("features")) == 60 + assert len(features.get("features")) == 66 features = split_by_square( - aoi_json.get("features")[0].get("geometry"), meters=50, osm_extract=extract_json + aoi_json.get("features")[0].get("geometry"), + db, + meters=50, + osm_extract=extract_json, ) - assert len(features.get("features")) == 60 + assert len(features.get("features")) == 66 -def test_split_by_square_with_str(aoi_json, extract_json): +def test_split_by_square_with_str(db, aoi_json, extract_json): """Test divide by square from geojson str and file.""" # GeoJSON Dumps features = split_by_square( - geojson.dumps(aoi_json.get("features")[0]), meters=50, osm_extract=extract_json + geojson.dumps(aoi_json.get("features")[0]), + db, + meters=50, + osm_extract=extract_json, ) - assert len(features.get("features")) == 60 + assert len(features.get("features")) == 66 # JSON Dumps features = split_by_square( json.dumps(aoi_json.get("features")[0].get("geometry")), + db, meters=50, osm_extract=extract_json, ) - assert len(features.get("features")) == 60 + assert len(features.get("features")) == 66 # File features = split_by_square( "tests/testdata/kathmandu.geojson", + db, meters=100, osm_extract="tests/testdata/kathmandu_extract.geojson", ) - assert len(features.get("features")) == 20 + assert len(features.get("features")) == 19 -def test_split_by_square_with_file_output(): +def test_split_by_square_with_file_output(db): """Test divide by square from geojson file. Also write output to file. @@ -104,28 +112,30 @@ def test_split_by_square_with_file_output(): outfile = Path(__file__).parent.parent / f"{uuid4()}.geojson" features = split_by_square( "tests/testdata/kathmandu.geojson", + db, osm_extract="tests/testdata/kathmandu_extract.geojson", meters=50, outfile=str(outfile), ) - assert len(features.get("features")) == 60 + assert len(features.get("features")) == 66 # Also check output file with open(outfile, "r") as jsonfile: output_geojson = geojson.load(jsonfile) - assert len(output_geojson.get("features")) == 60 + assert len(output_geojson.get("features")) == 66 -def test_split_by_square_with_multigeom_input(aoi_multi_json, extract_json): +def test_split_by_square_with_multigeom_input(db, aoi_multi_json, extract_json): """Test divide by square from geojson dict types.""" file_name = uuid4() outfile = Path(__file__).parent.parent / f"{file_name}.geojson" features = split_by_square( aoi_multi_json, + db, meters=50, osm_extract=extract_json, outfile=str(outfile), ) - assert len(features.get("features", [])) == 80 + assert len(features.get("features", [])) == 76 for index in [0, 1, 2, 3]: assert Path(f"{Path(outfile).stem}_{index}.geojson)").exists() @@ -219,6 +229,8 @@ def test_split_by_square_cli(): [ "--boundary", str(infile), + "--dburl", + "postgresql://fmtm:dummycipassword@db:5432/splitter", "--meters", "100", "--extract", @@ -233,7 +245,7 @@ def test_split_by_square_cli(): with open(outfile, "r") as jsonfile: output_geojson = geojson.load(jsonfile) - assert len(output_geojson.get("features")) == 20 + assert len(output_geojson.get("features")) == 19 def test_split_by_features_cli():