Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance issues w/ mesh attribute retrieval #216

Open
tradeqvest opened this issue Mar 21, 2024 · 3 comments
Open

Performance issues w/ mesh attribute retrieval #216

tradeqvest opened this issue Mar 21, 2024 · 3 comments
Assignees

Comments

@tradeqvest
Copy link

Hello,

For my application, I constantly need to retrieve mesh element attributes, i.a. mesh.GetElementTransformation(i) or mesh.GetElementVertices(i). As this requires looping over each element, the performance suffers significantly. Is there any way to do it more efficiently that I am overlooking? Is there a way to vectorize the retrieval?

I would appreciate any insights! Thanks in advance for your time and help!

@sshiraiwa
Copy link
Member

As for GetElementVertices, there is Mesh::GetVertexToElementTable. This returns a mapping from Vertex to Element as a table. Using I and J array of this table, you can create a reverse mapping from Element to Vertex.
In the following, I construct scipy.sparse.csr_matrix from I and J. Then, I took transpose and tocsr
You can use the indices and indptr of resultant array as the mapping from element to vertices.

import numpy as np
import mfem.ser as mfem
from scipy.sparse import csr_matrix

mesh = proj.model1.mfem.variables.eval("mesh")
tb = mesh.GetVertexToElementTable()
i = mfem.intArray((tb.GetI(), mesh.GetNV())).GetDataArray()
i = np.hstack((i, tb.Size_of_connections())) # need to append the total length
j = mfem.intArray((tb.GetJ(), tb.Size_of_connections())).GetDataArray()
mat = csr_matrix(([1]*len(j), j, i)).transpose().tocsr()

# well.. let's check if this is correct ;D
for i in range(mesh.GetNE()):
   iverts = mat.indices[mat.indptr[i]:mat.indptr[i+1]]
   iverts2 = mesh.GetElementVertices(i)
   if np.any(np.sort(iverts) != np.sort(iverts2)):
      print("error", i, iverts, iverts2)

As for '''mesh.GetElementTransformation(i)''', I realized that it calls Tr = IsoparametricTransformation() every time,
meaning it creates this object every time. We could change the wrapper so that we can pass Tr as a keyword argument, if
this object allocation is an issue. If not, I am not sure if there is a simple way to make this faster.

@tradeqvest
Copy link
Author

Thank you very much for your answer! 🙂 The first part worked really well!

Regarding the speed up of mesh.GetElementTransformation(i), I want to speed it up for this method:

def interpolate_solution_at_points(
    fespace, mesh, solution, integration_points, corresponding_elements
):
    """
    Interpolate a finite element solution at given points.

    Args:
    - fespace: The finite element space (mfem.FiniteElementSpace)
    - mesh: The mesh (mfem.Mesh)
    - solution: The finite element solution (np.array)
    - points: The points where the solution is to be interpolated (numpy array of shape (n_points, dim))

    Returns:
    - interpolated_values: The interpolated solution values at the given points (numpy array)
    """
    dim = fespace.GetMesh().Dimension()
    assert (
        integration_points.shape[1] == dim
    ), "Dimension of points must match the mesh dimension"
    grid_function = GridFunction(fespace)
    grid_function.Assign(np.ravel(solution))
    n_points = integration_points.shape[0]
    interpolated_values = np.zeros(n_points)
    ip = IntegrationPoint()
    for i, elem in enumerate(corresponding_elements):
        trans = mesh.GetElementTransformation(elem)
        point = Vector(integration_points[i, :])
        trans.TransformBack(point, ip)
        interpolated_values[i] = grid_function.GetValue(elem, ip)
    return interpolated_values.reshape(-1, 1)

If you see a way to make it more efficient, please let me know! 🙂 Thank you in advance for your time and effort!

@justinlaughlin
Copy link
Contributor

Hi @tradeqvest

What is the size of the problem you are working with?

  • How big is your for loop (how many points are you interpolating over)?
  • How many elements are in your mesh

The reason I ask is because if Nel << Npoints it may be worthwhile, as a first pass, to construct a mapping of your transformations for all elements, then access them in the for loop, rather than reinitializing.

I ran a quick profile and it looks like although mesh.GetElementTransformation does take some time, a lot of the time was in the initialization of Vector. Perhaps you could construct a single Vector before your loop, and change the values in the loop.

I'm not aware of a vectorized solution (maybe @sshiraiwa) might know. Could you try those two things and see if it improves your speed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants