From 12d0aa4c704b55b8ccc629f6741e1d690906cbce Mon Sep 17 00:00:00 2001 From: LimitingFactor Date: Mon, 25 Sep 2023 13:09:12 +0100 Subject: [PATCH] finished the rollout plot and did some tidy up Signed-off-by: LimitingFactor --- examples/cfd/vortex_shedding_mgn/inference.py | 93 +++++++++---------- 1 file changed, 42 insertions(+), 51 deletions(-) diff --git a/examples/cfd/vortex_shedding_mgn/inference.py b/examples/cfd/vortex_shedding_mgn/inference.py index 81f3879..93c84a9 100644 --- a/examples/cfd/vortex_shedding_mgn/inference.py +++ b/examples/cfd/vortex_shedding_mgn/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os import matplotlib.pyplot as plt @@ -169,15 +170,17 @@ def init_animation(self, idx): self.ax[1].set_facecolor("black") self.ax[2].set_facecolor("black") self.first_call = True + self.text = None # make animations dir if not os.path.exists("./animations"): os.makedirs("./animations") def animate(self, num): + # Setup the colour bar ranges if self.animation_variable == "u": min_var = -1.0 - max_var = 4.5 + max_var = 4.0 min_delta_var = -0.25 max_delta_var = 0.25 elif self.animation_variable == "v": @@ -186,8 +189,8 @@ def animate(self, num): min_delta_var = -0.25 max_delta_var = 0.25 elif self.animation_variable == "p": - min_var = -6.0 - max_var = 6.0 + min_var = -5.0 + max_var = 5.0 min_delta_var = -0.25 max_delta_var = 0.25 @@ -208,25 +211,23 @@ def animate(self, num): self.ax[0].set_axis_off() navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy") self.ax[0].add_patch(navy_box) # Add a navy box to the first subplot - ans = self.ax[0].tripcolor(triang, y_star, vmin=min_var, vmax=max_var) + tripcolor_plot = self.ax[0].tripcolor(triang, y_star, vmin=min_var, vmax=max_var) self.ax[0].triplot(triang, "ko-", ms=0.5, lw=0.3) self.ax[0].set_title("Modulus MeshGraphNet Prediction", color="white") if num == 0 and self.first_call: cb_ax = self.fig.add_axes([0.9525, 0.69, 0.01, 0.26]) - cb = self.fig.colorbar(ans, orientation="vertical", cax=cb_ax) - # cb = self.fig.colorbar(ans, ax=self.ax[0], location="right") - # COLORBAR - # set colorbar label plus label color - cb.set_label(self.animation_variable, color="white") - - # set colorbar tick color - cb.ax.yaxis.set_tick_params(color="white") - - # set colorbar edgecolor - cb.outline.set_edgecolor("white") - - # set colorbar ticklabels - plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white") + self.setup_colourbars(tripcolor_plot, cb_ax) + + # Update the text for the example number and step number + example_num = math.floor(num / C.num_test_time_steps) + if self.text is None: + self.text = plt.text(0.001, 0.9, + f"Example {example_num + 1}: " + f"{num - example_num * C.num_test_time_steps}/{C.num_test_time_steps}", + color="white", fontsize=20, transform=self.ax[0].transAxes) + else: + self.text.set_text( + f"Example {example_num + 1}: {num - example_num * C.num_test_time_steps}/{C.num_test_time_steps}") # Truth plotting self.ax[1].cla() @@ -234,25 +235,12 @@ def animate(self, num): self.ax[1].set_axis_off() navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy") self.ax[1].add_patch(navy_box) # Add a navy box to the second subplot - ans = self.ax[1].tripcolor(triang, y_exact, vmin=min_var, vmax=max_var) + tripcolor_plot = self.ax[1].tripcolor(triang, y_exact, vmin=min_var, vmax=max_var) self.ax[1].triplot(triang, "ko-", ms=0.5, lw=0.3) self.ax[1].set_title("Ground Truth", color="white") if num == 0 and self.first_call: cb_ax = self.fig.add_axes([0.9525, 0.37, 0.01, 0.26]) - cb = self.fig.colorbar(ans, orientation="vertical", cax=cb_ax) - # cb = self.fig.colorbar(ans, ax=self.ax[0], location="right") - # COLORBAR - # set colorbar label plus label color - cb.set_label(self.animation_variable, color="white") - - # set colorbar tick color - cb.ax.yaxis.set_tick_params(color="white") - - # set colorbar edgecolor - cb.outline.set_edgecolor("white") - - # set colorbar ticklabels - plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white") + self.setup_colourbars(tripcolor_plot, cb_ax) # Error plotting self.ax[2].cla() @@ -260,7 +248,7 @@ def animate(self, num): self.ax[2].set_axis_off() navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy") self.ax[2].add_patch(navy_box) # Add a navy box to the second subplot - ans = self.ax[2].tripcolor( + tripcolor_plot = self.ax[2].tripcolor( triang, y_error, vmin=min_delta_var, vmax=max_delta_var, cmap="coolwarm" ) self.ax[2].triplot(triang, "ko-", ms=0.5, lw=0.3) @@ -269,20 +257,7 @@ def animate(self, num): ) if num == 0 and self.first_call: cb_ax = self.fig.add_axes([0.9525, 0.055, 0.01, 0.26]) - cb = self.fig.colorbar(ans, orientation="vertical", cax=cb_ax) - # cb = self.fig.colorbar(ans, ax=self.ax[0], location="right") - # COLORBAR - # set colorbar label plus label color - cb.set_label(self.animation_variable, color="white") - - # set colorbar tick color - cb.ax.yaxis.set_tick_params(color="white") - - # set colorbar edgecolor - cb.outline.set_edgecolor("white") - - # set colorbar ticklabels - plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white") + self.setup_colourbars(tripcolor_plot, cb_ax) # Adjust subplots to minimize empty space self.ax[0].set_aspect("auto", adjustable="box") @@ -299,6 +274,13 @@ def animate(self, num): ) return self.fig + def setup_colourbars(self, tripcolor_plot, cb_ax): + cb = self.fig.colorbar(tripcolor_plot, orientation="vertical", cax=cb_ax) + cb.set_label(self.animation_variable, color="white") + cb.ax.yaxis.set_tick_params(color="white") + cb.outline.set_edgecolor("white") + plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white") + if __name__ == "__main__": logger = PythonLogger("main") # General python logger @@ -315,9 +297,18 @@ def animate(self, num): frames=len(rollout.graphs) // C.frame_skip, interval=C.frame_interval, ) - ani.save("animations/animation_" + C.viz_vars[i] + ".gif") + ani.save("animations/animation_" + C.viz_vars[i] + ".gif", dpi=50) logger.info(f"Created animation for {C.viz_vars[i]}") + # Plot the losses + plt.style.use('dark_background') fig, ax = plt.subplots(1, 1, figsize=(16, 4.5)) - ax.plot(rollout.loss) - plt.savefig("animations/loss.png") + ax.set_title("Rollout loss") + for i in range(C.num_test_samples): + start = i * (C.num_test_time_steps - 1) + end = i * (C.num_test_time_steps - 1) + (C.num_test_time_steps - 1) + ax.plot(rollout.loss[start:end]) + ax.set_xlim([0, C.num_test_time_steps]) + ax.set_xlabel("Rollout step") + ax.set_ylabel("Step loss") + plt.savefig("animations/rollout_loss.png")