Skip to content

Commit

Permalink
finished the rollout plot and did some tidy up
Browse files Browse the repository at this point in the history
Signed-off-by: LimitingFactor <[email protected]>
  • Loading branch information
LimitingFactor authored and cfd1 committed Oct 4, 2023
1 parent 0a6e5fd commit d99c745
Showing 1 changed file with 42 additions and 51 deletions.
93 changes: 42 additions & 51 deletions examples/cfd/vortex_shedding_mgn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -169,15 +170,17 @@ def init_animation(self, idx):
self.ax[1].set_facecolor("black")
self.ax[2].set_facecolor("black")
self.first_call = True
self.text = None

# make animations dir
if not os.path.exists("./animations"):
os.makedirs("./animations")

def animate(self, num):
# Setup the colour bar ranges
if self.animation_variable == "u":
min_var = -1.0
max_var = 4.5
max_var = 4.0
min_delta_var = -0.25
max_delta_var = 0.25
elif self.animation_variable == "v":
Expand All @@ -186,8 +189,8 @@ def animate(self, num):
min_delta_var = -0.25
max_delta_var = 0.25
elif self.animation_variable == "p":
min_var = -6.0
max_var = 6.0
min_var = -5.0
max_var = 5.0
min_delta_var = -0.25
max_delta_var = 0.25

Expand All @@ -208,59 +211,44 @@ def animate(self, num):
self.ax[0].set_axis_off()
navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy")
self.ax[0].add_patch(navy_box) # Add a navy box to the first subplot
ans = self.ax[0].tripcolor(triang, y_star, vmin=min_var, vmax=max_var)
tripcolor_plot = self.ax[0].tripcolor(triang, y_star, vmin=min_var, vmax=max_var)
self.ax[0].triplot(triang, "ko-", ms=0.5, lw=0.3)
self.ax[0].set_title("Modulus MeshGraphNet Prediction", color="white")
if num == 0 and self.first_call:
cb_ax = self.fig.add_axes([0.9525, 0.69, 0.01, 0.26])
cb = self.fig.colorbar(ans, orientation="vertical", cax=cb_ax)
# cb = self.fig.colorbar(ans, ax=self.ax[0], location="right")
# COLORBAR
# set colorbar label plus label color
cb.set_label(self.animation_variable, color="white")

# set colorbar tick color
cb.ax.yaxis.set_tick_params(color="white")

# set colorbar edgecolor
cb.outline.set_edgecolor("white")

# set colorbar ticklabels
plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white")
self.setup_colourbars(tripcolor_plot, cb_ax)

# Update the text for the example number and step number
example_num = math.floor(num / C.num_test_time_steps)
if self.text is None:
self.text = plt.text(0.001, 0.9,
f"Example {example_num + 1}: "
f"{num - example_num * C.num_test_time_steps}/{C.num_test_time_steps}",
color="white", fontsize=20, transform=self.ax[0].transAxes)
else:
self.text.set_text(
f"Example {example_num + 1}: {num - example_num * C.num_test_time_steps}/{C.num_test_time_steps}")

# Truth plotting
self.ax[1].cla()
self.ax[1].set_aspect("equal")
self.ax[1].set_axis_off()
navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy")
self.ax[1].add_patch(navy_box) # Add a navy box to the second subplot
ans = self.ax[1].tripcolor(triang, y_exact, vmin=min_var, vmax=max_var)
tripcolor_plot = self.ax[1].tripcolor(triang, y_exact, vmin=min_var, vmax=max_var)
self.ax[1].triplot(triang, "ko-", ms=0.5, lw=0.3)
self.ax[1].set_title("Ground Truth", color="white")
if num == 0 and self.first_call:
cb_ax = self.fig.add_axes([0.9525, 0.37, 0.01, 0.26])
cb = self.fig.colorbar(ans, orientation="vertical", cax=cb_ax)
# cb = self.fig.colorbar(ans, ax=self.ax[0], location="right")
# COLORBAR
# set colorbar label plus label color
cb.set_label(self.animation_variable, color="white")

# set colorbar tick color
cb.ax.yaxis.set_tick_params(color="white")

# set colorbar edgecolor
cb.outline.set_edgecolor("white")

# set colorbar ticklabels
plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white")
self.setup_colourbars(tripcolor_plot, cb_ax)

# Error plotting
self.ax[2].cla()
self.ax[2].set_aspect("equal")
self.ax[2].set_axis_off()
navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy")
self.ax[2].add_patch(navy_box) # Add a navy box to the second subplot
ans = self.ax[2].tripcolor(
tripcolor_plot = self.ax[2].tripcolor(
triang, y_error, vmin=min_delta_var, vmax=max_delta_var, cmap="coolwarm"
)
self.ax[2].triplot(triang, "ko-", ms=0.5, lw=0.3)
Expand All @@ -269,20 +257,7 @@ def animate(self, num):
)
if num == 0 and self.first_call:
cb_ax = self.fig.add_axes([0.9525, 0.055, 0.01, 0.26])
cb = self.fig.colorbar(ans, orientation="vertical", cax=cb_ax)
# cb = self.fig.colorbar(ans, ax=self.ax[0], location="right")
# COLORBAR
# set colorbar label plus label color
cb.set_label(self.animation_variable, color="white")

# set colorbar tick color
cb.ax.yaxis.set_tick_params(color="white")

# set colorbar edgecolor
cb.outline.set_edgecolor("white")

# set colorbar ticklabels
plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white")
self.setup_colourbars(tripcolor_plot, cb_ax)

# Adjust subplots to minimize empty space
self.ax[0].set_aspect("auto", adjustable="box")
Expand All @@ -299,6 +274,13 @@ def animate(self, num):
)
return self.fig

def setup_colourbars(self, tripcolor_plot, cb_ax):
cb = self.fig.colorbar(tripcolor_plot, orientation="vertical", cax=cb_ax)
cb.set_label(self.animation_variable, color="white")
cb.ax.yaxis.set_tick_params(color="white")
cb.outline.set_edgecolor("white")
plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white")


if __name__ == "__main__":
logger = PythonLogger("main") # General python logger
Expand All @@ -315,9 +297,18 @@ def animate(self, num):
frames=len(rollout.graphs) // C.frame_skip,
interval=C.frame_interval,
)
ani.save("animations/animation_" + C.viz_vars[i] + ".gif")
ani.save("animations/animation_" + C.viz_vars[i] + ".gif", dpi=50)
logger.info(f"Created animation for {C.viz_vars[i]}")

# Plot the losses
plt.style.use('dark_background')
fig, ax = plt.subplots(1, 1, figsize=(16, 4.5))
ax.plot(rollout.loss)
plt.savefig("animations/loss.png")
ax.set_title("Rollout loss")
for i in range(C.num_test_samples):
start = i * (C.num_test_time_steps - 1)
end = i * (C.num_test_time_steps - 1) + (C.num_test_time_steps - 1)
ax.plot(rollout.loss[start:end])
ax.set_xlim([0, C.num_test_time_steps])
ax.set_xlabel("Rollout step")
ax.set_ylabel("Step loss")
plt.savefig("animations/rollout_loss.png")

0 comments on commit d99c745

Please sign in to comment.