Skip to content

Commit

Permalink
Apply black formatting and checking to python files
Browse files Browse the repository at this point in the history
Signed-off-by: Pierre R. Mai <[email protected]>
  • Loading branch information
pmai committed Jan 15, 2024
1 parent ad258d9 commit 493bd51
Show file tree
Hide file tree
Showing 18 changed files with 852 additions and 396 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/protobuf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ jobs:
python -m pip install --upgrade pip
python -m pip install -r requirements_develop.txt
- name: Check black format
run: black --check --diff .

- name: Install Doxygen
run: sudo apt-get install doxygen graphviz

Expand Down
73 changes: 44 additions & 29 deletions format/OSITrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from osi3.osi_groundtruth_pb2 import GroundTruth
from osi3.osi_sensordata_pb2 import SensorData
import warnings
warnings.simplefilter('default')

SEPARATOR = b'$$__$$'
warnings.simplefilter("default")

SEPARATOR = b"$$__$$"
SEPARATOR_LENGTH = len(SEPARATOR)
BUFFER_SIZE = 1000000

Expand All @@ -31,7 +32,7 @@ def get_size_from_file_stream(file_object):
MESSAGES_TYPE = {
"SensorView": SensorView,
"GroundTruth": GroundTruth,
"SensorData": SensorData
"SensorData": SensorData,
}


Expand All @@ -49,15 +50,15 @@ def __init__(self, path=None, type_name="SensorView"):
def from_file(self, path, type_name="SensorView", max_index=-1, format_type=None):
"""Import a scenario from a file"""

if path.lower().endswith(('.lzma', '.xz')):
if path.lower().endswith((".lzma", ".xz")):
self.scenario_file = lzma.open(path, "rb")
else:
self.scenario_file = open(path, "rb")

self.type_name = type_name
self.format_type = format_type

if self.format_type == 'separated':
if self.format_type == "separated":
# warnings.warn("The separated trace files will be completely removed in the near future. Please convert them to *.osi files with the converter in the main OSI repository.", PendingDeprecationWarning)
self.timestep_count = self.retrieve_message_offsets(max_index)
else:
Expand All @@ -73,7 +74,7 @@ def retrieve_message_offsets(self, max_index):
scenario_size = get_size_from_file_stream(self.scenario_file)

if max_index == -1:
max_index = float('inf')
max_index = float("inf")

buffer_deque = deque(maxlen=2)

Expand All @@ -100,7 +101,7 @@ def retrieve_message_offsets(self, max_index):
self.scenario_file.seek(message_offset)

while eof and found != -1:
buffer = buffer[found + SEPARATOR_LENGTH:]
buffer = buffer[found + SEPARATOR_LENGTH :]
found = buffer.find(SEPARATOR)

buffer_offset = scenario_size - len(buffer)
Expand All @@ -126,7 +127,7 @@ def retrieve_message(self):
self.message_offsets = [0]
eof = False

# TODO Implement buffering for the scenarios
# TODO Implement buffering for the scenarios
self.scenario_file.seek(0)
serialized_message = self.scenario_file.read()
INT_LENGTH = len(struct.pack("<L", 0))
Expand All @@ -135,8 +136,12 @@ def retrieve_message(self):
i = 0
while i < len(serialized_message):
message = MESSAGES_TYPE[self.type_name]()
message_length = struct.unpack("<L", serialized_message[i:INT_LENGTH+i])[0]
message.ParseFromString(serialized_message[i+INT_LENGTH:i+INT_LENGTH+message_length])
message_length = struct.unpack(
"<L", serialized_message[i : INT_LENGTH + i]
)[0]
message.ParseFromString(
serialized_message[i + INT_LENGTH : i + INT_LENGTH + message_length]
)
i += message_length + INT_LENGTH
self.message_offsets.append(i)

Expand All @@ -153,7 +158,7 @@ def get_message_by_index(self, index):
Get a message by its index. Try first to get it from the cache made
by the method ``cache_messages_in_index_range``.
"""
return next(self.get_messages_in_index_range(index, index+1))
return next(self.get_messages_in_index_range(index, index + 1))

def get_messages(self):
return self.get_messages_in_index_range(0, len(self.message_offsets))
Expand All @@ -164,26 +169,28 @@ def get_messages_in_index_range(self, begin, end):
"""
self.scenario_file.seek(self.message_offsets[begin])
abs_first_offset = self.message_offsets[begin]
abs_last_offset = self.message_offsets[end] \
if end < len(self.message_offsets) \
abs_last_offset = (
self.message_offsets[end]
if end < len(self.message_offsets)
else self.retrieved_scenario_size
)

rel_message_offsets = [
abs_message_offset - abs_first_offset
for abs_message_offset in self.message_offsets[begin:end]
]

if self.format_type == "separated":
message_sequence_len = abs_last_offset - \
abs_first_offset - SEPARATOR_LENGTH
serialized_messages_extract = self.scenario_file.read(
message_sequence_len)
message_sequence_len = abs_last_offset - abs_first_offset - SEPARATOR_LENGTH
serialized_messages_extract = self.scenario_file.read(message_sequence_len)

for rel_index, rel_message_offset in enumerate(rel_message_offsets):
rel_begin = rel_message_offset
rel_end = rel_message_offsets[rel_index + 1] - SEPARATOR_LENGTH \
if rel_index + 1 < len(rel_message_offsets) \
rel_end = (
rel_message_offsets[rel_index + 1] - SEPARATOR_LENGTH
if rel_index + 1 < len(rel_message_offsets)
else message_sequence_len
)
message = MESSAGES_TYPE[self.type_name]()
serialized_message = serialized_messages_extract[rel_begin:rel_end]
message.ParseFromString(serialized_message)
Expand Down Expand Up @@ -212,27 +219,35 @@ def get_messages_in_index_range(self, begin, end):

def make_readable(self, name, interval=None, index=None):
self.scenario_file.seek(0)
serialized_message = self.scenario_file.read()
serialized_message = self.scenario_file.read()
message_length = len(serialized_message)

if message_length > 1000000000:
# Throw a warning if trace file is bigger than 1GB
gb_size_input = round(message_length/1000000000, 2)
gb_size_output = round(3.307692308*message_length/1000000000, 2)
warnings.warn(f"The trace file you are trying to make readable has the size {gb_size_input}GB. This will generate a readable file with the size {gb_size_output}GB. Make sure you have enough disc space and memory to read the file with your text editor.", ResourceWarning)

with open(name, 'a') as f:

gb_size_input = round(message_length / 1000000000, 2)
gb_size_output = round(3.307692308 * message_length / 1000000000, 2)
warnings.warn(
f"The trace file you are trying to make readable has the size {gb_size_input}GB. This will generate a readable file with the size {gb_size_output}GB. Make sure you have enough disc space and memory to read the file with your text editor.",
ResourceWarning,
)

with open(name, "a") as f:
if interval is None and index is None:
for i in self.get_messages():
f.write(str(i))

if interval is not None and index is None:
if type(interval) == tuple and len(interval) == 2 and interval[0]<interval[1]:
if (
type(interval) == tuple
and len(interval) == 2
and interval[0] < interval[1]
):
for i in self.get_messages_in_index_range(interval[0], interval[1]):
f.write(str(i))
else:
raise Exception("Argument 'interval' needs to be a tuple of length 2! The first number must be smaller then the second.")
raise Exception(
"Argument 'interval' needs to be a tuple of length 2! The first number must be smaller then the second."
)

if interval is None and index is not None:
if type(index) == int:
Expand Down
75 changes: 44 additions & 31 deletions format/osi2read.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,60 @@
'''
"""
This program converts serialized txt/osi trace files into a human readable txth file.
Example usage:
python3 osi2read.py -d trace.osi -o myreadableosifile
python3 osi2read.py -d trace.txt -f separated -o myreadableosifile
'''
"""

from OSITrace import OSITrace
import struct
import lzma
import argparse
import os


def command_line_arguments():
""" Define and handle command line interface """
"""Define and handle command line interface"""

dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

parser = argparse.ArgumentParser(
description='Convert a serialized osi/txt trace file to a readable txth output.',
prog='osi2read converter')
parser.add_argument('--data', '-d',
help='Path to the file with serialized data.',
type=str)
parser.add_argument('--type', '-t',
help='Name of the type used to serialize data.',
choices=['SensorView', 'GroundTruth', 'SensorData'],
default='SensorView',
type=str,
required=False)
parser.add_argument('--output', '-o',
help='Output name of the file.',
default='converted.txth',
type=str,
required=False)
parser.add_argument('--format', '-f',
help='Set the format type of the trace.',
choices=['separated', None],
default=None,
type=str,
required=False)
description="Convert a serialized osi/txt trace file to a readable txth output.",
prog="osi2read converter",
)
parser.add_argument(
"--data", "-d", help="Path to the file with serialized data.", type=str
)
parser.add_argument(
"--type",
"-t",
help="Name of the type used to serialize data.",
choices=["SensorView", "GroundTruth", "SensorData"],
default="SensorView",
type=str,
required=False,
)
parser.add_argument(
"--output",
"-o",
help="Output name of the file.",
default="converted.txth",
type=str,
required=False,
)
parser.add_argument(
"--format",
"-f",
help="Set the format type of the trace.",
choices=["separated", None],
default=None,
type=str,
required=False,
)

return parser.parse_args()


def main():
# Handling of command line arguments
args = command_line_arguments()
Expand All @@ -51,13 +63,14 @@ def main():
trace = OSITrace()
trace.from_file(path=args.data, type_name=args.type, format_type=args.format)

args.output = args.output.split('.', 1)[0] + '.txth'
args.output = args.output.split(".", 1)[0] + ".txth"

if args.output == 'converted.txth':
args.output = args.data.split('.', 1)[0] + '.txth'
if args.output == "converted.txth":
args.output = args.data.split(".", 1)[0] + ".txth"

trace.make_readable(args.output)
trace.scenario_file.close()

trace.scenario_file.close()


if __name__ == "__main__":
main()
main()
Loading

0 comments on commit 493bd51

Please sign in to comment.