Skip to content

Commit

Permalink
adding changes from #3297
Browse files Browse the repository at this point in the history
  • Loading branch information
terrancedejesus committed Dec 1, 2023
1 parent 3b9128c commit 1662378
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 18 deletions.
3 changes: 2 additions & 1 deletion detection_rules/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,9 @@ def get_summary_rule_info(r: TOMLRule):
r = r.contents
rule_str = f'{r.name:<{longest_name}} (v:{r.autobumped_version} t:{r.data.type}'
if isinstance(rule.contents.data, QueryRuleData):
index = rule.contents.data.get('index') or []
rule_str += f'-{r.data.language}'
rule_str += f'(indexes:{"".join(index_map[idx] for idx in rule.contents.data.index) or "none"}'
rule_str += f'(indexes:{"".join(index_map[idx] for idx in index) or "none"}'

return rule_str

Expand Down
15 changes: 7 additions & 8 deletions detection_rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def get_required_fields(self, index: str) -> List[dict]:
return validator.get_required_fields(index or [])

@validates_schema
def validate_exceptions(self, data, **kwargs):
def validate_query_data(self, data, **kwargs):
"""Custom validation for query rule type and subclasses."""

# alert suppression is only valid for query rule type and not any of its subclasses
Expand Down Expand Up @@ -717,18 +717,17 @@ def interval_ratio(self) -> Optional[float]:


@dataclass(frozen=True)
class ESQLRuleData(BaseRuleData):
class ESQLRuleData(QueryRuleData):
"""ESQL rules are a special case of query rules."""
type: Literal["esql"]
language: Literal["esql"]
query: str

@cached_property
def validator(self) -> Optional[QueryValidator]:
return ESQLValidator(self.query)

def validate_query(self, meta: RuleMeta) -> None:
return self.validator.validate(self, meta)
@validates_schema
def validate_esql_data(self, data, **kwargs):
"""Custom validation for esql rule type."""
if data.get('index'):
raise ValidationError("Index is not valid for esql rule type.")


@dataclass(frozen=True)
Expand Down
5 changes: 3 additions & 2 deletions detection_rules/rule_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from esql.esql_listener import ESQLErrorListener, ESQLValidatorListener
from esql.EsqlBaseLexer import EsqlBaseLexer
from esql.EsqlBaseParser import EsqlBaseParser
from esql.utils import get_node
from esql.utils import get_node, pretty_print_tree

from . import ecs, endgame
from .integrations import (get_integration_schema_data,
Expand Down Expand Up @@ -435,7 +435,8 @@ def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None:

self.run_walker(EsqlBaseParser.BooleanDefaultContext) # TODO: Walk entire tree?
# TODO: Pass event dataset to related integrations workflow
# pretty_print_tree(tree)
tree = self.ast
pretty_print_tree(tree)

# get event datasets
self.event_datasets = self.listener.event_datasets
Expand Down
19 changes: 17 additions & 2 deletions esql/esql_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def __init__(self, schema: dict = {}):
def enterQualifiedName(self, ctx: EsqlBaseParser.QualifiedNameContext): # noqa: N802
"""Extract field from context (ctx)."""

# TODO: we need to check if a field can be set in any processing command and ignore these parents
if not isinstance(ctx.parentCtx, EsqlBaseParser.EvalCommandContext) and \
not isinstance(ctx.parentCtx, EsqlBaseParser.MetadataContext):
not isinstance(ctx.parentCtx, EsqlBaseParser.MetadataContext) and \
not isinstance(ctx.parentCtx.parentCtx.parentCtx, EsqlBaseParser.StatsCommandContext):
field = ctx.getText()
self.field_list.append(field)

Expand All @@ -46,7 +48,8 @@ def enterSourceIdentifier(self, ctx: EsqlBaseParser.SourceIdentifierContext): #

# Check if the parent context is NOT 'FromCommandContext'
if not isinstance(ctx.parentCtx, EsqlBaseParser.FromCommandContext) and \
not isinstance(ctx.parentCtx, EsqlBaseParser.MetadataContext):
not isinstance(ctx.parentCtx, EsqlBaseParser.MetadataContext) and \
not isinstance(ctx.parentCtx.parentCtx.parentCtx, EsqlBaseParser.StatsCommandContext):
# Extract field from context (ctx)
# The implementation depends on your parse tree structure
# For example, if the field name is directly the text of this context:
Expand All @@ -59,6 +62,18 @@ def enterSourceIdentifier(self, ctx: EsqlBaseParser.SourceIdentifierContext): #
# check index against integrations?
self.indices.append(ctx.getText())

def enterFromCommand(self, ctx: EsqlBaseParser.FromCommandContext): # noqa: N802
"""Override entry method for FromCommandContext."""

# check if metadata is present for rule type
# metadata_node = get_node(ctx, EsqlBaseParser.MetadataContext)
# if not metadata_node:
# composite_ctx = ctx.parentCtx.parentCtx.parentCtx
# processing_ctx = get_node(composite_ctx, EsqlBaseParser.ProcessingCommandContext)
# stats_ctx = get_node(processing_ctx[0], EsqlBaseParser.StatsCommandContext)
# if not stats_ctx:
# raise ESQLSemanticError(f"Missing metadata for ES|QL query with no stats command")

def check_literal_type(self, ctx: ParserRuleContext):
"""Check the type of a literal against the schema."""
field, context_type = self.find_associated_field_and_context(ctx)
Expand Down
15 changes: 10 additions & 5 deletions tests/test_all_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_required_tags(self):
missing_required_tags = set()

if isinstance(rule.contents.data, QueryRuleData):
for index in rule.contents.data.index:
for index in rule.contents.data.get('index') or []:
expected_tags = required_tags_map.get(index, {})
expected_all = expected_tags.get('all', [])
expected_any = expected_tags.get('any', [])
Expand Down Expand Up @@ -611,6 +611,9 @@ def test_integration_tag(self):
valid_integration_folders = [p.name for p in list(Path(INTEGRATION_RULE_DIR).glob("*")) if p.name != 'endpoint']

for rule in self.production_rules:
# TODO: temp bypass for esql rules; once parsed, we should be able to look for indexes via `FROM`
if not rule.contents.data.get('index'):
continue
if isinstance(rule.contents.data, QueryRuleData) and rule.contents.data.language != 'lucene':
rule_integrations = rule.contents.metadata.get('integration') or []
rule_integrations = [rule_integrations] if isinstance(rule_integrations, str) else rule_integrations
Expand All @@ -619,7 +622,7 @@ def test_integration_tag(self):
meta = rule.contents.metadata
package_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest)
package_integrations_list = list(set([integration["package"] for integration in package_integrations]))
indices = data.get('index')
indices = data.get('index') or []
for rule_integration in rule_integrations:
if ("even.dataset" in rule.contents.data.query and not package_integrations and # noqa: W504
not rule_promotion and rule_integration not in definitions.NON_DATASET_PACKAGES): # noqa: W504
Expand Down Expand Up @@ -812,12 +815,14 @@ def build_rule(query: str, query_language: str):

def test_event_dataset(self):
for rule in self.all_rules:
if(isinstance(rule.contents.data, QueryRuleData)):
if isinstance(rule.contents.data, QueryRuleData):
# Need to pick validator based on language
if rule.contents.data.language == "kuery":
test_validator = KQLValidator(rule.contents.data.query)
if rule.contents.data.language == "eql":
elif rule.contents.data.language == "eql":
test_validator = EQLValidator(rule.contents.data.query)
else:
continue
data = rule.contents.data
meta = rule.contents.metadata
if meta.query_schema_validation is not False or meta.maturity != "deprecated":
Expand All @@ -833,7 +838,7 @@ def test_event_dataset(self):
meta,
pkg_integrations)

if(validation_integrations_check and "event.dataset" in rule.contents.data.query):
if validation_integrations_check and "event.dataset" in rule.contents.data.query:
raise validation_integrations_check


Expand Down

0 comments on commit 1662378

Please sign in to comment.