Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: Throw error for invalid ** usage outside folder segments (e.g. /tmp/**.csv) #3100

Merged
merged 7 commits into from
Oct 31, 2024
23 changes: 23 additions & 0 deletions src/daft-io/src/object_store_glob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,29 @@ pub async fn glob(
};
let glob = glob.as_str();

// We need to do some validation on the glob pattern before compiling it, since the globset crate is very permissive
// and will happily compile patterns that don't make sense without throwing an error.
fn verify_glob(glob: &str) -> super::Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we can move this out of the function, and just run Rust unit-tests!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, makes a lot of sense.

Wasn't familiar with how Rust unit tests worked, have moved the testing logic for the verify_glob function to object_store_glob.rs.

// Catch for cases like `s3://bucket/path/**.txt`
// NOTE: "\**" is a valid pattern that matches a literal `*`, followed by anything, so we need to only capture cases where `**` is not preceded by a backslash
let re = regex::Regex::new(r"(?:[^\\]|^)\*\*").unwrap();

for segment in glob.split(GLOB_DELIMITER) {
if re.is_match(segment) && segment != "**" {
return Err(super::Error::InvalidArgument {
msg: format!(
"Invalid usage of '**' in glob pattern. The '**' wildcard must occupy an entire path segment and be surrounded by '{}' characters. Found invalid usage in '{}'.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we be more helpful with the error message as well? Would love to add a suggestion here for the user to do this glob path instead: {re_group_1}/**/*{re_group_2}/{re_group_3}.

Copy link
Contributor Author

@conradsoon conradsoon Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I've slightly rewritten the regex to process the path fully instead of segmenting it delimited portions.
This should allow us to suggest a corrected glob path for the user as well.

Have rewritten it to give suggestions in this manner: is this the behaviour you expect?

  • Original: invalid/blahblah**.txtCorrected: invalid/blahblah/**/*.txt
  • Original: invalid/\***.txtCorrected: invalid/\*/**/*.txt
  • Original: invalid/\**blahblah**.txtCorrected: invalid/\**blahblah/**/*.txt

GLOB_DELIMITER, glob
),
});
}
}

Ok(())
}

verify_glob(glob)?;

let glob_fragments = to_glob_fragments(glob)?;
let full_glob_matcher = GlobBuilder::new(glob)
.literal_separator(true)
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/io/parquet/test_reads_s3_minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def test_minio_parquet_read_no_files(minio_io_config):
fs.touch("s3://data-engineering-prod/foo/file.txt")

with pytest.raises(FileNotFoundError, match="Glob path had no matches:"):
daft.read_parquet("s3://data-engineering-prod/foo/**.parquet", io_config=minio_io_config)
# Need to have a special character within the test path to trigger the matching logic
daft.read_parquet(
"s3://data-engineering-prod/foo/this-should-not-match-anything-and-this-file-should-not-exist-*.parquet",
io_config=minio_io_config,
)


@pytest.mark.integration()
Expand Down
8 changes: 8 additions & 0 deletions tests/integration/io/test_list_files_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from daft.daft import AzureConfig, IOConfig, io_glob
from daft.exceptions import DaftCoreException

STORAGE_ACCOUNT = "dafttestdata"
CONTAINER = "public-anonymous"
Expand Down Expand Up @@ -69,3 +70,10 @@ def test_az_notfound():
path = f"az://{CONTAINER}/test_"
with pytest.raises(FileNotFoundError, match=path):
io_glob(path, io_config=IOConfig(azure=DEFAULT_AZURE_CONFIG))


@pytest.mark.integration()
def test_invalid_double_asterisk_usage():
path = f"az://{CONTAINER}/**.pq"
with pytest.raises(DaftCoreException):
io_glob(path, io_config=IOConfig(azure=DEFAULT_AZURE_CONFIG))
8 changes: 8 additions & 0 deletions tests/integration/io/test_list_files_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from daft.daft import io_glob
from daft.exceptions import DaftCoreException

BUCKET = "daft-public-data-gs"

Expand Down Expand Up @@ -73,3 +74,10 @@ def test_gs_notfound(gcs_public_config):
path = f"gs://{BUCKET}/test_"
with pytest.raises(FileNotFoundError, match=path):
io_glob(path, io_config=gcs_public_config)


@pytest.mark.integration()
def test_invalid_double_asterisk_usage_gcs(gcs_public_config):
path = f"gs://{BUCKET}/**invalid"
with pytest.raises(DaftCoreException):
io_glob(path, io_config=gcs_public_config)
8 changes: 8 additions & 0 deletions tests/integration/io/test_list_files_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fsspec.implementations.http import HTTPFileSystem

from daft.daft import io_glob
from daft.exceptions import DaftCoreException
from tests.integration.io.conftest import mount_data_nginx


Expand Down Expand Up @@ -139,3 +140,10 @@ def test_http_listing_absolute_base_urls(nginx_config, tmpdir):
assert daft_ls_result == [
{"type": "File", "path": f"{nginx_http_url}/other.html", "size": None},
]


@pytest.mark.integration()
def test_invalid_double_asterisk_usage_http(nginx_http_url):
path = f"{nginx_http_url}/**.txt"
with pytest.raises(DaftCoreException):
io_glob(path)
10 changes: 10 additions & 0 deletions tests/integration/io/test_list_files_s3_minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from daft.daft import io_glob
from daft.exceptions import DaftCoreException

from .conftest import minio_create_bucket

Expand Down Expand Up @@ -391,3 +392,12 @@ def test_limit(minio_io_config, limit):
fs.write_bytes(f"s3://{bucket_name}/{name}", b"")
daft_ls_result = io_glob(f"s3://{bucket_name}/**", io_config=minio_io_config, limit=limit)
assert len(daft_ls_result) == limit


@pytest.mark.integration()
def test_invalid_double_asterisk_usage_s3(minio_io_config):
bucket_name = "bucket"
path = f"s3://{bucket_name}/**.pq"
with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as _:
with pytest.raises(DaftCoreException):
io_glob(path, io_config=minio_io_config)
15 changes: 15 additions & 0 deletions tests/io/test_list_files_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fsspec.implementations.local import LocalFileSystem

from daft.daft import io_glob
from daft.exceptions import DaftCoreException


def local_recursive_list(fs, path) -> list:
Expand Down Expand Up @@ -165,3 +166,17 @@ def test_missing_file_path(tmp_path, include_protocol):
p = "file://" + p
with pytest.raises(FileNotFoundError, match="/c/cc/ddd not found"):
io_glob(p)


@pytest.mark.parametrize("include_protocol", [False, True])
@pytest.mark.integration()
def test_invalid_double_asterisk_usage_local(tmp_path, include_protocol):
d = tmp_path / "dir"
d.mkdir()

path = str(d) + "/**.pq"
if include_protocol:
path = "file://" + path

with pytest.raises(DaftCoreException):
io_glob(path)
Loading