Skip to content

Commit

Permalink
pw_tokenizer: Refactor DatabaseFile
Browse files Browse the repository at this point in the history
Split DatabaseFile into multiple classes that coincide with each type of
database.

Change-Id: I8b29a8315646d60849a4cff75c7eed29bb9d7c0a
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/103084
Reviewed-by: Wyatt Hepler <[email protected]>
Commit-Queue: Carlos Chinchilla <[email protected]>
  • Loading branch information
Leo Acosta authored and CQ Bot Account committed Jul 22, 2022
1 parent 33fe972 commit 85ccddb
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 39 deletions.
40 changes: 23 additions & 17 deletions pw_tokenizer/py/pw_tokenizer/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import re
import struct
import sys
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Pattern,
Set, TextIO, Tuple, Union)
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Optional,
Pattern, Set, TextIO, Tuple, Union)

try:
from pw_tokenizer import elf_reader, tokens
Expand Down Expand Up @@ -218,7 +218,7 @@ def _load_token_database(db, domain: Pattern[str]) -> tokens.Database:
return _database_from_json(json_fd)

# Read the path as a packed binary or CSV file.
return tokens.DatabaseFile(db)
return tokens.DatabaseFile.create(Path(db))

# Assume that it's a file object and check if it's an ELF.
if elf_reader.compatible_file(db):
Expand All @@ -230,7 +230,7 @@ def _load_token_database(db, domain: Pattern[str]) -> tokens.Database:
if db.name.endswith('.json'):
return _database_from_json(db)

return tokens.DatabaseFile(db.name)
return tokens.DatabaseFile.create(Path(db.name))

# Read CSV directly from the file object.
return tokens.Database(tokens.parse_csv(db))
Expand Down Expand Up @@ -293,9 +293,8 @@ def generate_reports(paths: Iterable[Path]) -> _DatabaseReport:


def _handle_create(databases, database, force, output_type, include, exclude,
replace):
replace) -> None:
"""Creates a token database file from one or more ELF files."""

if database == '-':
# Must write bytes to stdout; use sys.stdout.buffer.
fd = sys.stdout.buffer
Expand All @@ -320,19 +319,24 @@ def _handle_create(databases, database, force, output_type, include, exclude,
fd.name, output_type)


def _handle_add(token_database, databases):
def _handle_add(token_database: tokens.DatabaseFile,
databases: List[tokens.Database]) -> None:
initial = len(token_database)

for source in databases:
token_database.add(source.entries())

number_of_entries_added = len(token_database) - initial

token_database.write_to_file()

_LOG.info('Added %d entries to %s',
len(token_database) - initial, token_database.path)
_LOG.info('Added %d entries to %s', number_of_entries_added,
token_database.path)


def _handle_mark_removed(token_database, databases, date):
def _handle_mark_removed(token_database: tokens.DatabaseFile,
databases: List[tokens.Database],
date: Optional[datetime]):
marked_removed = token_database.mark_removed(
(entry for entry in tokens.Database.merged(*databases).entries()
if not entry.date_removed), date)
Expand All @@ -343,7 +347,8 @@ def _handle_mark_removed(token_database, databases, date):
len(token_database), token_database.path)


def _handle_purge(token_database, before):
def _handle_purge(token_database: tokens.DatabaseFile,
before: Optional[datetime]):
purged = token_database.purge(before)
token_database.write_to_file()

Expand Down Expand Up @@ -460,12 +465,13 @@ def year_month_day(value) -> datetime:

# Shared command line options.
option_db = argparse.ArgumentParser(add_help=False)
option_db.add_argument('-d',
'--database',
dest='token_database',
type=tokens.DatabaseFile,
required=True,
help='The database file to update.')
option_db.add_argument(
'-d',
'--database',
dest='token_database',
type=lambda arg: tokens.DatabaseFile.create(Path(arg)),
required=True,
help='The database file to update.')

option_tokens = token_databases_parser('*')

Expand Down
59 changes: 41 additions & 18 deletions pw_tokenizer/py/pw_tokenizer/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# the License.
"""Builds and manages databases of tokenized strings."""

from abc import abstractmethod
import collections
import csv
from dataclasses import dataclass
Expand All @@ -23,7 +24,8 @@
import re
import struct
from typing import (BinaryIO, Callable, Dict, Iterable, Iterator, List,
NamedTuple, Optional, Pattern, Tuple, Union, ValuesView)
NamedTuple, Optional, Pattern, TextIO, Tuple, Union,
ValuesView)

DATE_FORMAT = '%Y-%m-%d'
DEFAULT_DOMAIN = ''
Expand Down Expand Up @@ -295,7 +297,7 @@ def __str__(self) -> str:
return csv_output.getvalue().decode()


def parse_csv(fd) -> Iterable[TokenizedStringEntry]:
def parse_csv(fd: TextIO) -> Iterable[TokenizedStringEntry]:
"""Parses TokenizedStringEntries from a CSV token database file."""
for line in csv.reader(fd):
try:
Expand Down Expand Up @@ -445,23 +447,44 @@ class DatabaseFile(Database):
This class adds the write_to_file() method that writes to file from which it
was created in the correct format (CSV or binary).
"""
def __init__(self, path: Union[Path, str]):
self.path = Path(path)

def __init__(self, path: Path,
entries: Iterable[TokenizedStringEntry]) -> None:
super().__init__(entries)
self.path = path

@staticmethod
def create(path: Path) -> 'DatabaseFile':
"""Creates a DatabaseFile that coincides to the file type."""
# Read the path as a packed binary file.
with self.path.open('rb') as fd:
with path.open('rb') as fd:
if file_is_binary_database(fd):
super().__init__(parse_binary(fd))
self._export = write_binary
return
return _BinaryDatabase(path, fd)

# Read the path as a CSV file.
_check_that_file_is_csv_database(self.path)
with self.path.open('r', newline='', encoding='utf-8') as file:
super().__init__(parse_csv(file))
self._export = write_csv

def write_to_file(self, path: Optional[Union[Path, str]] = None) -> None:
"""Exports in the original format to the original or provided path."""
with open(self.path if path is None else path, 'wb') as fd:
self._export(self, fd)
_check_that_file_is_csv_database(path)
with path.open('r', newline='', encoding='utf-8') as csv_fd:
return _CSVDatabase(path, csv_fd)

@abstractmethod
def write_to_file(self) -> None:
"""Exports in the original format to the original path."""


class _BinaryDatabase(DatabaseFile):
def __init__(self, path: Path, fd: BinaryIO) -> None:
super().__init__(path, parse_binary(fd))

def write_to_file(self) -> None:
"""Exports in the binary format to the original path."""
with self.path.open('wb') as fd:
write_binary(self, fd)


class _CSVDatabase(DatabaseFile):
def __init__(self, path: Path, fd: TextIO) -> None:
super().__init__(path, parse_csv(fd))

def write_to_file(self) -> None:
"""Exports in the csv format to the original path."""
with self.path.open('wb') as fd:
write_csv(self, fd)
8 changes: 4 additions & 4 deletions pw_tokenizer/py/tokens_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def tearDown(self):

def test_update_csv_file(self):
self._path.write_text(CSV_DATABASE)
db = tokens.DatabaseFile(self._path)
db = tokens.DatabaseFile.create(self._path)
self.assertEqual(str(db), CSV_DATABASE)

db.add([tokens.TokenizedStringEntry(0xffffffff, 'New entry!')])
Expand All @@ -419,19 +419,19 @@ def test_csv_file_too_short_raises_exception(self):
self._path.write_text('1234')

with self.assertRaises(tokens.DatabaseFormatError):
tokens.DatabaseFile(self._path)
tokens.DatabaseFile.create(self._path)

def test_csv_invalid_format_raises_exception(self):
self._path.write_text('MK34567890')

with self.assertRaises(tokens.DatabaseFormatError):
tokens.DatabaseFile(self._path)
tokens.DatabaseFile.create(self._path)

def test_csv_not_utf8(self):
self._path.write_bytes(b'\x80' * 20)

with self.assertRaises(tokens.DatabaseFormatError):
tokens.DatabaseFile(self._path)
tokens.DatabaseFile.create(self._path)


class TestFilter(unittest.TestCase):
Expand Down

0 comments on commit 85ccddb

Please sign in to comment.