Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: Throw error for invalid ** usage outside folder segments (e.g. /tmp/**.csv) #3100

Merged
merged 7 commits into from
Oct 31, 2024
69 changes: 69 additions & 0 deletions src/daft-io/src/object_store_glob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,37 @@ fn _should_return(fm: &FileMetadata) -> bool {
}
}

/// Validates the glob pattern before compiling it. The `globset` crate which we use for globbing is
/// very permissive and does not check for invalid usage of the '**' wildcard. This function ensures
/// that the glob pattern does not contain invalid usage of '**'.
fn verify_glob(glob: &str) -> super::Result<()> {
let re = regex::Regex::new(r"(?P<before>.*?[^\\])\*\*(?P<after>[^/\n].*)").unwrap();

if let Some(captures) = re.captures(glob) {
let before = captures.name("before").map_or("", |m| m.as_str());
let after = captures.name("after").map_or("", |m| m.as_str());

// Ensure the 'before' part ends with a delimiter
let corrected_before = if !before.ends_with('/') {
format!("{}/", before)
} else {
before.to_string()
};

let corrected_pattern = format!("{corrected_before}**/*{after}");

return Err(super::Error::InvalidArgument {
msg: format!(
"Invalid usage of '**' in glob pattern. Found '{before}**{after}'. \
The '**' wildcard should be used to match directories and must be surrounded by delimiters. \
Did you perhaps mean: '{corrected_pattern}'?"
),
});
}

Ok(())
}

/// Globs an ObjectSource for Files
///
/// Uses the `globset` crate for matching, and thus supports all the syntax enabled by that crate.
Expand Down Expand Up @@ -404,6 +435,10 @@ pub async fn glob(
};
let glob = glob.as_str();

// Validate the glob pattern, this is necessary since the `globset` crate is overly permissive and happily compiles patterns
// like "/foo/bar/**.txt" which don't make sense.
verify_glob(glob)?;

let glob_fragments = to_glob_fragments(glob)?;
let full_glob_matcher = GlobBuilder::new(glob)
.literal_separator(true)
Expand Down Expand Up @@ -662,3 +697,37 @@ pub async fn glob(

Ok(to_rtn_stream.boxed())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_verify_glob() {
// Test valid glob patterns
assert!(verify_glob("valid/pattern.txt").is_ok()); // Normal globbing works ok
assert!(verify_glob("another/valid/pattern/**/blah.txt").is_ok()); // No error if ** used as a segment
assert!(verify_glob("**").is_ok()); // ** by itself is ok
assert!(verify_glob("another/valid/pattern/**").is_ok()); // No trailing slash is ok
assert!(verify_glob("another/valid/pattern/**/").is_ok()); // Trailing slash is ok (should be interpreted as **/*)
assert!(verify_glob("another/valid/pattern/**/\\**.txt").is_ok()); // Escaped ** is ok
assert!(verify_glob("**/wildcard/*.txt").is_ok()); // Wildcard matching not affected

// Test invalid glob patterns and check error messages
// The '**' wildcard should be used to match directories and must be surrounded by delimiters.
let err = verify_glob("invalid/**.txt").unwrap_err();
assert!(err.to_string().contains("invalid/**/*.txt")); // Suggests adding a delimiter after '**'

// '**' should be surrounded by delimiters to match directories, not used directly with file names.
let err = verify_glob("invalid/blahblah**.txt").unwrap_err();
assert!(err.to_string().contains("invalid/blahblah/**/*.txt")); // Suggests adding a delimiter before '**'

// Backslash should only escape the first '*', leading to non-escaped '**'.
let err = verify_glob("invalid/\\***.txt").unwrap_err();
assert!(err.to_string().contains("invalid/\\\\*/**/*.txt")); // Suggests correcting the escape sequence (NOTE: double backslash)

// Non-escaped '**' should trigger even when there is an escaped '**'.
let err = verify_glob("invalid/\\**blahblah**.txt").unwrap_err();
assert!(err.to_string().contains("invalid/\\\\**blahblah/**/*.txt")); // Suggests adding delimiters around '**'
}
}
6 changes: 5 additions & 1 deletion tests/integration/io/parquet/test_reads_s3_minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def test_minio_parquet_read_no_files(minio_io_config):
fs.touch("s3://data-engineering-prod/foo/file.txt")

with pytest.raises(FileNotFoundError, match="Glob path had no matches:"):
daft.read_parquet("s3://data-engineering-prod/foo/**.parquet", io_config=minio_io_config)
# Need to have a special character within the test path to trigger the matching logic
daft.read_parquet(
"s3://data-engineering-prod/foo/this-should-not-match-anything-and-this-file-should-not-exist-*.parquet",
io_config=minio_io_config,
)


@pytest.mark.integration()
Expand Down
43 changes: 43 additions & 0 deletions tests/io/test_list_files_local.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import os
import re

import pytest
from fsspec.implementations.local import LocalFileSystem

from daft.daft import io_glob
from daft.exceptions import DaftCoreException


def local_recursive_list(fs, path) -> list:
Expand Down Expand Up @@ -165,3 +167,44 @@ def test_missing_file_path(tmp_path, include_protocol):
p = "file://" + p
with pytest.raises(FileNotFoundError, match="/c/cc/ddd not found"):
io_glob(p)


@pytest.mark.parametrize("include_protocol", [False, True])
def test_invalid_double_asterisk_usage_local(tmp_path, include_protocol):
d = tmp_path / "dir"
d.mkdir()

path = str(d) + "/**.pq"
if include_protocol:
path = "file://" + path

expected_correct_path = str(d) + "/**/*.pq"
if include_protocol:
expected_correct_path = "file://" + expected_correct_path

# Need to escape these or the regex matcher will complain
expected_correct_path = re.escape(expected_correct_path)

with pytest.raises(DaftCoreException, match=expected_correct_path):
io_glob(path)


@pytest.mark.parametrize("include_protocol", [False, True])
def test_literal_double_asterisk_file(tmp_path, include_protocol):
d = tmp_path / "dir"
d.mkdir()
file_with_literal_name = d / "*.pq"
file_with_literal_name.touch()

path = str(d) + "/\**.pq"
if include_protocol:
path = "file://" + path

fs = LocalFileSystem()
fs_result = fs.ls(str(d), detail=True)
fs_result = [f for f in fs_result if f["name"] == str(file_with_literal_name)]

daft_ls_result = io_glob(path)

assert len(daft_ls_result) == 1
compare_local_result(daft_ls_result, fs_result)
Loading