Skip to content

Commit

Permalink
Refactor the variable usage for improved clarity
Browse files Browse the repository at this point in the history
Signed-off-by: Carles Pey <[email protected]>
  • Loading branch information
cpey committed Nov 23, 2024
1 parent 5f82015 commit b386748
Showing 1 changed file with 67 additions and 54 deletions.
121 changes: 67 additions & 54 deletions chipsec/modules/tools/smm/smm_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,21 +199,54 @@ def get_info(self):
return f'duration {self.duration} code {self.code:02X} data {self.data:02X} ({gprs_info(self.gprs)})'


class smi_stats:
def __init__(self):
self.clear()

def clear(self):
self.count = 0
self.mean = 0
self.m2 = 0
self.stdev = 0
self.outliers = 0

#
# Computes the standard deviation using the Welford's online algorithm
#
def update_stats(self, duration):
self.count += 1
difference = duration - self.mean
self.mean += difference / self.count
self.m2 += difference * (duration - self.mean)
variance = self.m2 / self.count
self.stdev = math.sqrt(variance)

def get_info(self):
info = f'average {round(self.mean)} stddev {self.stdev:.2f} checked {self.count + self.outliers}'
return info

#
# Combines the statistics of the two data sets using parallel variance computation
#
def combine(self, partial):
self.outliers += partial.outliers
total_count = self.count + partial.count
difference = partial.mean - self.mean
self.m2 += partial.m2 + difference**2 * self.count * partial.count / total_count
self.count = total_count
variance = self.m2 / self.count
self.stdev = math.sqrt(variance)


class scan_track:
def __init__(self):
self.clear()
self.hist_smi_duration = 0
self.hist_smi_num = 0
self.outliers_hist = 0
current_smi_stats = smi_stats()
history_smi_stats = smi_stats()
history_smi_stats_tmp = smi_stats()
self.helper = OsHelper().get_default_helper()
self.helper.init()
self.smi_count = self.get_smi_count()
self.needs_calibration = True
self.calib_samples = 0
self.stdev = 0
self.m2 = 0
self.stdev_hist = 0
self.m2_hist = 0

def __del__(self):
self.helper.close()
Expand Down Expand Up @@ -251,73 +284,49 @@ def find_address_in_regs(self, gprs):
return key

def clear(self):
self.max = smi_info(0)
self.min = smi_info(2**32 - 1)
self.outlier = smi_info(0)
self.avg_smi_duration = 0
self.avg_smi_num = 0
self.outliers = 0
self.code = None
self.confirmed = False
self.contents_changed = False
self.needs_calibration = True
self.calib_samples = 0
self.stdev = 0
self.m2 = 0
current_smi_stats.clear()

def add(self, duration, code, data, gprs, confirmed=False):
def add(self, duration, code, data, gprs, contents_changed=False):
if not self.code:
self.code = code
outlier = self.is_outlier(duration)
if not outlier:
self.update_stdev(duration)
if duration > self.max.duration:
self.max.update(duration, code, data, gprs.copy())
elif duration < self.min.duration:
self.min.update(duration, code, data, gprs.copy())
self.current_smi_stats.update_stats(duration)
self.history_smi_stats_tmp.update_stats(duration)
elif self.is_slow_outlier(duration):
self.outliers += 1
self.outliers_hist += 1
self.current_smi_stats.outliers += 1
self.outlier.update(duration, code, data, gprs.copy())
self.confirmed = confirmed

#
# Computes the standard deviation using the Welford's online algorithm
#
def update_stdev(self, value):
self.avg_smi_num += 1
self.hist_smi_num += 1
difference = value - self.avg_smi_duration
difference_hist = value - self.hist_smi_duration
self.avg_smi_duration += difference / self.avg_smi_num
self.hist_smi_duration += difference_hist / self.hist_smi_num
self.m2 += difference * (value - self.avg_smi_duration)
self.m2_hist += difference_hist * (value - self.hist_smi_duration)
variance = self.m2 / self.avg_smi_num
variance_hist = self.m2_hist / self.hist_smi_num
self.stdev = math.sqrt(variance)
self.stdev_hist = math.sqrt(variance_hist)
self.contents_changed = contents_changed

def update_calibration(self, duration):
if not self.needs_calibration:
return
self.update_stdev(duration)
self.current_smi_stats.update_stats(duration)
self.history_smi_stats_tmp.update_stats(duration)
self.calib_samples += 1
if self.calib_samples >= SCAN_CALIB_SAMPLES:
self.needs_calibration = False

def is_slow_outlier(self, value):
ret = False
if value > self.avg_smi_duration + OUTLIER_STD_DEV * self.stdev:
if value > self.current_smi_stats.mean + OUTLIER_STD_DEV * self.current_smi_stats.stdev:
ret = True
if value > self.hist_smi_duration + OUTLIER_STD_DEV * self.stdev_hist:
if self.history_smi_stats.count and
value > self.history_smi_stats.mean + OUTLIER_STD_DEV * self.history_smi_stats.stdev:
ret = True
return ret

def is_fast_outlier(self, value):
ret = False
if value < self.avg_smi_duration - OUTLIER_STD_DEV * self.stdev:
if value < self.current_smi_stats.mean - OUTLIER_STD_DEV * self.current_smi_stat.stdev:
ret = True
if value < self.hist_smi_duration - OUTLIER_STD_DEV * self.stdev_hist:
if self.history_smi_stats.count and
value < self.history_smi_stats.mean - OUTLIER_STD_DEV * self.history_smi_stats.stdev:
ret = True
return ret

Expand All @@ -332,18 +341,17 @@ def is_outlier(self, value):
return ret

def skip(self):
return self.outliers or self.confirmed
return self.current_smi_stats.outliers or self.contents_changed

def found_outlier(self):
return bool(self.outliers)
return bool(self.current_smi_stats.outliers)

def get_total_outliers(self):
return self.outliers_hist
return self.history_smi_stats.outliers

def get_info(self):
avg = self.avg_smi_duration or self.hist_smi_duration
info = f'average {round(avg)} stddev {self.stdev:.2f} checked {self.avg_smi_num + self.outliers}'
if self.outliers:
info = self.current_smi_stats.get_info()
if self.current_smi_stats.outliers:
info += f'\n Identified outlier: {self.outlier.get_info()}'
return info

Expand All @@ -354,6 +362,10 @@ def log_smi_result(self, logger):
else:
logger.log(f'[*] {msg}')

def update_history_stats(self):
self.history_smi_stats.combine(self.current_smi_stats)
print(f"{self.history_smi_stats.get_info()}\n {self.history_smi_stats_tmp.get_info()}")


class smi_desc:
def __init__(self):
Expand Down Expand Up @@ -699,6 +711,7 @@ def test_fuzz(self, thread_id, smic_start, smic_end, _addr, _addr1, scan_mode=Fa
break
if scan_mode:
scan.log_smi_result(self.logger)
scan.update_history_stats()
scan.clear()

return bad_ptr_cnt, scan
Expand Down

0 comments on commit b386748

Please sign in to comment.