diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 78f2a06a720..2b0bc94b6b5 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -97,6 +97,42 @@ py_test( ], ) +py_binary( + name = "compress", + srcs = [ + "compress.py", + ], + deps = [ + ":metadata_py", + ":model_facade", + ":spec", + "@absl_py//absl:app", + "@absl_py//absl/flags", + "@flatbuffers//:runtime_py", + requirement("bitarray"), + requirement("numpy"), + ], +) + +py_test( + name = "compress_test", + size = "small", + srcs = [ + "compress_test.py", + ], + deps = [ + ":compress", + ":metadata_py", + ":model_facade", + ":spec", + ":test_models", + "//tensorflow/lite/python:schema_py", + requirement("bitarray"), + requirement("numpy"), + requirement("tensorflow"), + ], +) + py_library( name = "model_facade", srcs = ["model_facade.py"], diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py new file mode 100644 index 00000000000..18390d0d735 --- /dev/null +++ b/tensorflow/lite/micro/compression/compress.py @@ -0,0 +1,306 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model compression library and CLI. + +See USAGE. +""" + +import bitarray +import bitarray.util +from dataclasses import dataclass, field +import sys +from typing import ByteString, Iterable + +import absl.app +import absl.flags +import flatbuffers +import numpy as np + +from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema + +USAGE = f"""\ +Usage: compress.py --input --spec [--output ] + +Produce a compressed model from the input model by compressing tensors +according to the instructions in the spec file. The spec file lists the tensors +to compress, the compression methods to use on each tensor, and any parameters +for each compression method. + +The spec file is a YAML-format file with a dictionary at the root, containing a +key "tensors" with a list of tensors to compress as its value. E.g.: + +--- +{spec.EXAMPLE_YAML_SPEC} +--- + +The only compression method currently implemented is "lut", i.e., +Look-Up-Table. This method requires the tensor in the input model to have a +small number of unique values, fewer than or equal to 2**index_bitwidth. LUT +compression collects these values into a lookup table, and rewrites the tensor +as bitwidth-wide integer indices into that lookup table. Presumably, the input +model has been trained or preprocessed in a way that the tensor values +are binned into a meaningful, limited set. +""" + +# A compressed model augments the usual .tflite flatbuffer with a flatbuffer of +# its own containing compression metadata, stored at the buffer index stored at +# the following key in the .tflite flatbuffer's metadata map. +TFLITE_METADATA_KEY = "COMPRESSION_METADATA" + + +class CompressionError(Exception): + """Raised when compression fails for the reason documented in the message.""" + + def __init__(self, message, wrapped_exception=None): + super().__init__(f"{message}: {str(wrapped_exception)}") + self.original_exception = wrapped_exception + + +class _MetadataBuilder: + """Builder for the compression metadata flatbuffer.""" + + def __init__(self): + self._metadata = schema.MetadataT() + self._metadata.subgraphs = [] + + def compile(self) -> bytearray: + """Packs the metadata into a binary array and returns it. + """ + builder = flatbuffers.Builder(1 * 2**10) + root = self._metadata.Pack(builder) + builder.Finish(root) + return builder.Output() + + def subgraph(self, index: int): + """Return subgraph at index, adding subgraphs if necessary. + """ + while len(self._metadata.subgraphs) <= index: + self._add_subgraph() + return self._metadata.subgraphs[index] + + def add_lut_tensor(self, subgraph_id: int): + """Add LUT tensor to the given subgraph and return it. + """ + tensor = schema.LutTensorT() + self.subgraph(subgraph_id).lutTensors.append(tensor) + return tensor + + def _add_subgraph(self): + subgraph = schema.SubgraphT() + subgraph.lutTensors = [] + self._metadata.subgraphs.append(subgraph) + return subgraph + + +@dataclass +class _LutCompressedArray: + compression_axis: int = 0 + lookup_tables: list[np.ndarray] = field(default_factory=list) + indices: np.ndarray = field(default_factory=lambda: np.array([])) + + @property + def index_bitwidth(self) -> int: + """Returns the number of bits required to encode the indices.""" + if self.indices is None: + raise ValueError + + max_index = int(np.max(self.indices)) + return max_index.bit_length() or 1 + + +def _lut_compress_array(tensor: np.ndarray, axis: int) -> _LutCompressedArray: + """Compresses using a lookup table per subarray along the given axis. + + Compressing a tensor with a lookup table per subarray along a particular axis + is analogous to quantizing a tensor with different quantization parameters + per subarray along a particular axis (dimension). + """ + compressed = _LutCompressedArray() + compressed.compression_axis = axis + + # Iterate over subarrays along the compression axis + subarray_indices = [] + for subarray in np.moveaxis(tensor, axis, 0): + values, indices = np.unique(subarray, return_inverse=True) + compressed.lookup_tables.append(values) + indices = indices.reshape(subarray.shape) + subarray_indices.append(indices) + + # Reconstruct a tensor of indices from the subarrays + stacked = np.stack(subarray_indices, axis=0) + compressed.indices = np.moveaxis(stacked, 0, axis) + + return compressed + + +def _check_lut_compression(compression) -> spec.LookUpTableCompression: + if len(compression) != 1: + raise CompressionError("Each tensor must have exactly one compression") + if not isinstance(compression[0], spec.LookUpTableCompression): + raise CompressionError('Only "lut" compression may be specified') + + return compression[0] + + +def _identify_compression_axis(tensor: model_facade._Tensor) -> int: + """Finds the axis along which to compress. + + Use the quantization axis, else the NWHC channel dimension. If necessary, + an user-specified override could be added to the compression spec schema. + """ + if tensor.quantization is not None: + axis = tensor.quantization.quantizedDimension + else: + axis = tensor.array.ndim - 1 + + return axis + + +def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor): + """Applies business logic regarding specified bitwidth. + + It is an error if the bitwidth required to compress a tensor exceeds the + specified bitwith, and a warning if the tensor can be compressed in less than + the specified bitwidth. The latter is allowed, and is not an error, to permit + testing with larger bitwidths without re-binning a model. + """ + if compressed > specified: + raise CompressionError( + f"index_bitwidth too small: {compressed} bits needed to " + f"enumerate unique values in tensor specified in {spec}") + elif compressed < specified: + print( + f"warning: index_bitwidth too large: only {compressed} " + f"bits needed to enumerate unique values in tensor specified in {spec}", + file=sys.stderr) + + +def _pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: + """Packs indices into a bytearray using bitwidth-sized fields. + """ + endianness = "big" + bits = bitarray.bitarray(endian=endianness) + for i in indices.ravel(): + bits.extend( + bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) + return bits.tobytes() + + +def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray: + """Packs the value tables of a LutCompressedArray. + + Pack the value tables of a LutCompressedArray into a bytes object in the + format writable to a value_table buffer in the .tflite flatbuffer. The + tables, one per subarray, are concatinated. + """ + buffer = bytearray() + for t in tables: + padding_needed = table_len - len(t) + padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) + buffer.extend(padded.tobytes()) + + return buffer + + +def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: + """Compresses a model .tflite flatbuffer. + + Args: + model_in: the original, uncompressed .tflite flatbuffer + specs: an iterable of compression specs, see module spec.py + + Returns: + A compressed flatbuffer. + """ + model = model_facade.read(model_in) + metadata = _MetadataBuilder() + + for spec in specs: + try: + tensor = model.subgraphs[spec.subgraph].tensors[spec.tensor] + lut_compression = _check_lut_compression(spec.compression) + spec_bitwidth = lut_compression.index_bitwidth + axis = _identify_compression_axis(tensor) + compressed = _lut_compress_array(tensor.array, axis) + _check_bitwidth(compressed.index_bitwidth, spec_bitwidth, spec) + + # overwrite tensor data with indices + tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) + + # write value buffer + value_buffer = model.add_buffer() + value_buffer.data = _pack_lookup_tables(compressed.lookup_tables, + 2**spec_bitwidth) + # add compression metadata for tensor + lut_tensor = metadata.add_lut_tensor(subgraph_id=tensor.subgraph.index) + lut_tensor.tensor = tensor.index + lut_tensor.valueBuffer = value_buffer.index + lut_tensor.indexBitwidth = spec_bitwidth + + except Exception as e: + raise CompressionError(f"error compressing {spec}") from e + + # add compression metadata to model + model.add_metadata(TFLITE_METADATA_KEY, metadata.compile()) + + return model.compile() + + +def _fail_w_usage() -> int: + absl.app.usage() + return 1 + + +FLAGS = absl.flags.FLAGS +absl.flags.DEFINE_string("input", None, help="uncompressed .tflite flatbuffer") +absl.flags.DEFINE_string("spec", None, help="specfile (see module spec.py)") +absl.flags.DEFINE_string("output", None, help="compressed .tflite flatbuffer") + + +def main(argv): + if len(argv) > 1: + # no positional arguments accepted + return _fail_w_usage() + + in_path = FLAGS.input + if in_path is None: + return _fail_w_usage() + else: + with open(in_path, "rb") as in_file: + in_model = in_file.read() + + spec_path = FLAGS.spec + if spec_path is None: + return _fail_w_usage() + else: + with open(spec_path, "r") as spec_file: + specs = spec.parse_yaml(spec_file.read()) + + out_path = FLAGS.output + if out_path is None: + out_path = in_path.split(".tflite")[0] + ".compressed.tflite" + + compressed = compress(in_model, specs) + + with open(out_path, "wb") as out_file: + out_file.write(compressed) + + return 0 + + +if __name__ == "__main__": + sys.modules['__main__'].__doc__ = USAGE # for absl's use + absl.app.run(main) diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py new file mode 100644 index 00000000000..4c9d2fa3d01 --- /dev/null +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -0,0 +1,561 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bitarray +import bitarray.util +import numpy as np +import tensorflow as tf + +from tflite_micro.tensorflow.lite.micro.compression import compress +from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema +from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.micro.compression import test_models +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +class TestPackIndices(tf.test.TestCase): + + def test_basic_case(self): + indices = np.array([1, 2, 3]) + bitwidth = 4 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b0001_0010, 0b0011_0000]) + self.assertEqual(result, expected_bytes) + + def test_single_element(self): + indices = np.array([10]) + bitwidth = 8 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b0000_1010]) + self.assertEqual(result, expected_bytes) + + def test_different_bitwidth(self): + indices = np.array([1, 2, 3]) + bitwidth = 8 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b0000_0001, 0b0000_0010, 0b0000_0011]) + self.assertEqual(result, expected_bytes) + + def test_large_numbers(self): + indices = np.array([255, 128, 64]) + bitwidth = 8 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b1111_1111, 0b1000_0000, 0b0100_0000]) + self.assertEqual(result, expected_bytes) + + def test_multidimensional_array(self): + indices = np.array([[1, 2], [3, 4]]) + bitwidth = 4 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b0001_0010, 0b0011_0100]) + self.assertEqual(result, expected_bytes) + + def test_zero_bitwidth(self): + indices = np.array([0, 1, 2]) + bitwidth = 0 + with self.assertRaises(ValueError): + compress._pack_indices(indices, bitwidth) + + def test_empty_array(self): + indices = np.array([]) + bitwidth = 4 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = b"" + self.assertEqual(result, expected_bytes) + + def test_bitwidth_1(self): + indices = np.array([1, 0, 1, 1, 0, 1]) + bitwidth = 1 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b101101_00]) + self.assertEqual(result, expected_bytes) + + def test_bitwidth_2(self): + indices = np.array([1, 2, 3, 0]) + bitwidth = 2 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b01_10_11_00]) + self.assertEqual(result, expected_bytes) + + def test_bitwidth_3(self): + indices = np.array([1, 3, 5, 7]) + bitwidth = 3 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b001_011_10, 0b1_111_0000]) + self.assertEqual(result, expected_bytes) + + def test_bitwidth_5(self): + indices = np.array([1, 2, 16, 31]) + bitwidth = 5 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b00001_000, 0b10_10000_1, 0b1111_0000]) + self.assertEqual(result, expected_bytes) + + def test_bitwidth_7(self): + indices = np.array([1, 64, 127, 32]) + bitwidth = 7 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes( + [0b0000001_1, 0b000000_11, 0b11111_010, 0b0000_0000]) + self.assertEqual(result, expected_bytes) + + +class TestPackLookupTables(tf.test.TestCase): + + def test_int16_positive(self): + tables = [np.array([0x1234, 0x5678], dtype=' tuple[int, bitarray.bitarray, np.ndarray]: + """Helper: extracts the compressed tensor parts for a given spec. + + Returns: + bitwidth + indices + values + """ + subgraph_obj = self.compressed.subgraphs[subgraph] + tensor_obj = subgraph_obj.tensors[tensor] + lut_tensors = self.metadata.subgraphs[subgraph_obj.index].lutTensors + lut_tensor = next(t for t in lut_tensors if t.tensor == tensor_obj.index) + bitwidth = lut_tensor.indexBitwidth + + indices = bitarray.bitarray(buffer=tensor_obj.buffer.data, endian="big") + n_indices = np.prod(tensor_obj.shape) + indices = indices[:n_indices * bitwidth] # trim possible padding + + value_buffer = self.compressed.buffers[lut_tensor.valueBuffer] + values = np.frombuffer(value_buffer.data, dtype=tensor_obj.dtype) + + return bitwidth, indices, values + + def _make_indices(self, s: str) -> bitarray.bitarray: + """Helper: makes indices from "01" strings for use as expected values.""" + return bitarray.bitarray(s, endian="big") + + def test_compressed_uint8(self): + bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=0) + self.assertEqual(bitwidth, 4) + + # yapf: disable + expected_indices = self._make_indices(""" + 0000 0001 0010 0011 + 0100 0101 0110 0111 + 1000 1001 1010 1011 + 1100 1101 1110 1111 + """) + # yapf: enable + self.assertEqual(indices, expected_indices) + + expected_values = np.array(range(16), dtype=" _Buffer: + return self.subgraph.model.buffers[self._tensor_t.buffer] @property - def data(self): + def data(self) -> bytes: return self.buffer.data @property def dtype(self) -> np.dtype: - return _NP_DTYPES[self._tensor.type] + return _NP_DTYPES[self._tensor_t.type] @property def array(self) -> np.ndarray: @@ -165,62 +167,61 @@ def array(self) -> np.ndarray: is an array of fixed-width, integer fields. """ return np.frombuffer(self.data, - dtype=self.dtype).reshape(self._tensor.shape) + dtype=self.dtype).reshape(self._tensor_t.shape) @property - def quantization(self): - return self._tensor.quantization + def quantization(self) -> tflite.QuantizationParametersT | None: + return self._tensor_t.quantization class _Buffer: - def __init__(self, buffer, index, model): - self.buffer = buffer + def __init__(self, buffer_t: tflite.BufferT, index, model): + self._buffer_t = buffer_t self.index = index self.model = model @property - def data(self): - return self.buffer.data + def data(self) -> bytes: + return bytes(self._buffer_t.data) @data.setter - def data(self, value): - self.buffer.data = value - return self.data + def data(self, value: ByteString): + self._buffer_t.data = list(value) def extend(self, values: NDArray): - self.buffer.data.extend(values.tobytes()) + self._buffer_t.data.extend(values.tobytes()) class _Subgraph: - def __init__(self, subgraph, index, model): - self.subgraph = subgraph + def __init__(self, subgraph_t: tflite.SubGraphT, index: int, model: _Model): + self._subgraph_t = subgraph_t self.index = index self.model = model @property def operators(self) -> _Iterator[_Operator]: - return _Iterator(self.subgraph.operators, _Operator, parent=self) + return _Iterator(self._subgraph_t.operators, _Operator, parent=self) @property def tensors(self) -> _Iterator[_Tensor]: - return _Iterator(self.subgraph.tensors, _Tensor, parent=self) + return _Iterator(self._subgraph_t.tensors, _Tensor, parent=self) class _Model: """A facade for manipulating tflite.Model. """ - def __init__(self, representation: tflite.ModelT): - self.root = representation + def __init__(self, model_t: tflite.ModelT): + self._model_t = model_t def compile(self) -> bytearray: """Returns a tflite.Model flatbuffer. """ size_hint = 4 * 2**10 builder = flatbuffers.Builder(size_hint) - builder.Finish(self.root.Pack(builder)) + builder.Finish(self._model_t.Pack(builder)) return builder.Output() def add_buffer(self) -> _Buffer: @@ -228,9 +229,9 @@ def add_buffer(self) -> _Buffer: """ buffer = tflite.BufferT() buffer.data = [] - self.root.buffers.append(buffer) - index = len(self.root.buffers) - 1 - return _Buffer(buffer, index, self.root) + self._model_t.buffers.append(buffer) + index = len(self._model_t.buffers) - 1 + return _Buffer(buffer, index, self._model_t) def add_metadata(self, key, value): """Adds a key-value pair, writing value to a newly created buffer. @@ -240,19 +241,32 @@ def add_metadata(self, key, value): buffer = self.add_buffer() buffer.data = value metadata.buffer = buffer.index - self.root.metadata.append(metadata) + self._model_t.metadata.append(metadata) + + @property + def metadata(self) -> dict[str, _Buffer]: + """Returns the model's metadata as a dictionary to Buffer objects. + """ + result = {} + for m in self._model_t.metadata: + name = m.name.decode("utf-8") # type: ignore (fb library is wrong) + buffer = _Buffer(self._model_t.buffers[m.buffer], m.buffer, + self._model_t) + result[name] = buffer + + return result @property def operatorCodes(self): - return self.root.operatorCodes + return self._model_t.operatorCodes @property def subgraphs(self) -> _Iterator[_Subgraph]: - return _Iterator(self.root.subgraphs, _Subgraph, parent=self) + return _Iterator(self._model_t.subgraphs, _Subgraph, parent=self) @property def buffers(self) -> _Iterator[_Buffer]: - return _Iterator(self.root.buffers, _Buffer, parent=self) + return _Iterator(self._model_t.buffers, _Buffer, parent=self) def read(buffer: ByteString): diff --git a/tensorflow/lite/micro/compression/model_facade_test.py b/tensorflow/lite/micro/compression/model_facade_test.py index 892ef641359..e931e578f2b 100644 --- a/tensorflow/lite/micro/compression/model_facade_test.py +++ b/tensorflow/lite/micro/compression/model_facade_test.py @@ -27,6 +27,16 @@ "builtin_code": tflite.BuiltinOperator.ADD, }, }, + "metadata": { + 0: { + "name": "metadata0", + "buffer": 0 + }, + 1: { + "name": "metadata1", + "buffer": 0 + }, + }, "subgraphs": { 0: { "operators": { @@ -100,6 +110,11 @@ def testSubgraphIteration(self): for i, subgraph in enumerate(self.facade.subgraphs): self.assertEqual(i, subgraph.index) + def testMetadata(self): + self.assertIn("metadata0", self.facade.metadata) + self.assertIn("metadata1", self.facade.metadata) + self.assertNotIn("metadata2", self.facade.metadata) + class TestTensors(tf.test.TestCase): diff --git a/tensorflow/lite/micro/compression/test_models.py b/tensorflow/lite/micro/compression/test_models.py index 4896706cb1f..b2e79cbd02d 100644 --- a/tensorflow/lite/micro/compression/test_models.py +++ b/tensorflow/lite/micro/compression/test_models.py @@ -40,6 +40,12 @@ "builtin_code": tflite.BuiltinOperator.ADD, }, }, + "metadata": { + 0: { + "name": "metadata0", + "buffer": 0 + }, + }, "subgraphs": { 0: { "operators": { @@ -80,6 +86,9 @@ "shape": (16, 1), "type": tflite.TensorType.INT8, "buffer": 1, + "quantization": { + "quantized_dimension": 0, + }, }, }, }, @@ -116,6 +125,14 @@ def build(model_definition: dict) -> bytearray: root.operatorCodes.append(opcode_t) opcode_t.builtinCode = operator_code["builtin_code"] + root.metadata = [] + if "metadata" in model_definition: + for _, metadata in model_definition["metadata"].items(): + metadata_t = tflite.MetadataT() + metadata_t.name = metadata["name"] + metadata_t.buffer = metadata["buffer"] + root.metadata.append(metadata_t) + root.subgraphs = [] for id, subgraph in model_definition["subgraphs"].items(): assert id == len(root.subgraphs) @@ -139,6 +156,14 @@ def build(model_definition: dict) -> bytearray: tensor_t.shape = tensor["shape"] tensor_t.type = tensor["type"] tensor_t.buffer = tensor["buffer"] + + try: + d = tensor["quantization"]["quantized_dimension"] + tensor_t.quantization = tflite.QuantizationParametersT() + tensor_t.quantization.quantizedDimension = d + except KeyError: + tensor_t.quantization = None + subgraph_t.tensors.append(tensor_t) root.buffers = [] diff --git a/third_party/python_requirements.in b/third_party/python_requirements.in index 7b1cb26f57b..b1c06efca84 100644 --- a/third_party/python_requirements.in +++ b/third_party/python_requirements.in @@ -26,6 +26,7 @@ # is sensitive to the Python environment (interpreter version, etc.) in which # it is run. +bitarray hexdump tensorflow twine diff --git a/third_party/python_requirements.txt b/third_party/python_requirements.txt index 21f39145ad0..f7b94c1c50c 100644 --- a/third_party/python_requirements.txt +++ b/third_party/python_requirements.txt @@ -19,6 +19,145 @@ backports-tarfile==1.2.0 \ --hash=sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34 \ --hash=sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991 # via jaraco-context +bitarray==3.0.0 \ + --hash=sha256:000df24c183011b5d27c23d79970f49b6762e5bb5aacd25da9c3e9695c693222 \ + --hash=sha256:0027b8f3bb2bba914c79115e96a59b9924aafa1a578223a7c4f0a7242d349842 \ + --hash=sha256:00f9a88c56e373009ac3c73c55205cfbd9683fbd247e2f9a64bae3da78795252 \ + --hash=sha256:041c889e69c847b8a96346650e50f728b747ae176889199c49a3f31ae1de0e23 \ + --hash=sha256:0879f839ec8f079fa60c3255966c2e1aa7196699a234d4e5b7898fbc321901b5 \ + --hash=sha256:0b555006a7dea53f6bebc616a4d0249cecbf8f1fadf77860120a2e5dbdc2f167 \ + --hash=sha256:0b655c3110e315219e266b2732609fddb0857bc69593de29f3c2ba74b7d3f51a \ + --hash=sha256:0cecaf2981c9cd2054547f651537b4f4939f9fe225d3fc2b77324b597c124e40 \ + --hash=sha256:0e104f9399144fab6a892d379ba1bb4275e56272eb465059beef52a77b4e5ce6 \ + --hash=sha256:0ef5c787c8263c082a73219a69eb60a500e157a4ac69d1b8515ad836b0e71fb4 \ + --hash=sha256:12f19ede03e685c5c588ab5ed63167999295ffab5e1126c5fe97d12c0718c18f \ + --hash=sha256:1414a7102a3c4986f241480544f5c99f5d32258fb9b85c9c04e84e48c490ab35 \ + --hash=sha256:147542299f458bdb177f798726e5f7d39ab8491de4182c3c6d9885ed275a3c2b \ + --hash=sha256:150b7b29c36d9f1a24779aea723fdfc73d1c1c161dc0ea14990da27d4e947092 \ + --hash=sha256:153d7c416a70951dcfa73487af05d2f49c632e95602f1620cd9a651fa2033695 \ + --hash=sha256:184972c96e1c7e691be60c3792ca1a51dd22b7f25d96ebea502fe3c9b554f25d \ + --hash=sha256:18abdce7ab5d2104437c39670821cba0b32fdb9b2da9e6d17a4ff295362bd9dc \ + --hash=sha256:2055206ed653bee0b56628f6a4d248d53e5660228d355bbec0014bdfa27050ae \ + --hash=sha256:20f30373f0af9cb583e4122348cefde93c82865dbcbccc4997108b3d575ece84 \ + --hash=sha256:22b00f65193fafb13aa644e16012c8b49e7d5cbb6bb72825105ff89aadaa01e3 \ + --hash=sha256:251cd5bd47f542893b2b61860eded54f34920ea47fd5bff038d85e7a2f7ae99b \ + --hash=sha256:2855cc01ee370f7e6e3ec97eebe44b1453c83fb35080313145e2c8c3c5243afb \ + --hash=sha256:2ac67b658fa5426503e9581a3fb44a26a3b346c1abd17105735f07db572195b3 \ + --hash=sha256:2d9fe3ee51afeb909b68f97e14c6539ace3f4faa99b21012e610bbe7315c388d \ + --hash=sha256:2da91ab3633c66999c2a352f0ca9ae064f553e5fc0eca231d28e7e305b83e942 \ + --hash=sha256:2dad7ba2af80f9ec1dd988c3aca7992408ec0d0b4c215b65d353d95ab0070b10 \ + --hash=sha256:34fc13da3518f14825b239374734fce93c1a9299ed7b558c3ec1d659ec7e4c70 \ + --hash=sha256:369b6d457af94af901d632c7e625ca6caf0a7484110fc91c6290ce26bc4f1478 \ + --hash=sha256:37be5482b9df3105bad00fdf7dc65244e449b130867c3879c9db1db7d72e508b \ + --hash=sha256:3963b80a68aedcd722a9978d261ae53cb9bb6a8129cc29790f0f10ce5aca287a \ + --hash=sha256:39b38a3d45dac39d528c87b700b81dfd5e8dc8e9e1a102503336310ef837c3fd \ + --hash=sha256:3cd565253889940b4ec4768d24f101d9fe111cad4606fdb203ea16f9797cf9ed \ + --hash=sha256:3d47bc4ff9b0e1624d613563c6fa7b80aebe7863c56c3df5ab238bb7134e8755 \ + --hash=sha256:3fa5d8e4b28388b337face6ce4029be73585651a44866901513df44be9a491ab \ + --hash=sha256:42bf1b222c698b467097f58b9f59dc850dfa694dde4e08237407a6a103757aa3 \ + --hash=sha256:43b6c7c4f4a7b80e86e24a76f4c6b9b67d03229ea16d7d403520616535c32196 \ + --hash=sha256:44c3e78b60070389b824d5a654afa1c893df723153c81904088d4922c3cfb6ac \ + --hash=sha256:4683bff52f5a0fd523fb5d3138161ef87611e63968e1fcb6cf4b0c6a86970fe0 \ + --hash=sha256:47ccf9887bd595d4a0536f2310f0dcf89e17ab83b8befa7dc8727b8017120fda \ + --hash=sha256:4800c91a14656789d2e67d9513359e23e8a534c8ee1482bb9b517a4cfc845200 \ + --hash=sha256:4817d73d995bd2b977d9cde6050be8d407791cf1f84c8047fa0bea88c1b815bc \ + --hash=sha256:4839d3b64af51e4b8bb4a602563b98b9faeb34fd6c00ed23d7834e40a9d080fc \ + --hash=sha256:4ac2027ca650a7302864ed2528220d6cc6921501b383e9917afc7a2424a1e36d \ + --hash=sha256:4cb5702dd667f4bb10fed056ffdc4ddaae8193a52cd74cb2cdb54e71f4ef2dd1 \ + --hash=sha256:53e002ac1073ac70e323a7a4bfa9ab95e7e1a85c79160799e265563f342b1557 \ + --hash=sha256:545d36332de81e4742a845a80df89530ff193213a50b4cbef937ed5a44c0e5e5 \ + --hash=sha256:572a61fba7e3a710a8324771322fba8488d134034d349dcd036a7aef74723a80 \ + --hash=sha256:57d5ef854f8ec434f2ffd9ddcefc25a10848393fe2976e2be2c8c773cf5fef42 \ + --hash=sha256:5ddbf71a97ad1d6252e6e93d2d703b624d0a5b77c153b12f9ea87d83e1250e0c \ + --hash=sha256:5fa4b4d9fa90124b33b251ef74e44e737021f253dc7a9174e1b39f097451f7ca \ + --hash=sha256:628f93e9c2c23930bd1cfe21c634d6c84ec30f45f23e69aefe1fcd262186d7bb \ + --hash=sha256:648e7ce794928e8d11343b5da8ecc5b910af75a82ea1a4264d5d0a55c3785faa \ + --hash=sha256:656db7bdf1d81ec3b57b3cad7ec7276765964bcfd0eb81c5d1331f385298169c \ + --hash=sha256:666e44b0458bb2894b64264a29f2cc7b5b2cbcc4c5e9cedfe1fdbde37a8e329a \ + --hash=sha256:66a33a537e781eac3a352397ce6b07eedf3a8380ef4a804f8844f3f45e335544 \ + --hash=sha256:66d6134b7bb737b88f1d16478ad0927c571387f6054f4afa5557825a4c1b78e2 \ + --hash=sha256:67a0b56dd02f2713f6f52cacb3f251afd67c94c5f0748026d307d87a81a8e15c \ + --hash=sha256:6c33129b49196aa7965ac0f16fcde7b6ad8614b606caf01669a0277cef1afe1d \ + --hash=sha256:6d2a2ce73f9897268f58857ad6893a1a6680c5a6b28f79d21c7d33285a5ae646 \ + --hash=sha256:71ad0139c95c9acf4fb62e203b428f9906157b15eecf3f30dc10b55919225896 \ + --hash=sha256:7814c9924a0b30ecd401f02f082d8697fc5a5be3f8d407efa6e34531ff3c306a \ + --hash=sha256:787db8da5e9e29be712f7a6bce153c7bc8697ccc2c38633e347bb9c82475d5c9 \ + --hash=sha256:7cb885c043000924554fe2124d13084c8fdae03aec52c4086915cd4cb87fe8be \ + --hash=sha256:7cd021ada988e73d649289cee00428b75564c46d55fbdcb0e3402e504b0ae5ea \ + --hash=sha256:7e51e7f8289bf6bb631e1ef2a8f5e9ca287985ff518fe666abbdfdb6a848cb26 \ + --hash=sha256:7e9eee03f187cef1e54a4545124109ee0afc84398628b4b32ebb4852b4a66393 \ + --hash=sha256:7edb83089acbf2c86c8002b96599071931dc4ea5e1513e08306f6f7df879a48b \ + --hash=sha256:7f1c24be7519f16a47b7e2ad1a1ef73023d34d8cbe1a3a59b185fc14baabb132 \ + --hash=sha256:8330912be6cb8e2fbfe8eb69f82dee139d605730cadf8d50882103af9ac83bb4 \ + --hash=sha256:8a9eb510cde3fa78c2e302bece510bf5ed494ec40e6b082dec753d6e22d5d1b1 \ + --hash=sha256:8c9733d2ff9b7838ac04bf1048baea153174753e6a47312be14c83c6a395424b \ + --hash=sha256:904c1d5e3bd24f0c0d37a582d2461312033c91436a6a4f3bdeeceb4bea4a899d \ + --hash=sha256:928b8b6dfcd015e1a81334cfdac02815da2a2407854492a80cf8a3a922b04052 \ + --hash=sha256:9502c2230d59a4ace2fddfd770dad8e8b414cbd99517e7e56c55c20997c28b8d \ + --hash=sha256:96cf0898f8060b2d3ae491762ae871b071212ded97ff9e1e3a5229e9fefe544c \ + --hash=sha256:98a4070ddafabddaee70b2aa7cc6286cf73c37984169ab03af1782da2351059a \ + --hash=sha256:9929051feeaf8d948cc0b1c9ce57748079a941a1a15c89f6014edf18adaade84 \ + --hash=sha256:996d1b83eb904589f40974538223eaed1ab0f62be8a5105c280b9bd849e685c4 \ + --hash=sha256:9c6e52005e91803eb4e08c0a08a481fb55ddce97f926bae1f6fa61b3396b5b61 \ + --hash=sha256:9e3727ab63dfb6bde00b281934e2212bb7529ea3006c0031a556a84d2268bea5 \ + --hash=sha256:a0255bd05ec7165e512c115423a5255a3f301417973d20a80fc5bfc3f3640bcb \ + --hash=sha256:a2083dc20f0d828a7cdf7a16b20dae56aab0f43dc4f347a3b3039f6577992b03 \ + --hash=sha256:a3c36b2fcfebe15ad1c10a90c1d52a42bebe960adcbce340fef867203028fbe7 \ + --hash=sha256:a4f49ac31734fe654a68e2515c0da7f5bbdf2d52755ba09a42ac406f1f08c9d0 \ + --hash=sha256:a667ea05ba1ea81b722682276dbef1d36990f8908cf51e570099fd505a89f931 \ + --hash=sha256:a754c1464e7b946b1cac7300c582c6fba7d66e535cd1dab76d998ad285ac5a37 \ + --hash=sha256:a817ad70c1aff217530576b4f037dd9b539eb2926603354fcac605d824082ad1 \ + --hash=sha256:aa54c7e1da8cf4be0aab941ea284ec64033ede5d6de3fd47d75e77cafe986e9d \ + --hash=sha256:ab37da66a8736ad5a75a58034180e92c41e864da0152b84e71fcc253a2f69cd4 \ + --hash=sha256:ac06dd72ee1e1b6e312504d06f75220b5894af1fb58f0c20643698f5122aea76 \ + --hash=sha256:aca0a9cd376beaccd9f504961de83e776dd209c2de5a4c78dc87a78edf61839b \ + --hash=sha256:acc07211a59e2f245e9a06f28fa374d094fb0e71cf5366eef52abbb826ddc81e \ + --hash=sha256:aef404d5400d95c6ec86664df9924bde667c8865f8e33c9b7bd79823d53b3e5d \ + --hash=sha256:b1047999f1797c3ea7b7c85261649249c243308dcf3632840d076d18fa72f142 \ + --hash=sha256:b7d09ef06ba57bea646144c29764bf6b870fb3c5558ca098191e07b6a1d40bf7 \ + --hash=sha256:bcf0150ae0bcc4aa97bdfcb231b37bad1a59083c1b5012643b266012bf420e68 \ + --hash=sha256:bcf524a087b143ba736aebbb054bb399d49e77cf7c04ed24c728e411adc82bfa \ + --hash=sha256:beeb79e476d19b91fd6a3439853e4e5ba1b3b475920fa40d62bde719c8af786f \ + --hash=sha256:bf90aba4cff9e72e24ecdefe33bad608f147a23fa5c97790a5bab0e72fe62b6d \ + --hash=sha256:c23286abba0cb509733c6ce8f4013cd951672c332b2e184dbefbd7331cd234c8 \ + --hash=sha256:c2945e0390d1329c585c584c6b6d78be017d9c6a1288f9c92006fe907f69cc28 \ + --hash=sha256:c756a92cf1c1abf01e56a4cc40cb89f0ff9147f2a0be5b557ec436a23ff464d8 \ + --hash=sha256:c9e9fef0754867d88e948ce8351c9fd7e507d8514e0f242fd67c907b9cdf98b3 \ + --hash=sha256:ca79f02a98cbda1472449d440592a2fe2ad96fe55515a0447fa8864a38017cf8 \ + --hash=sha256:cb7302dbcfcb676f0b66f15891f091d0233c4fc23e1d4b9dc9b9e958156e347f \ + --hash=sha256:cb98d5b6eac4b2cf2a5a69f60a9c499844b8bea207059e9fc45c752436e6bb49 \ + --hash=sha256:cc83ea003dd75e9ade3291ef0585577dd5524aec0c8c99305c0aaa2a7570d6db \ + --hash=sha256:ce249ed981f428a8b61538ca82d3875847733d579dd40084ab8246549160f8a4 \ + --hash=sha256:cf0cc2e91dd38122dec2e6541efa99aafb0a62e118179218181eff720b4b8153 \ + --hash=sha256:d1a199e6d7c3bad5ba9d0e4dc00dde70ee7d111c9dfc521247fa646ef59fa57e \ + --hash=sha256:d1d5abf1d6d910599ac16afdd9a0ed3e24f3b46af57f3070cf2792f236f36e0b \ + --hash=sha256:d3f761184b93092077c7f6b7dad7bd4e671c1620404a76620da7872ceb576a94 \ + --hash=sha256:d756bfeb62ca4fe65d2af7a39249d442c05070c047d03729ad6cd4c2e9b0f0bd \ + --hash=sha256:d8c36ddc1923bcc4c11b9994c54eaae25034812a42400b7b8a86fe6d242166a2 \ + --hash=sha256:dbe1084935b942fab206e609fa1ed3f46ad1f2612fb4833e177e9b2a5e006c96 \ + --hash=sha256:dc1937a0ff2671797d35243db4b596329842480d125a65e9fe964bcffaf16dfc \ + --hash=sha256:dfea514e665af278b2e1d4deb542de1cd4f77413bee83dd15ae16175976ea8d5 \ + --hash=sha256:e008b7b4ce6c7f7a54b250c45c28d4243cc2a3bbfd5298fa7dac92afda229842 \ + --hash=sha256:e0e7f24a0b01e6e6a0191c50b06ca8edfdec1988d9d2b264d669d2487f4f4680 \ + --hash=sha256:e15c94d79810c5ab90ddf4d943f71f14332890417be896ca253f21fa3d78d2b1 \ + --hash=sha256:e56ba8be5f17dee0ffa6d6ce85251e062ded2faa3cbd2558659c671e6c3bf96d \ + --hash=sha256:e89ea59a3ed86a6eb150d016ed28b1bedf892802d0ed32b5659d3199440f3ced \ + --hash=sha256:e91d46d12781a14ccb8b284566b14933de4e3b29f8bc5e1c17de7a2001ad3b5b \ + --hash=sha256:ea40e98d751ed4b255db4a88fe8fb743374183f78470b9e9305aab186bf28ede \ + --hash=sha256:eb27c01b747649afd7e1c342961680893df6d8d81f832a6f04d8c8e03a8a54cc \ + --hash=sha256:ec5b0f2d13da53e0975ac15ecbe8badb463bdb0bebaa09457f4df3320421915c \ + --hash=sha256:ee040ad3b7dfa05e459713099f16373c1f2a6f68b43cb0575a66718e7a5daef4 \ + --hash=sha256:f12cc7c7638074918cdcc7491aff897df921b092ffd877227892d2686e98f876 \ + --hash=sha256:f536fc4d1a683025f9caef0bebeafd60384054579ffe0825bb9bd8c59f8c55b8 \ + --hash=sha256:f71f24b58e75a889b9915e3197865302467f13e7390efdea5b6afc7424b3a2ea \ + --hash=sha256:f75fc0198c955d840b836059bd43e0993edbf119923029ca60c4fc017cefa54a \ + --hash=sha256:f785af6b7cb07a9b1e5db0dea9ef9e3e8bb3d74874a0a61303eab9c16acc1999 \ + --hash=sha256:fbb645477595ce2a0fbb678d1cfd08d3b896e5d56196d40fb9e114eeab9382b3 \ + --hash=sha256:fcef31b062f756ba7eebcd7890c5d5de84b9d64ee877325257bcc9782288564a \ + --hash=sha256:fe606e728842389943a939258809dc5db2de831b1d2e0118515059e87f7bbc1a \ + --hash=sha256:fef4e3b3f2084b4dae3e5316b44cda72587dcc81f68b4eb2dbda1b8d15261b61 \ + --hash=sha256:ffd94b4803811c738e504a4b499fb2f848b2f7412d71e6b517508217c1d7929d + # via -r third_party/python_requirements.in certifi==2024.8.30 \ --hash=sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8 \ --hash=sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9