Skip to content

Commit

Permalink
Update wandb and cli reports by adding optimisation completion times
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoupas committed Nov 23, 2023
1 parent e007ad4 commit 6ad8c6b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 30 deletions.
32 changes: 22 additions & 10 deletions fpgaconvnet/optimiser/solvers/greedy_partition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import copy
import time
import wandb
import itertools
from tabulate import tabulate
Expand Down Expand Up @@ -28,8 +29,8 @@ class GreedyPartition(Solver):

def __post_init__(self):
if self.wandb_enabled:
self.pre_merge_log_tbl = wandb.Table(columns=["partition", f"part_{'latency' if self.objective == LATENCY else 'throughput'}", "slowdown", f"net_{'latency' if self.objective == LATENCY else 'throughput'}", "URAM (%)", "BRAM (%)", "DSP (%)", "LUT (%)", "FF (%)", "BW"])
self.final_log_tbl = wandb.Table(columns=["partition", f"part_{'latency' if self.objective == LATENCY else 'throughput'}", "slowdown", f"net_{'latency' if self.objective == LATENCY else 'throughput'}", "URAM (%)", "BRAM (%)", "DSP (%)", "LUT (%)", "FF (%)", "BW"])
self.pre_merge_log_tbl = wandb.Table(columns=["part", f"part_{'latency' if self.objective == LATENCY else 'throughput'}", "part_optim_time", "slowdown", f"net_{'latency' if self.objective == LATENCY else 'throughput'}", "URAM %", "BRAM %", "DSP %", "LUT %", "FF %", "BW"])
self.final_log_tbl = wandb.Table(columns=["part", f"part_{'latency' if self.objective == LATENCY else 'throughput'}", "part_optim_time", "slowdown", f"net_{'latency' if self.objective == LATENCY else 'throughput'}", "URAM %", "BRAM %", "DSP %", "LUT %", "FF %", "BW"])

def check_targets_met(self):
# stop the optimiser if targets are already met
Expand Down Expand Up @@ -61,7 +62,7 @@ def merge_memory_bound_partitions(self):
# return

# cache the network
net= copy.deepcopy(self.net)
net = copy.deepcopy(self.net)

self.update_partitions()
cost = self.get_cost()
Expand Down Expand Up @@ -124,19 +125,23 @@ def merge_memory_bound_partitions(self):
transforms.merge_horizontal(self.net, *horizontal_merges[1])
current_merge = horizontal_merges[1]

print(current_merge)
data = [["Attemting Partition Merging:", f"(Part {current_merge[0] + 1}, Part {current_merge[1] + 1})", "", "Total Partitions:", len(self.net.partitions) + 1]]
data_table = tabulate(data, headers="firstrow", tablefmt="youtrack")
print(data_table)
self.update_partitions()
status = self.run_solver(log_final=True)

if not status or self.get_cost() >= cost:
self.net= net
self.net = net
reject_list.append(current_merge)
print("reject")
data = [["Outcome:","Merge Rejected"]]
else:
for i, merge in enumerate(reject_list):
if merge[0] >= current_merge[1]:
reject_list[i] = (merge[0]-1,merge[1]-1)
print("accept")
data = [["Outcome:", "Merge Accepted"]]
data_table = tabulate(data, headers="firstrow", tablefmt="youtrack")
print(data_table)

if self.wandb_enabled:
wandb.log({"final_solver": self.final_log_tbl})
Expand Down Expand Up @@ -450,6 +455,7 @@ def run_solver(self, log=True, log_final=False) -> bool:
assert "partition" not in self.transforms

for partition_index in range(len(self.net.partitions)):
part_start_time = time.perf_counter()
# don't use enumerate, copy.deepcopy creates the new partition object
if not self.net.partitions[partition_index].need_optimise:
continue
Expand Down Expand Up @@ -496,21 +502,27 @@ def run_solver(self, log=True, log_final=False) -> bool:
if self.objective != 1:
break

part_opt_time = time.perf_counter() - part_start_time
self.total_opt_time += part_opt_time
part_cost = self.get_cost([partition_index]) if self.objective == LATENCY else -self.get_cost([partition_index])
data = [[f"{partition_index+1}/{len(self.net.partitions)} single partition cost ({'latency' if self.objective == LATENCY else 'throughput'}):",
f"{part_cost:.4f}",
"",
"slowdown:",
self.net.partitions[partition_index].slow_down_factor]]
self.net.partitions[partition_index].slow_down_factor,
"partition optimisation time:",
f"{part_opt_time:.2f} sec",
"total optimisation time:",
f"{self.total_opt_time:.2f} sec",]]
data_table = tabulate(data, tablefmt="double_outline")
print(data_table)
if self.wandb_enabled:
self.wandb_log()
if log_final:
self.final_log_tbl.add_data(partition_index+1, part_cost, self.net.partitions[partition_index].slow_down_factor, -1, -1, -1, -1, -1, -1, -1)
self.final_log_tbl.add_data(partition_index+1, part_cost, part_opt_time, self.net.partitions[partition_index].slow_down_factor, -1, -1, -1, -1, -1, -1, -1)
self.solver_status(self.final_log_tbl)
else:
self.pre_merge_log_tbl.add_data(partition_index+1, part_cost, self.net.partitions[partition_index].slow_down_factor, -1, -1, -1, -1, -1, -1, -1)
self.pre_merge_log_tbl.add_data(partition_index+1, part_cost, part_opt_time, self.net.partitions[partition_index].slow_down_factor, -1, -1, -1, -1, -1, -1, -1)
self.solver_status(self.pre_merge_log_tbl)
else:
self.solver_status()
Expand Down
2 changes: 1 addition & 1 deletion fpgaconvnet/optimiser/solvers/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def create_report(self, output_path):
latency = self.net.get_latency(self.platform.board_freq, self.multi_fpga, inter_delay)
throughput = self.net.get_throughput(self.platform.board_freq, self.multi_fpga, inter_delay)
if self.wandb_enabled:
self.wandb_log()
self.wandb_log(**{"optimisation_time_sec": self.total_opt_time})
report = {}
report = {
"name" : self.net.name,
Expand Down
39 changes: 20 additions & 19 deletions fpgaconvnet/optimiser/solvers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Solver:
ram_usage: float = 1.0
multi_fpga: bool = False
constrain_port_width: bool = True
total_opt_time: float = 0.0
wandb_enabled: bool = False

"""
Expand Down Expand Up @@ -215,22 +216,22 @@ def solver_status(self, wandb_tbl=None):

# Resources
resources = [ self.get_partition_resource(partition) for partition in self.net.partitions ]
BRAM = max([ resource['BRAM'] for resource in resources ])
DSP = max([ resource['DSP'] for resource in resources ])
LUT = max([ resource['LUT'] for resource in resources ])
FF = max([ resource['FF'] for resource in resources ])
BW = max([ partition.get_total_bandwidth(self.platform.board_freq) for partition in self.net.partitions ])
BRAM = np.mean([ resource['BRAM'] for resource in resources ])
DSP = np.mean([ resource['DSP'] for resource in resources ])
LUT = np.mean([ resource['LUT'] for resource in resources ])
FF = np.mean([ resource['FF'] for resource in resources ])
BW = np.mean([ partition.get_total_bandwidth(self.platform.board_freq) for partition in self.net.partitions ])

solver_data = [
["COST:", "", "RESOURCES:", "", "", "", ""],
["", "", "BRAM", "DSP", "LUT", "FF", "BW"],
[f"{cost:.6f} ({objective})",
"",
f"{BRAM}/{self.platform.get_bram()}",
f"{DSP}/{self.platform.get_dsp()}",
f"{LUT}/{self.platform.get_lut()}",
f"{FF}/{self.platform.get_ff()}",
f"{BW}/{self.platform.get_mem_bw()}"],
f"{BRAM:.2f}/{self.platform.get_bram()}",
f"{DSP:.2f}/{self.platform.get_dsp()}",
f"{LUT:.2f}/{self.platform.get_lut()}",
f"{FF:.2f}/{self.platform.get_ff()}",
f"{BW:.2f}/{self.platform.get_mem_bw()}"],
["",
"",
f"{BRAM/self.platform.get_bram() * 100:.2f} %",
Expand All @@ -241,21 +242,21 @@ def solver_status(self, wandb_tbl=None):
]

if self.platform.get_uram() > 0:
URAM = max([ resource['URAM'] for resource in resources ])
URAM = np.mean([ resource['URAM'] for resource in resources ])
solver_data[0].insert(3, "")
solver_data[1].insert(2, "URAM")
solver_data[2].insert(2, f"{URAM}/{self.platform.get_uram()}")
solver_data[2].insert(2, f"{URAM:.2f}/{self.platform.get_uram()}")
solver_data[3].insert(2, f"{URAM/self.platform.get_uram() * 100:.2f} %")

if wandb_tbl != None:
for _, row in list(wandb_tbl.iterrows())[-1:]:
row[3] = cost
row[4] = URAM/self.platform.get_uram() * 100 if self.platform.get_uram() > 0 else 0
row[5] = BRAM/self.platform.get_bram() * 100
row[6] = DSP/self.platform.get_dsp() * 100
row[7] = LUT/self.platform.get_lut() * 100
row[8] = FF/self.platform.get_ff() * 100
row[9] = BW
row[4] = cost
row[5] = URAM/self.platform.get_uram() * 100 if self.platform.get_uram() > 0 else 0
row[6] = BRAM/self.platform.get_bram() * 100
row[7] = DSP/self.platform.get_dsp() * 100
row[8] = LUT/self.platform.get_lut() * 100
row[9] = FF/self.platform.get_ff() * 100
row[10] = BW

solver_table = tabulate(solver_data, headers="firstrow", tablefmt="github")
print(solver_table)
Expand Down

0 comments on commit 6ad8c6b

Please sign in to comment.