Skip to content

Commit

Permalink
Merge pull request #128 from s-fong/brainstem
Browse files Browse the repository at this point in the history
Brainstem
  • Loading branch information
rchristie authored May 27, 2021
2 parents 5796e3b + c5f198f commit 924fb6c
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/scaffoldmaker/annotation/brainstem_terms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Common resource for testing annotation terms.
"""

# convention: preferred name, preferred id, followed by any other ids and alternative names
brainstem_terms = [
("medulla oblongata", "UBERON:0001896"),
("pons", "UBERON:0000988"),
("midbrain", "UBERON:0001891"),
("diencephalon", "UBERON:0001894"),
("brainstem", "UBERON:0002298")
]

def get_brainstem_annotation_term(name : str):
"""
Find term by matching name to any identifier held for a term.
Raise exception if name not found.
:return ( preferred name, preferred id )
"""
for term in brainstem_terms:
if name in term:
return ( term[0], term[1] )
raise NameError("Brainstem annotation term '" + name + "' not found.")
358 changes: 358 additions & 0 deletions src/scaffoldmaker/meshtypes/meshtype_3d_brainstem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
"""
Brainstem mesh using a tapered cylinder
"""

from __future__ import division
import copy
from opencmiss.zinc.element import Element
from opencmiss.zinc.node import Node
from opencmiss.utils.zinc.field import Field, findOrCreateFieldCoordinates, findOrCreateFieldGroup, findOrCreateFieldNodeGroup, findOrCreateFieldStoredMeshLocation, findOrCreateFieldStoredString
from opencmiss.utils.zinc.finiteelement import getMaximumNodeIdentifier
from scaffoldmaker.annotation.annotationgroup import AnnotationGroup, findOrCreateAnnotationGroupForTerm
from scaffoldmaker.annotation.brainstem_terms import get_brainstem_annotation_term
from scaffoldmaker.meshtypes.scaffold_base import Scaffold_base
from scaffoldmaker.meshtypes.meshtype_1d_path1 import MeshType_1d_path1
from scaffoldmaker.utils.meshrefinement import MeshRefinement
from scaffoldmaker.utils.cylindermesh import CylinderMesh, CylinderShape, CylinderEnds, Tapered, ConeBaseProgression, CylinderCentralPath
from scaffoldmaker.utils.zinc_utils import exnodeStringFromNodeValues
from scaffoldmaker.scaffoldpackage import ScaffoldPackage


class MeshType_3d_brainstem1(Scaffold_base):
"""
Generates a tapered cylinder for the brainstem based on solid cylinder mesh, with variable numbers of elements in major, minor and length directions. Regions of the brainstem are annotated.
"""

centralPathDefaultScaffoldPackages = {
'Cylinder 1': ScaffoldPackage(MeshType_1d_path1, {
'scaffoldSettings': {
'Coordinate dimensions': 3,
'D2 derivatives': True,
'D3 derivatives': True,
'Length': 3.0,
'Number of elements': 3
},
'meshEdits': exnodeStringFromNodeValues( # dimensional.
[Node.VALUE_LABEL_VALUE, Node.VALUE_LABEL_D_DS1, Node.VALUE_LABEL_D_DS2, Node.VALUE_LABEL_D2_DS1DS2,
Node.VALUE_LABEL_D_DS3, Node.VALUE_LABEL_D2_DS1DS3], [
[[0.0, -1.0, 5.0], [0.0, 0.0, -4.5], [5.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, -2.4, 0.0], [0.0, -2.2, 0.0]],
[[0.0, -1.0, 0.5], [0.0, 0.0, -4.5], [6.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, -4, 0.0], [0.0, -1.1, 0.0]],
[[0.0, -1.0, -4.0], [0.0, 0.0, -4.5], [7.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, -4.5, 0.0], [0.0, -0.8, 0.0]],
[[0.0, -1.0, -8.5], [0.0, 0.0, -4.5], [8.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, -5.5, 0.0], [0.0, -0.8, 0.0]],
[[0.0, -1.0, -13.0], [0.0, 0.0, -4.5], [9.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, -6.0, 0.0], [0.0, -0.2, 0.0]]
])
})
}

@staticmethod
def getName():
return '3D Brainstem 1'

@classmethod
def getDefaultOptions(cls, parameterSetName='Default'):
centralPathOption = cls.centralPathDefaultScaffoldPackages['Cylinder 1']
options = {
'Central path': copy.deepcopy(centralPathOption),
'Number of elements across major': 6,
'Number of elements across minor': 6,
'Number of elements along': 12,
'Use cross derivatives': False,
'Refine': False,
'Refine number of elements across major and minor': 1,
'Refine number of elements along': 1
}
return options

@classmethod
def getOptionValidScaffoldTypes(cls, optionName):
if optionName == 'Central path':
return [MeshType_1d_path1]
return []

@classmethod
def getOptionScaffoldTypeParameterSetNames(cls, optionName, scaffoldType):
if optionName == 'Central path':
return list(cls.centralPathDefaultScaffoldPackages.keys())
assert scaffoldType in cls.getOptionValidScaffoldTypes(optionName), \
cls.__name__ + '.getOptionScaffoldTypeParameterSetNames. ' + \
'Invalid option \'' + optionName + '\' scaffold type ' + scaffoldType.getName()
return scaffoldType.getParameterSetNames()

@classmethod
def getOptionScaffoldPackage(cls, optionName, scaffoldType, parameterSetName=None):
'''
:param parameterSetName: Name of valid parameter set for option Scaffold, or None for default.
:return: ScaffoldPackage.
'''
if parameterSetName:
assert parameterSetName in cls.getOptionScaffoldTypeParameterSetNames(optionName, scaffoldType), \
'Invalid parameter set ' + str(parameterSetName) + ' for scaffold ' + str(scaffoldType.getName()) + \
' in option ' + str(optionName) + ' of scaffold ' + cls.getName()
if optionName == 'Central path':
if not parameterSetName:
parameterSetName = list(cls.centralPathDefaultScaffoldPackages.keys())[0]
return copy.deepcopy(cls.centralPathDefaultScaffoldPackages[parameterSetName])
assert False, cls.__name__ + '.getOptionScaffoldPackage: Option ' + optionName + ' is not a scaffold'


@staticmethod
def getOrderedOptionNames():
return [
'Central path',
'Number of elements across major',
'Number of elements across minor',
'Number of elements along',
'Refine',
'Refine number of elements across major and minor',
'Refine number of elements along'
]


@classmethod
def checkOptions(cls, options):

if not options['Central path'].getScaffoldType() in cls.getOptionValidScaffoldTypes('Central path'):
options['Central path'] = cls.getOptionScaffoldPackage('Central path', MeshType_1d_path1)
dependentChanges = False
if options['Number of elements across major'] < 4:
options['Number of elements across major'] = 4
if options['Number of elements across major'] % 2:
options['Number of elements across major'] += 1

if options['Number of elements across minor'] < 4:
options['Number of elements across minor'] = 4
if options['Number of elements across minor'] % 2:
options['Number of elements across minor'] += 1

if options['Number of elements along'] < 2:
options['Number of elements along'] = 2

return dependentChanges

@staticmethod
def generateBaseMesh(region, options):
"""
Generate the base tricubic Hermite mesh. See also generateMesh().
:param region: Zinc region to define model in. Must be empty.
:param options: Dict containing options. See getDefaultOptions().
:return: None
"""

fm = region.getFieldmodule()
mesh = fm.findMeshByDimension(3)
coordinates = findOrCreateFieldCoordinates(fm)

centralPath = options['Central path']
full = True
elementsCountAcrossMajor = options['Number of elements across major']
if not full:
elementsCountAcrossMajor //= 2
elementsCountAcrossMinor = options['Number of elements across minor']
elementsCountAlong = options['Number of elements along']


elemPerLayer = ((elementsCountAcrossMajor - 2) * elementsCountAcrossMinor) + (
2 * (elementsCountAcrossMinor - 2))
brainstemGroup = AnnotationGroup(region, get_brainstem_annotation_term('brainstem'))
midbrainGroup = AnnotationGroup(region, get_brainstem_annotation_term('midbrain'))
ponsGroup = AnnotationGroup(region, get_brainstem_annotation_term('pons'))
medullaGroup = AnnotationGroup(region, get_brainstem_annotation_term('medulla oblongata'))
brainstemMeshGroup = brainstemGroup.getMeshGroup(mesh)
midbrainMeshGroup = midbrainGroup.getMeshGroup(mesh)
ponsMeshGroup = ponsGroup.getMeshGroup(mesh)
medullaMeshGroup = medullaGroup.getMeshGroup(mesh)
annotationGroups = [brainstemGroup, midbrainGroup, ponsGroup, medullaGroup]

#######################
# CREATE MAIN BODY MESH
#######################
cylinderCentralPath = CylinderCentralPath(region, centralPath, elementsCountAlong)

cylinderShape = CylinderShape.CYLINDER_SHAPE_FULL if full else CylinderShape.CYLINDER_SHAPE_LOWER_HALF

base = CylinderEnds(elementsCountAcrossMajor, elementsCountAcrossMinor,
centre=[0.0, 0.0, 0.0],
alongAxis=cylinderCentralPath.alongAxis[0], majorAxis=cylinderCentralPath.majorAxis[0],
minorRadius=cylinderCentralPath.minorRadii[0])

cylinder1 = CylinderMesh(fm, coordinates, elementsCountAlong, base,
cylinderShape=cylinderShape,
cylinderCentralPath=cylinderCentralPath, useCrossDerivatives=False)

iRegionBoundaries = [int(7*elementsCountAlong/15),int(14*elementsCountAlong/15)]
for elementIdentifier in range(1, mesh.getSize()+1):
element = mesh.findElementByIdentifier(elementIdentifier)
brainstemMeshGroup.addElement(element)
if elementIdentifier > (iRegionBoundaries[-1]*elemPerLayer):
midbrainMeshGroup.addElement(element)
elif elementIdentifier > (iRegionBoundaries[0]*elemPerLayer) and elementIdentifier <= (iRegionBoundaries[-1]*elemPerLayer):
ponsMeshGroup.addElement(element)
else:
medullaMeshGroup.addElement(element)

##############################
# point markers
##############################
eIndexPM = {}
xiPM = {}
pointMarkers = {}
eIndexPM['caudal-dorsal'] = int(elemPerLayer/2)
eIndexPM['midRostralCaudal-dorsal'] = int(elemPerLayer / 2) + (elemPerLayer * int((elementsCountAlong/2)-1))
eIndexPM['rostral-dorsal'] = (elemPerLayer*(elementsCountAlong-1)) + int(elemPerLayer/2)
eIndexPM['caudal-ventral'] = int(elemPerLayer/2) - (elementsCountAcrossMinor-1)
eIndexPM['midRostralCaudal-ventral'] = eIndexPM['midRostralCaudal-dorsal'] - (elementsCountAcrossMinor-1)
eIndexPM['rostral-ventral'] = int((elemPerLayer*(elementsCountAlong-1)) + (int(elemPerLayer/2) - elementsCountAcrossMinor + 1))
xiPM['caudal-ventral'] = [1.0, 0.0, 0.0]
xiPM['caudal-dorsal'] = [1.0, 0.0, 1.0]
xiPM['midRostralCaudal-ventral'] = [1.0, 1.0, 0.0]
xiPM['midRostralCaudal-dorsal'] = [1.0, 1.0, 1.0]
xiPM['rostral-ventral'] = [1.0, 1.0, 0.0]
xiPM['rostral-dorsal'] = [1.0, 1.0, 1.0]
for key in eIndexPM.keys():
pointMarkers[key] = {"elementID": eIndexPM[key], "xi": xiPM[key]}
# the following emergent markers are in bodyCoordinates. Will not work in normal coordinates system.
emergentMarkers = createCranialNerveEmergentMarkers(region, mesh, "coordinates")
pointMarkers.update(emergentMarkers)

nodes = fm.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES)
cache = fm.createFieldcache()
nodeIdentifier = max(1, getMaximumNodeIdentifier(nodes) + 1)
markerGroup = findOrCreateFieldGroup(fm, "marker")
markerName = findOrCreateFieldStoredString(fm, name="marker_name")
markerLocation = findOrCreateFieldStoredMeshLocation(fm, mesh, name="marker_location")
markerPoints = findOrCreateFieldNodeGroup(markerGroup, nodes).getNodesetGroup()
markerTemplateInternal = nodes.createNodetemplate()
markerTemplateInternal.defineField(markerName)
markerTemplateInternal.defineField(markerLocation)
if pointMarkers:
for key in pointMarkers:
addMarker = {"name": key, "xi": pointMarkers[key]["xi"]}
markerPoint = markerPoints.createNode(nodeIdentifier, markerTemplateInternal)
nodeIdentifier += 1
cache.setNode(markerPoint)
markerName.assignString(cache, addMarker["name"])
elementID = pointMarkers[key]["elementID"]
element = mesh.findElementByIdentifier(elementID)
markerLocation.assignMeshLocation(cache, element, addMarker["xi"])

return annotationGroups

@classmethod
def refineMesh(cls, meshRefinement, options):
"""
Refine source mesh into separate region, with change of basis.
:param meshRefinement: MeshRefinement, which knows source and target region.
:param options: Dict containing options. See getDefaultOptions().
"""
assert isinstance(meshRefinement, MeshRefinement)
refineElementsCountAcrossMajor = options['Refine number of elements across major and minor']
refineElementsCountAlong = options['Refine number of elements along']
meshRefinement.refineAllElementsCubeStandard3d(refineElementsCountAcrossMajor, refineElementsCountAlong, refineElementsCountAcrossMajor)

@classmethod
def defineFaceAnnotations(cls, region, options, annotationGroups):
"""
Add face annotation groups from the 1D mesh.
:param region: Zinc region containing model.
:param options: Dict containing options. See getDefaultOptions().
:param annotationGroups: List of annotation groups for ventral-level elements.
New point annotation groups are appended to this list.
"""
# create groups
fm = region.getFieldmodule()
mesh2d = fm.findMeshByDimension(2)
is_exterior = fm.createFieldIsExterior()
is_exterior_face_xi1 = fm.createFieldOr(
fm.createFieldAnd(is_exterior, fm.createFieldIsOnFace(Element.FACE_TYPE_XI1_0)),
fm.createFieldAnd(is_exterior, fm.createFieldIsOnFace(Element.FACE_TYPE_XI1_1)))
is_exterior_face_xi3 = fm.createFieldOr(fm.createFieldAnd(is_exterior, fm.createFieldIsOnFace(Element.FACE_TYPE_XI3_0)), fm.createFieldAnd(is_exterior, fm.createFieldIsOnFace(Element.FACE_TYPE_XI3_1)))

annoGroup = AnnotationGroup(region, get_brainstem_annotation_term('brainstem'))
isGroup = annoGroup.getFieldElementGroup(mesh2d)
is_face1 = fm.createFieldAnd(isGroup, is_exterior_face_xi1)
is_face3 = fm.createFieldAnd(isGroup, is_exterior_face_xi3)
is_face_ext = fm.createFieldOr(is_face1, is_face3)
faceGroup = findOrCreateAnnotationGroupForTerm(annotationGroups, region, ("brainstem_exterior", None))
faceGroup.getMeshGroup(mesh2d).addElementsConditional(is_face_ext)

# external regions
namelist = ['midbrain', 'pons', 'medulla oblongata']
for subregion in namelist:
subGroup = AnnotationGroup(region, get_brainstem_annotation_term(subregion))
issub = subGroup.getFieldElementGroup(mesh2d)
is_subface = fm.createFieldOr(fm.createFieldAnd(issub, is_exterior_face_xi1), fm.createFieldAnd(issub, is_exterior_face_xi3))
subFaceGroup = findOrCreateAnnotationGroupForTerm(annotationGroups, region, (subregion+'_exterior', None))
subFaceGroup.getMeshGroup(mesh2d).addElementsConditional(is_subface)

def createCranialNerveEmergentMarkers(region, mesh, coordinatesName):
# create marker points for locations the cranial nerves emerge from brainstem mesh, based on the USF cat brainstem data.
# return element xi
# use findMeshLocation to find the elementxi in an arbitrary mesh of given number of elements.

if coordinatesName == "bodyCoordinates":
# brainstem_coordinates: the left-side nerves
nerveDict = {'OCULOMOTOR_left':[-0.13912257342955267, -0.5345161733750351, -0.7374762051676923],
'TROCHLEAR_left':[-0.13148279719950992, 0.4218745504359067, -0.7375838988856348],
'TRIGEMINAL_left': [-0.7605971693047597, -0.4025791045292648, -0.6862730212268676],
'ABDUCENS_left': [-0.19517975766630574, -0.6252563181242173, -0.8205128205130072],
'FACIAL_left': [-0.5824675040481234, -0.3554448371502354, -0.24509655553058302],
'VESTIBULOCOCHLEAR_left': [-0.6147505791411602, -0.32790803815838, -0.24509655403515848],
'GLOSSOPHARYNGEAL_left': [-0.7307312460087607, -0.2576952819028721, -0.39215539053073717],
'VAGUS_left': [-0.6741855912315219, -0.25981298010131126, -0.24509655277992023],
'ACCESSORY_cranialRoot_left':[-0.6741855912315219, -0.25981298010131126, -0.24509655277992023],
'HYPOGLOSSAL_left': [-0.044776303107883636, -0.5027870527016534, -0.10510117079651562]
}

rightDict = {}
for key in nerveDict.keys():
nerveName = key.split('_')[0]+'_right'
xyz = [-1*nerveDict[key][0], nerveDict[key][1], nerveDict[key][2]]
rightDict.update({nerveName:xyz})
nerveDict.update(rightDict)

nerveNames = list(nerveDict.keys())

# add to data coordinates
markerNameField = 'marker_name'
fm = region.getFieldmodule()
cache = fm.createFieldcache()
datapoints = fm.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
data_coordinates = findOrCreateFieldCoordinates(fm, "data_coordinates")
markerName = findOrCreateFieldStoredString(fm, name="marker_name")
dnodetemplate = datapoints.createNodetemplate()
dnodetemplate.defineField(data_coordinates)
dnodetemplate.setValueNumberOfVersions(data_coordinates, -1, Node.VALUE_LABEL_VALUE, 1)
dnodetemplate.defineField(markerName)
dnodeIdentifier = 1
for nerveName in nerveNames:
node = datapoints.createNode(dnodeIdentifier, dnodetemplate)
cache.setNode(node)
addEnd = nerveDict[nerveName].copy()
data_coordinates.setNodeParameters(cache, -1, Node.VALUE_LABEL_VALUE, 1, addEnd)
markerName.assignString(cache, nerveName)
dnodeIdentifier += 1

# find element-xi for these data_coordinates
dataNamesField = fm.findFieldByName(markerNameField)
coordinates = findOrCreateFieldCoordinates(fm, coordinatesName)
found_mesh_location = fm.createFieldFindMeshLocation(data_coordinates, coordinates, mesh)
found_mesh_location.setSearchMode(found_mesh_location.SEARCH_MODE_NEAREST)
xi_projected_data = {}
nodeIter = datapoints.createNodeiterator()
node = nodeIter.next()
while node.isValid():
cache.setNode(node)
element, xi = found_mesh_location.evaluateMeshLocation(cache, 3)
marker_name = dataNamesField.evaluateString(cache)
if element.isValid():
addProjection = {marker_name:{"elementID": element.getIdentifier(), "xi": xi,"nodeID": node.getIdentifier()}}
xi_projected_data.update(addProjection)
node = nodeIter.next()

result = datapoints.destroyAllNodes()

else:
xi_projected_data = {}

return xi_projected_data

Loading

0 comments on commit 924fb6c

Please sign in to comment.