Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Barebone Functionality #6

Merged
merged 7 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,41 @@ Easily inspect the TensorRT log files. Check out statistics about optimizations,

## Developing

This project is built on [rye](https://rye.astral.sh/), make sure rye is available in your $PATH. After cloning, follow these steps to install dependencies and build the wheel:
This project is built on [rye](https://rye.astral.sh/), make sure rye is available in your $PATH. After cloning, install or update the dependencies by syncing:

```bash
rye sync
rye build
```
Please make and run tests when developing. Place test files inside the `tests/` directory, and make the `tests/` directory structure follows `src/trt_log_inspector/`. Note below's example where `test_file_helper.py` is under the `tests/utils/`, equivalent to `src/trt_log_inspector/utils/`.

The wheel will be available inside the `dist/` directory. Install it in a virtualenv by doing:
```bash
├── src
│   └── trt_log_inspector
│   ├── core
│   │   └── __init__.py
│   ├── __init__.py
│   └── utils
│   ├── file_helper.py
│   └── __init__.py
└── tests
├── core
└── utils
└── test_file_helper.py
```

Below is an example command to make the test file, and to run them. Tests are made with [pytest](https://docs.pytest.org/en/8.2.x/index.html):

```bash
touch tests/[test_file_name.py]
rye test
```

Build the package as a wheel to be able to use this package in a script. Below is an example to build the wheel and to install it in a virtualenv. The generated wheel and tar.gz file is available in `dist/`.

```bash
rye build
pip install trt_log_inspector-0.1.0-py3-none-any.whl`
```

## License

The TensorRT Log Inspector is open-sourced, licensed under the [MIT License](https://opensource.org/licenses/MIT).
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ authors = [
]
dependencies = [
"prettytable>=3.10.0",
"pandas>=2.2.2",
"matplotlib>=3.9.0",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand All @@ -19,7 +21,6 @@ build-backend = "hatchling.build"
managed = true
dev-dependencies = [
"pytest>=8.2.2",
"coverage>=7.5.3",
]

[tool.hatch.metadata]
Expand Down
31 changes: 30 additions & 1 deletion requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,44 @@
# generate-hashes: false

-e file:.
coverage==7.5.3
contourpy==1.2.1
# via matplotlib
cycler==0.12.1
# via matplotlib
fonttools==4.53.0
# via matplotlib
iniconfig==2.0.0
# via pytest
kiwisolver==1.4.5
# via matplotlib
matplotlib==3.9.0
# via trt-log-inspector
numpy==2.0.0
# via contourpy
# via matplotlib
# via pandas
packaging==24.1
# via matplotlib
# via pytest
pandas==2.2.2
# via trt-log-inspector
pillow==10.3.0
# via matplotlib
pluggy==1.5.0
# via pytest
prettytable==3.10.0
# via trt-log-inspector
pyparsing==3.1.2
# via matplotlib
pytest==8.2.2
python-dateutil==2.9.0.post0
# via matplotlib
# via pandas
pytz==2024.1
# via pandas
six==1.16.0
# via python-dateutil
tzdata==2024.1
# via pandas
wcwidth==0.2.13
# via prettytable
31 changes: 31 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,38 @@
# generate-hashes: false

-e file:.
contourpy==1.2.1
# via matplotlib
cycler==0.12.1
# via matplotlib
fonttools==4.53.0
# via matplotlib
kiwisolver==1.4.5
# via matplotlib
matplotlib==3.9.0
# via trt-log-inspector
numpy==2.0.0
# via contourpy
# via matplotlib
# via pandas
packaging==24.1
# via matplotlib
pandas==2.2.2
# via trt-log-inspector
pillow==10.3.0
# via matplotlib
prettytable==3.10.0
# via trt-log-inspector
pyparsing==3.1.2
# via matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
# via pandas
pytz==2024.1
# via pandas
six==1.16.0
# via python-dateutil
tzdata==2024.1
# via pandas
wcwidth==0.2.13
# via prettytable
11 changes: 11 additions & 0 deletions src/trt_log_inspector/coreutils/display_log_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod


class DisplayLogData(ABC):
@abstractmethod
def display_table(self):
pass

@abstractmethod
def display_plot(self):
pass
7 changes: 7 additions & 0 deletions src/trt_log_inspector/coreutils/validations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class InvalidLogFileError(Exception):
"""Raised if log file is missing identities"""
def __init__(self, missing: set[str], message: str="Log file is invalid, missing identities") -> None:
self.message = message
self.missing = missing
super().__init__(f"{message}: {missing}")

139 changes: 139 additions & 0 deletions src/trt_log_inspector/trt_log_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import re
from typing import Self, override
from collections.abc import Generator

from prettytable import PrettyTable
import pandas as pd # type: ignore[reportMissingStubs]
import matplotlib.pyplot as plt

from trt_log_inspector.coreutils.display_log_data import DisplayLogData
from trt_log_inspector.coreutils.validations import InvalidLogFileError


class TrtLogFile(DisplayLogData):
"""
Core class for examining the TensorRT (TRT) log file.
"""

def __init__(self, name: str, path: str) -> None:
self.name: str = name
self.path: str = path
self.results: list[dict[str, object]] = []

self._check_trt_log_validity()

def _parse_line(
self, query: None | list[re.Pattern[str]] = None
) -> Generator[str, None | list[re.Pattern[str]], None]:
"""
Attempt to open the file from `self.path` and apply query (if specified) to filter the text.

Args:
- query: compiled regex pattern to match against.

Yields:
- str: A stripped line from the file that matched the query.

Raises:
- FileNotFoundError: If file from self.path does not exist.
"""

try:
with open(self.path) as f:
for line in f:
if query is not None and any(p.search(line) for p in query):
yield line.strip()
elif query is None:
yield line.strip()
except FileNotFoundError:
raise FileNotFoundError(f"{self.path} cannot be opened")

def _check_trt_log_validity(self) -> None:
"""
Checking validity of log file by matching against some information:
- Input filename
- ONNX IR version
- Opset version
- Producer name
- Producer version

Raises:
- InvalidLogFileError: If any information is missing in the log file.
"""

identities = [
"Input filename",
"ONNX IR version",
"Opset version",
"Producer name",
"Producer version",
]
re_ident_filter = [re.compile(f"{pattern}.*") for pattern in identities]
line_parser = self._parse_line(query=re_ident_filter)

matches: set[str] = set()
for pr in line_parser:
for id in identities:
if id in pr:
matches.add(id)

missing = set(identities) - matches
if missing:
raise InvalidLogFileError(missing)

def conversion_duration_info(self) -> Self:
"""
Extract conversion duration information for each stages of conversion.

Returns:
- List[Dict]: Information about each stage and duration (in seconds)
"""

stages = [
"Formats and tactics selection",
"Engine generation",
"Calibration",
"Post Processing Calibration",
"Configuring builder",
"Graph construction and optimization",
"Finished engine building",
]
re_stages_filter = [
re.compile(rf"{pattern}.*?(\d+\.\d+)") for pattern in stages
]
line_parser = self._parse_line(query=re_stages_filter)

for pr in line_parser:
detected_stage = ""
for s in stages:
if s in pr:
detected_stage = s

self.results.append(
{
"stage": detected_stage,
"duration_s": float(re.findall(r"\d+\.\d+", pr)[0]), # type: ignore[reportAny]
}
)

return self

@override
def display_table(self) -> None:
table = PrettyTable()
table.field_names = list(self.results[0].keys())
table.align = "l"

for r in self.results:
table.add_row(list(r.values())) # type: ignore[reportUnknownMemberType]

print(table)

@override
def display_plot(self):
rec = pd.DataFrame.from_records(self.results) # type:ignore[reportUnknownMemberType]

rec.plot(x="stage", y="duration_s", kind="bar") # type:ignore[reportUnknownMemberType]
plt.tight_layout()
plt.grid() # type:ignore[reportUnknownMemberType]
plt.show() # type:ignore[reportUnknownMemberType]
20 changes: 0 additions & 20 deletions src/trt_log_inspector/utils/file_helper.py

This file was deleted.

Empty file added tests/test_trt_log_file.py
Empty file.
10 changes: 0 additions & 10 deletions tests/utils/test_file_helper.py

This file was deleted.

Loading