Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust ESQLRuleData to Inherit QueryRuleData Dataclass #3297

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion detection_rules/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,9 @@ def get_summary_rule_info(r: TOMLRule):
r = r.contents
rule_str = f'{r.name:<{longest_name}} (v:{r.autobumped_version} t:{r.data.type}'
if isinstance(rule.contents.data, QueryRuleData):
index = rule.contents.data.get("index") or []
rule_str += f'-{r.data.language}'
rule_str += f'(indexes:{"".join(index_map[idx] for idx in rule.contents.data.index) or "none"}'
rule_str += f'(indexes:{"".join(index_map[idx] for idx in index) or "none"}'

return rule_str

Expand Down
17 changes: 9 additions & 8 deletions detection_rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ def validator(self) -> Optional[QueryValidator]:
return KQLValidator(self.query)
elif self.language == "eql":
return EQLValidator(self.query)
elif self.language == "esql":
return ESQLValidator(self.query)

def validate_query(self, meta: RuleMeta) -> None:
validator = self.validator
Expand All @@ -594,7 +596,7 @@ def get_required_fields(self, index: str) -> List[dict]:
return validator.get_required_fields(index or [])

@validates_schema
def validate_exceptions(self, data, **kwargs):
def validates_query_data(self, data, **kwargs):
"""Custom validation for query rule type and subclasses."""

# alert suppression is only valid for query rule type and not any of its subclasses
Expand All @@ -603,18 +605,17 @@ def validate_exceptions(self, data, **kwargs):


@dataclass(frozen=True)
class ESQLRuleData(BaseRuleData):
class ESQLRuleData(QueryRuleData):
"""ESQL rules are a special case of query rules."""
type: Literal["esql"]
language: Literal["esql"]
query: str

@cached_property
def validator(self) -> Optional[QueryValidator]:
return ESQLValidator(self.query)

def validate_query(self, meta: RuleMeta) -> None:
return self.validator.validate(self, meta)
@validates_schema
def validate_esql_data(self, data, **kwargs):
"""Custom validation for esql rule type."""
if data.get('index'):
raise ValidationError("Index is not valid for esql rule type.")


@dataclass(frozen=True)
Expand Down
6 changes: 3 additions & 3 deletions detection_rules/rule_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,13 @@ def ast(self):
@cached_property
def unique_fields(self) -> List[str]:
"""Return a list of unique fields in the query."""
# return empty list for ES|QL rules until ast is available
# return empty list for ES|QL rules until ast is available (friendlier than raising error)
# raise NotImplementedError('ES|QL query parsing not yet supported')
return []

def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:
"""Validate an ESQL query while checking TOMLRule."""
print("Warning: ESQL queries are not validated at this time.")
return None
# temporarily override to NOP until ES|QL query parsing is supported


def extract_error_field(exc: Union[eql.EqlParseError, kql.KqlParseError]) -> Optional[str]:
Expand Down
15 changes: 10 additions & 5 deletions tests/test_all_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_required_tags(self):
missing_required_tags = set()

if isinstance(rule.contents.data, QueryRuleData):
for index in rule.contents.data.index:
for index in rule.contents.data.get('index') or []:
expected_tags = required_tags_map.get(index, {})
expected_all = expected_tags.get('all', [])
expected_any = expected_tags.get('any', [])
Expand Down Expand Up @@ -611,6 +611,9 @@ def test_integration_tag(self):
valid_integration_folders = [p.name for p in list(Path(INTEGRATION_RULE_DIR).glob("*")) if p.name != 'endpoint']

for rule in self.production_rules:
# TODO: temp bypass for esql rules; once parsed, we should be able to look for indexes via `FROM`
if not rule.contents.data.get('index'):
continue
if isinstance(rule.contents.data, QueryRuleData) and rule.contents.data.language != 'lucene':
rule_integrations = rule.contents.metadata.get('integration') or []
rule_integrations = [rule_integrations] if isinstance(rule_integrations, str) else rule_integrations
Expand All @@ -619,7 +622,7 @@ def test_integration_tag(self):
meta = rule.contents.metadata
package_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest)
package_integrations_list = list(set([integration["package"] for integration in package_integrations]))
indices = data.get('index')
indices = data.get('index') or []
for rule_integration in rule_integrations:
if ("even.dataset" in rule.contents.data.query and not package_integrations and # noqa: W504
not rule_promotion and rule_integration not in definitions.NON_DATASET_PACKAGES): # noqa: W504
Expand Down Expand Up @@ -812,12 +815,14 @@ def build_rule(query: str, query_language: str):

def test_event_dataset(self):
for rule in self.all_rules:
if(isinstance(rule.contents.data, QueryRuleData)):
if isinstance(rule.contents.data, QueryRuleData):
# Need to pick validator based on language
if rule.contents.data.language == "kuery":
test_validator = KQLValidator(rule.contents.data.query)
if rule.contents.data.language == "eql":
elif rule.contents.data.language == "eql":
test_validator = EQLValidator(rule.contents.data.query)
else:
continue
data = rule.contents.data
meta = rule.contents.metadata
if meta.query_schema_validation is not False or meta.maturity != "deprecated":
Expand All @@ -833,7 +838,7 @@ def test_event_dataset(self):
meta,
pkg_integrations)

if(validation_integrations_check and "event.dataset" in rule.contents.data.query):
if validation_integrations_check and "event.dataset" in rule.contents.data.query:
raise validation_integrations_check


Expand Down
Loading