From 1662378090da47c2e644775b21b5413d4c7c329d Mon Sep 17 00:00:00 2001 From: terrancedejesus Date: Fri, 1 Dec 2023 10:23:45 -0500 Subject: [PATCH] adding changes from #3297 --- detection_rules/packaging.py | 3 ++- detection_rules/rule.py | 15 +++++++-------- detection_rules/rule_validators.py | 5 +++-- esql/esql_listener.py | 19 +++++++++++++++++-- tests/test_all_rules.py | 15 ++++++++++----- 5 files changed, 39 insertions(+), 18 deletions(-) diff --git a/detection_rules/packaging.py b/detection_rules/packaging.py index 6251c68b183..f145e41b8da 100644 --- a/detection_rules/packaging.py +++ b/detection_rules/packaging.py @@ -277,8 +277,9 @@ def get_summary_rule_info(r: TOMLRule): r = r.contents rule_str = f'{r.name:<{longest_name}} (v:{r.autobumped_version} t:{r.data.type}' if isinstance(rule.contents.data, QueryRuleData): + index = rule.contents.data.get('index') or [] rule_str += f'-{r.data.language}' - rule_str += f'(indexes:{"".join(index_map[idx] for idx in rule.contents.data.index) or "none"}' + rule_str += f'(indexes:{"".join(index_map[idx] for idx in index) or "none"}' return rule_str diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 2a1eb5366a5..a35031f8f28 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -596,7 +596,7 @@ def get_required_fields(self, index: str) -> List[dict]: return validator.get_required_fields(index or []) @validates_schema - def validate_exceptions(self, data, **kwargs): + def validate_query_data(self, data, **kwargs): """Custom validation for query rule type and subclasses.""" # alert suppression is only valid for query rule type and not any of its subclasses @@ -717,18 +717,17 @@ def interval_ratio(self) -> Optional[float]: @dataclass(frozen=True) -class ESQLRuleData(BaseRuleData): +class ESQLRuleData(QueryRuleData): """ESQL rules are a special case of query rules.""" type: Literal["esql"] language: Literal["esql"] query: str - @cached_property - def validator(self) -> Optional[QueryValidator]: - return ESQLValidator(self.query) - - def validate_query(self, meta: RuleMeta) -> None: - return self.validator.validate(self, meta) + @validates_schema + def validate_esql_data(self, data, **kwargs): + """Custom validation for esql rule type.""" + if data.get('index'): + raise ValidationError("Index is not valid for esql rule type.") @dataclass(frozen=True) diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py index a45f93749f8..1c561cf62a1 100644 --- a/detection_rules/rule_validators.py +++ b/detection_rules/rule_validators.py @@ -18,7 +18,7 @@ from esql.esql_listener import ESQLErrorListener, ESQLValidatorListener from esql.EsqlBaseLexer import EsqlBaseLexer from esql.EsqlBaseParser import EsqlBaseParser -from esql.utils import get_node +from esql.utils import get_node, pretty_print_tree from . import ecs, endgame from .integrations import (get_integration_schema_data, @@ -435,7 +435,8 @@ def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: self.run_walker(EsqlBaseParser.BooleanDefaultContext) # TODO: Walk entire tree? # TODO: Pass event dataset to related integrations workflow - # pretty_print_tree(tree) + tree = self.ast + pretty_print_tree(tree) # get event datasets self.event_datasets = self.listener.event_datasets diff --git a/esql/esql_listener.py b/esql/esql_listener.py index a48f54a9f46..6b3611bd44d 100644 --- a/esql/esql_listener.py +++ b/esql/esql_listener.py @@ -33,8 +33,10 @@ def __init__(self, schema: dict = {}): def enterQualifiedName(self, ctx: EsqlBaseParser.QualifiedNameContext): # noqa: N802 """Extract field from context (ctx).""" + # TODO: we need to check if a field can be set in any processing command and ignore these parents if not isinstance(ctx.parentCtx, EsqlBaseParser.EvalCommandContext) and \ - not isinstance(ctx.parentCtx, EsqlBaseParser.MetadataContext): + not isinstance(ctx.parentCtx, EsqlBaseParser.MetadataContext) and \ + not isinstance(ctx.parentCtx.parentCtx.parentCtx, EsqlBaseParser.StatsCommandContext): field = ctx.getText() self.field_list.append(field) @@ -46,7 +48,8 @@ def enterSourceIdentifier(self, ctx: EsqlBaseParser.SourceIdentifierContext): # # Check if the parent context is NOT 'FromCommandContext' if not isinstance(ctx.parentCtx, EsqlBaseParser.FromCommandContext) and \ - not isinstance(ctx.parentCtx, EsqlBaseParser.MetadataContext): + not isinstance(ctx.parentCtx, EsqlBaseParser.MetadataContext) and \ + not isinstance(ctx.parentCtx.parentCtx.parentCtx, EsqlBaseParser.StatsCommandContext): # Extract field from context (ctx) # The implementation depends on your parse tree structure # For example, if the field name is directly the text of this context: @@ -59,6 +62,18 @@ def enterSourceIdentifier(self, ctx: EsqlBaseParser.SourceIdentifierContext): # # check index against integrations? self.indices.append(ctx.getText()) + def enterFromCommand(self, ctx: EsqlBaseParser.FromCommandContext): # noqa: N802 + """Override entry method for FromCommandContext.""" + + # check if metadata is present for rule type + # metadata_node = get_node(ctx, EsqlBaseParser.MetadataContext) + # if not metadata_node: + # composite_ctx = ctx.parentCtx.parentCtx.parentCtx + # processing_ctx = get_node(composite_ctx, EsqlBaseParser.ProcessingCommandContext) + # stats_ctx = get_node(processing_ctx[0], EsqlBaseParser.StatsCommandContext) + # if not stats_ctx: + # raise ESQLSemanticError(f"Missing metadata for ES|QL query with no stats command") + def check_literal_type(self, ctx: ParserRuleContext): """Check the type of a literal against the schema.""" field, context_type = self.find_associated_field_and_context(ctx) diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index d6f28d68e1b..fda4aa0e422 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -296,7 +296,7 @@ def test_required_tags(self): missing_required_tags = set() if isinstance(rule.contents.data, QueryRuleData): - for index in rule.contents.data.index: + for index in rule.contents.data.get('index') or []: expected_tags = required_tags_map.get(index, {}) expected_all = expected_tags.get('all', []) expected_any = expected_tags.get('any', []) @@ -611,6 +611,9 @@ def test_integration_tag(self): valid_integration_folders = [p.name for p in list(Path(INTEGRATION_RULE_DIR).glob("*")) if p.name != 'endpoint'] for rule in self.production_rules: + # TODO: temp bypass for esql rules; once parsed, we should be able to look for indexes via `FROM` + if not rule.contents.data.get('index'): + continue if isinstance(rule.contents.data, QueryRuleData) and rule.contents.data.language != 'lucene': rule_integrations = rule.contents.metadata.get('integration') or [] rule_integrations = [rule_integrations] if isinstance(rule_integrations, str) else rule_integrations @@ -619,7 +622,7 @@ def test_integration_tag(self): meta = rule.contents.metadata package_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) package_integrations_list = list(set([integration["package"] for integration in package_integrations])) - indices = data.get('index') + indices = data.get('index') or [] for rule_integration in rule_integrations: if ("even.dataset" in rule.contents.data.query and not package_integrations and # noqa: W504 not rule_promotion and rule_integration not in definitions.NON_DATASET_PACKAGES): # noqa: W504 @@ -812,12 +815,14 @@ def build_rule(query: str, query_language: str): def test_event_dataset(self): for rule in self.all_rules: - if(isinstance(rule.contents.data, QueryRuleData)): + if isinstance(rule.contents.data, QueryRuleData): # Need to pick validator based on language if rule.contents.data.language == "kuery": test_validator = KQLValidator(rule.contents.data.query) - if rule.contents.data.language == "eql": + elif rule.contents.data.language == "eql": test_validator = EQLValidator(rule.contents.data.query) + else: + continue data = rule.contents.data meta = rule.contents.metadata if meta.query_schema_validation is not False or meta.maturity != "deprecated": @@ -833,7 +838,7 @@ def test_event_dataset(self): meta, pkg_integrations) - if(validation_integrations_check and "event.dataset" in rule.contents.data.query): + if validation_integrations_check and "event.dataset" in rule.contents.data.query: raise validation_integrations_check