Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds error plot and colourbars to the animation. Start of a loss plot #112

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Updated the vortex_shedding_mgn inference to include an error plot and colourbars

### Deprecated

### Removed
Expand Down
125 changes: 109 additions & 16 deletions examples/cfd/vortex_shedding_mgn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch, dgl
from dgl.dataloading import GraphDataLoader
import torch
import math
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from dgl.dataloading import GraphDataLoader
from matplotlib import animation
from matplotlib import tri as mtri
import os
from matplotlib.patches import Rectangle

from modulus.models.meshgraphnet import MeshGraphNet
from modulus.datapipes.gnn.vortex_shedding_dataset import VortexSheddingDataset
from modulus.models.meshgraphnet import MeshGraphNet

from constants import Constants
from modulus.launch.logging import PythonLogger
from modulus.launch.utils import load_checkpoint
from constants import Constants

# Instantiate constants
C = Constants()
Expand Down Expand Up @@ -64,6 +64,9 @@ def __init__(self, logger):
else:
self.model = self.model.to(self.device)

# instantiate loss
self.criterion = torch.nn.MSELoss()

# enable train mode
self.model.eval()

Expand All @@ -77,12 +80,13 @@ def __init__(self, logger):
self.var_identifier = {"u": 0, "v": 1, "p": 2}

def predict(self):
self.pred, self.exact, self.faces, self.graphs = [], [], [], []
self.pred, self.exact, self.faces, self.graphs, self.loss = [], [], [], [], []
stats = {
key: value.to(self.device) for key, value in self.dataset.node_stats.items()
}
for i, (graph, cells, mask) in enumerate(self.dataloader):
graph = graph.to(self.device)

# denormalize data
graph.ndata["x"][:, 0:2] = self.dataset.denormalize(
graph.ndata["x"][:, 0:2], stats["velocity_mean"], stats["velocity_std"]
Expand All @@ -107,6 +111,8 @@ def predict(self):
invar[:, 0:2] = self.dataset.normalize_node(
invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"]
)

# Get the predition
pred_i = self.model(invar, graph.edata["x"], graph).detach() # predict

# denormalize prediction
Expand All @@ -116,6 +122,10 @@ def predict(self):
pred_i[:, 2] = self.dataset.denormalize(
pred_i[:, 2], stats["pressure_mean"], stats["pressure_std"]
)

loss = self.criterion(pred_i, graph.ndata["y"])
self.loss.append(loss.cpu().detach())

invar[:, 0:2] = self.dataset.denormalize(
invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"]
)
Expand Down Expand Up @@ -146,61 +156,131 @@ def predict(self):
self.graphs.append(graph.cpu())

def init_animation(self, idx):
self.animation_variable = C.viz_vars[idx]
self.pred_i = [var[:, idx] for var in self.pred]
self.exact_i = [var[:, idx] for var in self.exact]

# fig configs
plt.rcParams["image.cmap"] = "inferno"
self.fig, self.ax = plt.subplots(2, 1, figsize=(16, 9))
self.fig, self.ax = plt.subplots(3, 1, figsize=(16, (9 / 2) * 3))

# Set background color to black
self.fig.set_facecolor("black")
self.ax[0].set_facecolor("black")
self.ax[1].set_facecolor("black")
self.ax[2].set_facecolor("black")
self.first_call = True
self.text = None

# make animations dir
if not os.path.exists("./animations"):
os.makedirs("./animations")

def animate(self, num):
# Setup the colour bar ranges
if self.animation_variable == "u":
min_var = -1.0
max_var = 4.0
min_delta_var = -0.25
max_delta_var = 0.25
elif self.animation_variable == "v":
min_var = -2.0
max_var = 2.0
min_delta_var = -0.25
max_delta_var = 0.25
elif self.animation_variable == "p":
min_var = -5.0
max_var = 5.0
min_delta_var = -0.25
max_delta_var = 0.25

num *= C.frame_skip
graph = self.graphs[num]
y_star = self.pred_i[num].numpy()
y_exact = self.exact_i[num].numpy()
y_error = y_star - y_exact
triang = mtri.Triangulation(
graph.ndata["mesh_pos"][:, 0].numpy(),
graph.ndata["mesh_pos"][:, 1].numpy(),
self.faces[num],
)

# Prediction plotting
self.ax[0].cla()
self.ax[0].set_aspect("equal")
self.ax[0].set_axis_off()
navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy")
self.ax[0].add_patch(navy_box) # Add a navy box to the first subplot
self.ax[0].tripcolor(triang, y_star, vmin=np.min(y_star), vmax=np.max(y_star))
tripcolor_plot = self.ax[0].tripcolor(triang, y_star, vmin=min_var, vmax=max_var)
self.ax[0].triplot(triang, "ko-", ms=0.5, lw=0.3)
self.ax[0].set_title("Modulus MeshGraphNet Prediction", color="white")
if num == 0 and self.first_call:
cb_ax = self.fig.add_axes([0.9525, 0.69, 0.01, 0.26])
self.setup_colourbars(tripcolor_plot, cb_ax)

# Update the text for the example number and step number
example_num = math.floor(num / C.num_test_time_steps)
if self.text is None:
self.text = plt.text(0.001, 0.9,
f"Example {example_num + 1}: "
f"{num - example_num * C.num_test_time_steps}/{C.num_test_time_steps}",
color="white", fontsize=20, transform=self.ax[0].transAxes)
else:
self.text.set_text(
f"Example {example_num + 1}: {num - example_num * C.num_test_time_steps}/{C.num_test_time_steps}")

# Truth plotting
self.ax[1].cla()
self.ax[1].set_aspect("equal")
self.ax[1].set_axis_off()
navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy")
self.ax[1].add_patch(navy_box) # Add a navy box to the second subplot
self.ax[1].tripcolor(
triang, y_exact, vmin=np.min(y_exact), vmax=np.max(y_exact)
)
tripcolor_plot = self.ax[1].tripcolor(triang, y_exact, vmin=min_var, vmax=max_var)
self.ax[1].triplot(triang, "ko-", ms=0.5, lw=0.3)
self.ax[1].set_title("Ground Truth", color="white")
if num == 0 and self.first_call:
cb_ax = self.fig.add_axes([0.9525, 0.37, 0.01, 0.26])
self.setup_colourbars(tripcolor_plot, cb_ax)

# Error plotting
self.ax[2].cla()
self.ax[2].set_aspect("equal")
self.ax[2].set_axis_off()
navy_box = Rectangle((0, 0), 1.4, 0.4, facecolor="navy")
self.ax[2].add_patch(navy_box) # Add a navy box to the second subplot
tripcolor_plot = self.ax[2].tripcolor(
triang, y_error, vmin=min_delta_var, vmax=max_delta_var, cmap="coolwarm"
)
self.ax[2].triplot(triang, "ko-", ms=0.5, lw=0.3)
self.ax[2].set_title(
"Absolute Error (Prediction - Ground Truth)", color="white"
)
if num == 0 and self.first_call:
cb_ax = self.fig.add_axes([0.9525, 0.055, 0.01, 0.26])
self.setup_colourbars(tripcolor_plot, cb_ax)

# Adjust subplots to minimize empty space
self.ax[0].set_aspect("auto", adjustable="box")
self.ax[1].set_aspect("auto", adjustable="box")
self.ax[0].autoscale(enable=True, tight=True)

self.ax[1].set_aspect("auto", adjustable="box")
self.ax[1].autoscale(enable=True, tight=True)

self.ax[2].set_aspect("auto", adjustable="box")
self.ax[2].autoscale(enable=True, tight=True)

self.fig.subplots_adjust(
left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.1, hspace=0.2
)
return self.fig

def setup_colourbars(self, tripcolor_plot, cb_ax):
cb = self.fig.colorbar(tripcolor_plot, orientation="vertical", cax=cb_ax)
cb.set_label(self.animation_variable, color="white")
cb.ax.yaxis.set_tick_params(color="white")
cb.outline.set_edgecolor("white")
plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white")


if __name__ == "__main__":
logger = PythonLogger("main") # General python logger
Expand All @@ -217,5 +297,18 @@ def animate(self, num):
frames=len(rollout.graphs) // C.frame_skip,
interval=C.frame_interval,
)
ani.save("animations/animation_" + C.viz_vars[i] + ".gif")
ani.save("animations/animation_" + C.viz_vars[i] + ".gif", dpi=50)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the resolution a configurable variable?

logger.info(f"Created animation for {C.viz_vars[i]}")

# Plot the losses
plt.style.use('dark_background')
fig, ax = plt.subplots(1, 1, figsize=(16, 4.5))
ax.set_title("Rollout loss")
for i in range(C.num_test_samples):
start = i * (C.num_test_time_steps - 1)
end = i * (C.num_test_time_steps - 1) + (C.num_test_time_steps - 1)
ax.plot(rollout.loss[start:end])
ax.set_xlim([0, C.num_test_time_steps])
ax.set_xlabel("Rollout step")
ax.set_ylabel("Step loss")
plt.savefig("animations/rollout_loss.png")