Skip to content

Commit

Permalink
Merge pull request #1235 from mandiant/fix/issue-1234
Browse files Browse the repository at this point in the history
stricter mypy checking
  • Loading branch information
williballenthin authored Dec 15, 2022
2 parents 655c45d + 505910e commit ad47ea3
Show file tree
Hide file tree
Showing 35 changed files with 408 additions and 263 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
- name: Lint with pycodestyle
run: pycodestyle --show-source capa/ scripts/ tests/
- name: Check types with mypy
run: mypy --config-file .github/mypy/mypy.ini capa/ scripts/ tests/
run: mypy --config-file .github/mypy/mypy.ini --check-untyped-defs capa/ scripts/ tests/

rule_linter:
runs-on: ubuntu-20.04
Expand Down
15 changes: 11 additions & 4 deletions capa/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import copy
import collections
from typing import TYPE_CHECKING, Set, Dict, List, Tuple, Mapping, Iterable
from typing import TYPE_CHECKING, Set, Dict, List, Tuple, Union, Mapping, Iterable, Iterator, cast

import capa.perf
import capa.features.common
Expand Down Expand Up @@ -60,17 +60,24 @@ def evaluate(self, features: FeatureSet, short_circuit=True) -> Result:
"""
raise NotImplementedError()

def get_children(self):
def get_children(self) -> Iterator[Union["Statement", Feature]]:
if hasattr(self, "child"):
yield self.child
# this really confuses mypy because the property may not exist
# since its defined in the subclasses.
child = self.child # type: ignore
assert isinstance(child, (Statement, Feature))
yield child

if hasattr(self, "children"):
for child in getattr(self, "children"):
assert isinstance(child, (Statement, Feature))
yield child

def replace_child(self, existing, new):
if hasattr(self, "child"):
if self.child is existing:
# this really confuses mypy because the property may not exist
# since its defined in the subclasses.
if self.child is existing: # type: ignore
self.child = new

if hasattr(self, "children"):
Expand Down
39 changes: 19 additions & 20 deletions capa/features/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re
import abc
import codecs
import typing
import logging
import collections
from typing import TYPE_CHECKING, Set, Dict, List, Union, Optional
Expand Down Expand Up @@ -200,8 +201,9 @@ def evaluate(self, ctx, short_circuit=True):

# mapping from string value to list of locations.
# will unique the locations later on.
matches = collections.defaultdict(list)
matches: typing.DefaultDict[str, Set[Address]] = collections.defaultdict(set)

assert isinstance(self.value, str)
for feature, locations in ctx.items():
if not isinstance(feature, (String,)):
continue
Expand All @@ -211,31 +213,27 @@ def evaluate(self, ctx, short_circuit=True):
raise ValueError("unexpected feature value type")

if self.value in feature.value:
matches[feature.value].extend(locations)
matches[feature.value].update(locations)
if short_circuit:
# we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode.
break

if matches:
# finalize: defaultdict -> dict
# which makes json serialization easier
matches = dict(matches)

# collect all locations
locations = set()
for s in matches.keys():
matches[s] = list(set(matches[s]))
locations.update(matches[s])
for locs in matches.values():
locations.update(locs)

# unlike other features, we cannot return put a reference to `self` directly in a `Result`.
# this is because `self` may match on many strings, so we can't stuff the matched value into it.
# instead, return a new instance that has a reference to both the substring and the matched values.
return Result(True, _MatchedSubstring(self, matches), [], locations=locations)
return Result(True, _MatchedSubstring(self, dict(matches)), [], locations=locations)
else:
return Result(False, _MatchedSubstring(self, {}), [])

def __str__(self):
assert isinstance(self.value, str)
return "substring(%s)" % self.value


Expand All @@ -261,6 +259,7 @@ def __init__(self, substring: Substring, matches: Dict[str, Set[Address]]):
self.matches = matches

def __str__(self):
assert isinstance(self.value, str)
return 'substring("%s", matches = %s)' % (
self.value,
", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())),
Expand Down Expand Up @@ -292,7 +291,7 @@ def evaluate(self, ctx, short_circuit=True):

# mapping from string value to list of locations.
# will unique the locations later on.
matches = collections.defaultdict(list)
matches: typing.DefaultDict[str, Set[Address]] = collections.defaultdict(set)

for feature, locations in ctx.items():
if not isinstance(feature, (String,)):
Expand All @@ -307,32 +306,28 @@ def evaluate(self, ctx, short_circuit=True):
# using this mode cleans is more convenient for rule authors,
# so that they don't have to prefix/suffix their terms like: /.*foo.*/.
if self.re.search(feature.value):
matches[feature.value].extend(locations)
matches[feature.value].update(locations)
if short_circuit:
# we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode.
break

if matches:
# finalize: defaultdict -> dict
# which makes json serialization easier
matches = dict(matches)

# collect all locations
locations = set()
for s in matches.keys():
matches[s] = list(set(matches[s]))
locations.update(matches[s])
for locs in matches.values():
locations.update(locs)

# unlike other features, we cannot return put a reference to `self` directly in a `Result`.
# this is because `self` may match on many strings, so we can't stuff the matched value into it.
# instead, return a new instance that has a reference to both the regex and the matched values.
# see #262.
return Result(True, _MatchedRegex(self, matches), [], locations=locations)
return Result(True, _MatchedRegex(self, dict(matches)), [], locations=locations)
else:
return Result(False, _MatchedRegex(self, {}), [])

def __str__(self):
assert isinstance(self.value, str)
return "regex(string =~ %s)" % self.value


Expand All @@ -358,6 +353,7 @@ def __init__(self, regex: Regex, matches: Dict[str, Set[Address]]):
self.matches = matches

def __str__(self):
assert isinstance(self.value, str)
return "regex(string =~ %s, matches = %s)" % (
self.value,
", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())),
Expand All @@ -380,16 +376,19 @@ def evaluate(self, ctx, **kwargs):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.bytes"] += 1

assert isinstance(self.value, bytes)
for feature, locations in ctx.items():
if not isinstance(feature, (Bytes,)):
continue

assert isinstance(feature.value, bytes)
if feature.value.startswith(self.value):
return Result(True, self, [], locations=locations)

return Result(False, self, [])

def get_value_str(self):
assert isinstance(self.value, bytes)
return hex_string(bytes_to_str(self.value))


Expand Down
24 changes: 24 additions & 0 deletions capa/features/extractors/dnfile/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def format_name(module, method):

def resolve_dotnet_token(pe: dnfile.dnPE, token: Token) -> Any:
"""map generic token to string or table row"""
assert pe.net is not None
assert pe.net.mdtables is not None

if isinstance(token, StringToken):
user_string: Optional[str] = read_dotnet_user_string(pe, token)
if user_string is None:
Expand Down Expand Up @@ -143,6 +146,9 @@ def read_dotnet_method_body(pe: dnfile.dnPE, row: dnfile.mdtable.MethodDefRow) -

def read_dotnet_user_string(pe: dnfile.dnPE, token: StringToken) -> Optional[str]:
"""read user string from #US stream"""
assert pe.net is not None
assert pe.net.user_strings is not None

try:
user_string: Optional[dnfile.stream.UserString] = pe.net.user_strings.get_us(token.rid)
except UnicodeDecodeError as e:
Expand All @@ -169,6 +175,10 @@ def get_dotnet_managed_imports(pe: dnfile.dnPE) -> Iterator[DnType]:
TypeName (index into String heap)
TypeNamespace (index into String heap)
"""
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.mdtables.MemberRef is not None

for (rid, row) in enumerate(iter_dotnet_table(pe, "MemberRef")):
if not isinstance(row.Class.row, dnfile.mdtable.TypeRefRow):
continue
Expand Down Expand Up @@ -258,6 +268,10 @@ def get_dotnet_properties(pe: dnfile.dnPE) -> Iterator[DnType]:

def get_dotnet_managed_method_bodies(pe: dnfile.dnPE) -> Iterator[Tuple[int, CilMethodBody]]:
"""get managed methods from MethodDef table"""
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.mdtables.MethodDef is not None

if not hasattr(pe.net.mdtables, "MethodDef"):
return

Expand Down Expand Up @@ -307,15 +321,25 @@ def calculate_dotnet_token_value(table: int, rid: int) -> int:


def is_dotnet_table_valid(pe: dnfile.dnPE, table_name: str) -> bool:
assert pe.net is not None
assert pe.net.mdtables is not None

return bool(getattr(pe.net.mdtables, table_name, None))


def is_dotnet_mixed_mode(pe: dnfile.dnPE) -> bool:
assert pe.net is not None
assert pe.net.Flags is not None

return not bool(pe.net.Flags.CLR_ILONLY)


def iter_dotnet_table(pe: dnfile.dnPE, name: str) -> Iterator[Any]:
assert pe.net is not None
assert pe.net.mdtables is not None

if not is_dotnet_table_valid(pe, name):
return

for row in getattr(pe.net.mdtables, name):
yield row
26 changes: 24 additions & 2 deletions capa/features/extractors/dnfile_.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ def extract_file_os(**kwargs) -> Iterator[Tuple[Feature, Address]]:
yield OS(OS_ANY), NO_ADDRESS


def extract_file_arch(pe, **kwargs) -> Iterator[Tuple[Feature, Address]]:
def extract_file_arch(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Feature, Address]]:
# to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020
# .NET 4.5 added option: any CPU, 32-bit preferred
assert pe.net is not None
assert pe.net.Flags is not None

if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE:
yield Arch(ARCH_I386), NO_ADDRESS
elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS:
Expand Down Expand Up @@ -71,6 +74,9 @@ def get_entry_point(self) -> int:
# self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT
# True: native EP: Token
# False: managed EP: RVA
assert self.pe.net is not None
assert self.pe.net.struct is not None

return self.pe.net.struct.EntryPointTokenOrRva

def extract_global_features(self):
Expand All @@ -83,13 +89,29 @@ def is_dotnet_file(self) -> bool:
return bool(self.pe.net)

def is_mixed_mode(self) -> bool:
assert self.pe is not None
assert self.pe.net is not None
assert self.pe.net.Flags is not None

return not bool(self.pe.net.Flags.CLR_ILONLY)

def get_runtime_version(self) -> Tuple[int, int]:
assert self.pe is not None
assert self.pe.net is not None
assert self.pe.net.struct is not None

return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion

def get_meta_version_string(self) -> str:
return self.pe.net.metadata.struct.Version.rstrip(b"\x00").decode("utf-8")
assert self.pe.net is not None
assert self.pe.net.metadata is not None
assert self.pe.net.metadata.struct is not None
assert self.pe.net.metadata.struct.Version is not None

vbuf = self.pe.net.metadata.struct.Version
assert isinstance(vbuf, bytes)

return vbuf.rstrip(b"\x00").decode("utf-8")

def get_functions(self):
raise NotImplementedError("DnfileFeatureExtractor can only be used to extract file features")
Expand Down
26 changes: 25 additions & 1 deletion capa/features/extractors/dotnetfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def extract_file_namespace_features(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple

def extract_file_class_features(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Class, Address]]:
"""emit class features from TypeRef and TypeDef tables"""
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.mdtables.TypeDef is not None
assert pe.net.mdtables.TypeRef is not None

for (rid, row) in enumerate(iter_dotnet_table(pe, "TypeDef")):
token = calculate_dotnet_token_value(pe.net.mdtables.TypeDef.number, rid + 1)
yield Class(DnType.format_name(row.TypeName, namespace=row.TypeNamespace)), DNTokenAddress(token)
Expand All @@ -94,6 +99,9 @@ def extract_file_os(**kwargs) -> Iterator[Tuple[OS, Address]]:
def extract_file_arch(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Arch, Address]]:
# to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020
# .NET 4.5 added option: any CPU, 32-bit preferred
assert pe.net is not None
assert pe.net.Flags is not None

if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE:
yield Arch(ARCH_I386), NO_ADDRESS
elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS:
Expand Down Expand Up @@ -155,6 +163,9 @@ def get_entry_point(self) -> int:
# self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT
# True: native EP: Token
# False: managed EP: RVA
assert self.pe.net is not None
assert self.pe.net.struct is not None

return self.pe.net.struct.EntryPointTokenOrRva

def extract_global_features(self):
Expand All @@ -170,10 +181,23 @@ def is_mixed_mode(self) -> bool:
return is_dotnet_mixed_mode(self.pe)

def get_runtime_version(self) -> Tuple[int, int]:
assert self.pe.net is not None
assert self.pe.net.struct is not None
assert self.pe.net.struct.MajorRuntimeVersion is not None
assert self.pe.net.struct.MinorRuntimeVersion is not None

return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion

def get_meta_version_string(self) -> str:
return self.pe.net.metadata.struct.Version.rstrip(b"\x00").decode("utf-8")
assert self.pe.net is not None
assert self.pe.net.metadata is not None
assert self.pe.net.metadata.struct is not None
assert self.pe.net.metadata.struct.Version is not None

vbuf = self.pe.net.metadata.struct.Version
assert isinstance(vbuf, bytes)

return vbuf.rstrip(b"\x00").decode("utf-8")

def get_functions(self):
raise NotImplementedError("DotnetFileFeatureExtractor can only be used to extract file features")
Expand Down
Loading

0 comments on commit ad47ea3

Please sign in to comment.