diff --git a/tests/test_metric.py b/tests/test_metric.py index afca7469..00484486 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -99,7 +99,7 @@ def test_energy_histogram( cpu_histogram.labels = MagicMock(return_value=cpu_histogram) cpu_histogram.observe = MagicMock() - if histogram_metric.dram_histogram: + if histogram_metric.dram_histograms: for _dram_index, dram_histogram in histogram_metric.dram_histograms.items(): dram_histogram.labels = MagicMock(return_value=dram_histogram) dram_histogram.observe = MagicMock() diff --git a/zeus/metric.py b/zeus/metric.py index 91edfc90..47e5b446 100644 --- a/zeus/metric.py +++ b/zeus/metric.py @@ -60,9 +60,9 @@ def __init__( gpu_indices: list, prometheus_url: str, job: str, - gpu_bucket_range: list[float] | None, - cpu_bucket_range: list[float] | None, - dram_bucket_range: list[float] | None, + gpu_bucket_range: list[float] | None = None, + cpu_bucket_range: list[float] | None = None, + dram_bucket_range: list[float] | None = None, ) -> None: """Initialize the EnergyHistogram class. @@ -84,29 +84,41 @@ def __init__( Raises: ValueError: If any of the bucket ranges (GPU, CPU, DRAM) is an empty list. """ - if not gpu_bucket_range: + self.gpu_bucket_range = ( + [50.0, 100.0, 200.0, 500.0, 1000.0] + if gpu_bucket_range is None + else gpu_bucket_range + ) + self.cpu_bucket_range = ( + [10.0, 20.0, 50.0, 100.0, 200.0] + if cpu_bucket_range is None + else cpu_bucket_range + ) + self.dram_bucket_range = ( + [5.0, 10.0, 20.0, 50.0, 150.0] + if dram_bucket_range is None + else dram_bucket_range + ) + self.cpu_indices = cpu_indices + self.gpu_indices = gpu_indices + self.prometheus_url = prometheus_url + self.job = job + + self.registry = CollectorRegistry() + + if gpu_bucket_range == []: raise ValueError( "GPU bucket range cannot be empty. Please provide a valid range or omit the argument to use defaults." ) - if not cpu_bucket_range: + if cpu_bucket_range == []: raise ValueError( "CPU bucket range cannot be empty. Please provide a valid range or omit the argument to use defaults." ) - if not dram_bucket_range: + if dram_bucket_range == []: raise ValueError( "DRAM bucket range cannot be empty. Please provide a valid range or omit the argument to use defaults." ) - self.gpu_bucket_range = gpu_bucket_range or [50.0, 100.0, 200.0, 500.0, 1000.0] - self.cpu_bucket_range = cpu_bucket_range or [10.0, 20.0, 50.0, 100.0, 200.0] - self.dram_bucket_range = dram_bucket_range or [5.0, 10.0, 20.0, 50.0, 150.0] - self.cpu_indices = cpu_indices - self.gpu_indices = gpu_indices - self.prometheus_url = prometheus_url - self.job = job - - self.registry = CollectorRegistry() - # Initialize GPU histograms self.gpu_histograms = {} if self.gpu_indices: