Skip to content

Commit

Permalink
Improvements to flattenSubgraphs (#1101)
Browse files Browse the repository at this point in the history
- Extend flattenSubgraphs to handle a broader range of graph structures, including references to interfaces with 'defaultgeomprop' strings.
- Streamline the logic in flattenSubgraphs, merging duplicate code paths and clarifying documentation.
- Add a 'Flatten Subgraphs' option to the MaterialX viewer, allowing render comparisons between nested and flattened graphs.
  • Loading branch information
jstone-lucasfilm authored Oct 10, 2022
1 parent 01181c5 commit 7d4b758
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 142 deletions.
210 changes: 77 additions & 133 deletions source/MaterialXCore/Node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,45 +251,42 @@ NodePtr GraphElement::addMaterialNode(const string& name, ConstNodePtr shaderNod

void GraphElement::flattenSubgraphs(const string& target, NodePredicate filter)
{
vector<NodePtr> processNodeVec = getNodes();
while (!processNodeVec.empty())
vector<NodePtr> nodeQueue = getNodes();
while (!nodeQueue.empty())
{
// Precompute graph implementations and downstream ports for this node vector.
// Determine which nodes require processing, and precompute declarations
// and graph implementations for these nodes.
using PortElementVec = vector<PortElementPtr>;
std::vector<NodePtr> processNodeVec;
std::unordered_map<NodePtr, NodeGraphPtr> graphImplMap;
std::unordered_map<NodePtr, ConstNodeDefPtr> declarationMap;
std::unordered_map<NodePtr, PortElementVec> downstreamPortMap;
for (NodePtr cacheNode : processNodeVec)
for (NodePtr node : nodeQueue)
{
InterfaceElementPtr implement = cacheNode->getImplementation(target);
if (!implement || !implement->isA<NodeGraph>())
if (filter && !filter(node))
{
continue;
}
NodeGraphPtr subNodeGraph = implement->asA<NodeGraph>();
graphImplMap[cacheNode] = subNodeGraph;
downstreamPortMap[cacheNode] = cacheNode->getDownstreamPorts();
for (NodePtr subNode : subNodeGraph->getNodes())

InterfaceElementPtr implement = node->getImplementation(target);
if (implement && implement->isA<NodeGraph>())
{
downstreamPortMap[subNode] = subNode->getDownstreamPorts();
processNodeVec.push_back(node);
graphImplMap[node] = implement->asA<NodeGraph>();
declarationMap[node] = node->getDeclaration(target);
downstreamPortMap[node] = node->getDownstreamPorts();
for (NodePtr sourceSubNode : implement->asA<NodeGraph>()->getNodes())
{
downstreamPortMap[sourceSubNode] = sourceSubNode->getDownstreamPorts();
}
}
}
processNodeVec.clear();

// Attributes in addition to value to copy over
StringVec copyAttributes = { ValueElement::UNIT_ATTRIBUTE,
ValueElement::UNITTYPE_ATTRIBUTE,
ValueElement::COLOR_SPACE_ATTRIBUTE };
nodeQueue.clear();

// Iterate through nodes with graph implementations.
for (const auto& pair : graphImplMap)
for (NodePtr processNode : processNodeVec)
{
NodePtr processNode = pair.first;
if (filter && !filter(processNode))
{
continue;
}

NodeGraphPtr sourceSubGraph = pair.second;
NodeGraphPtr sourceSubGraph = graphImplMap[processNode];
std::unordered_map<NodePtr, NodePtr> subNodeMap;

// Create a new instance of each original subnode.
Expand All @@ -302,150 +299,97 @@ void GraphElement::flattenSubgraphs(const string& target, NodePredicate filter)
destSubNode->copyContentFrom(sourceSubNode);
setChildIndex(destSubNode->getName(), getChildIndex(processNode->getName()));

// Transfer interface properties from the reference node to the new subnode.
for (ValueElementPtr destValue : destSubNode->getChildrenOfType<ValueElement>())
{
if (!destValue->hasInterfaceName())
{
continue;
}

ValueElementPtr refValue = processNode->getChildOfType<ValueElement>(destValue->getInterfaceName());
if (refValue)
{
if (refValue->hasValueString())
{
destValue->setValueString(refValue->getValueString());
}
for (auto copyAttribute : copyAttributes)
{
if (refValue->hasAttribute(copyAttribute))
{
destValue->setAttribute(copyAttribute, refValue->getAttribute(copyAttribute));
}
}
if (destValue->isA<Input>() && refValue->isA<Input>())
{
InputPtr refInput = refValue->asA<Input>();
InputPtr newInput = destValue->asA<Input>();
if (refInput->hasNodeName())
{
newInput->setNodeName(refInput->getNodeName());
}
if (refInput->hasOutputString())
{
newInput->setOutputString(refInput->getOutputString());
}
if (refInput->hasNodeGraphString())
{
newInput->setNodeGraphString(refInput->getNodeGraphString());
}
}
}
destValue->removeAttribute(ValueElement::INTERFACE_NAME_ATTRIBUTE);
}

// Store the mapping between subgraphs.
subNodeMap[sourceSubNode] = destSubNode;

// Add the subnode to the queue, allowing processing of nested subgraphs.
processNodeVec.push_back(destSubNode);
nodeQueue.push_back(destSubNode);
}

// Transfer internal connections between subgraphs.
// Update properties of generated subnodes.
for (const auto& subNodePair : subNodeMap)
{
NodePtr sourceSubNode = subNodePair.first;
NodePtr destSubNode = subNodePair.second;

// Update node connections.
for (PortElementPtr sourcePort : downstreamPortMap[sourceSubNode])
{
if (sourcePort->isA<Input>())
{
auto it = subNodeMap.find(sourcePort->getParent()->asA<Node>());
if (it != subNodeMap.end())
{
it->second->setConnectedNode(sourcePort->getName(), destSubNode);
InputPtr processNodeInput = it->second->getInput(sourcePort->getName());
if (processNodeInput)
{
processNodeInput->setNodeName(destSubNode->getName());
}
}
}
else if (sourcePort->isA<Output>())
{
for (PortElementPtr processNodePort : downstreamPortMap[processNode])
{
processNodePort->setConnectedNode(destSubNode);
processNodePort->setNodeName(destSubNode->getName());
}
}
}
}

// Connect any nodegraph outputs within the graph which point to another
// flatten node within the nodegraph. As it's been flattened the previous
// reference is incorrect and needs to be updated.
if (sourceSubGraph->getOutputCount())
{
for (OutputPtr sourceOutput : getOutputs())
// Transfer interface properties.
for (InputPtr destInput : destSubNode->getInputs())
{
const string& nodeNameString = sourceOutput->getNodeName();
const string& outputString = sourceOutput->getOutputString();

if (nodeNameString != processNode->getName())
{
continue;
}

// Look for what the original output pointed to.
OutputPtr sourceSubGraphOutput = outputString.empty() ? sourceSubGraph->getOutputs()[0] : sourceSubGraph->getOutput(outputString);
if (!sourceSubGraphOutput)
if (destInput->hasInterfaceName())
{
continue;
}

string destName = sourceSubGraphOutput->getNodeName();
if (destName.empty())
{
destName = sourceSubGraphOutput->getNodeGraphString();
}
NodePtr sourceSubNode = sourceSubGraph->getNode(destName);
NodePtr destNode = sourceSubNode ? subNodeMap[sourceSubNode] : nullptr;
if (destNode)
{
destName = destNode->getName();
InputPtr sourceInput = processNode->getInput(destInput->getInterfaceName());
if (sourceInput)
{
destInput->copyContentFrom(sourceInput);
}
else
{
ConstNodeDefPtr declaration = declarationMap[processNode];
InputPtr declInput = declaration ? declaration->getActiveInput(destInput->getInterfaceName()) : nullptr;
if (declInput)
{
if (declInput->hasValueString())
{
destInput->setValueString(declInput->getValueString());
}
if (declInput->hasDefaultGeomPropString())
{
ConstGeomPropDefPtr geomPropDef = getDocument()->getGeomPropDef(declInput->getDefaultGeomPropString());
if (geomPropDef)
{
destInput->setConnectedNode(addGeomNode(geomPropDef, "geomNode"));
}
}
}
}
destInput->removeAttribute(ValueElement::INTERFACE_NAME_ATTRIBUTE);
}

// Point original output to this one
sourceOutput->setNodeName(destName);
}
}

// If the node was flattened then any downstream references
// need to be updated to point to the new root of the flatten node.
PortElementVec downstreamPorts = downstreamPortMap[processNode];
for (auto downstreamPort : downstreamPorts)
// Update downstream ports with connections to subgraph outputs.
for (PortElementPtr downstreamPort : downstreamPortMap[processNode])
{
const string& outputString = downstreamPort->getOutputString();

// Look for an output on the flattened graph
OutputPtr sourceSubGraphOutput = outputString.empty() ? sourceSubGraph->getOutputs()[0] : sourceSubGraph->getOutput(outputString);
if (!sourceSubGraphOutput)
{
continue;
}

// Find connected node to the output
string destName = sourceSubGraphOutput->getNodeName();
if (destName.empty())
if (downstreamPort->hasOutputString())
{
destName = sourceSubGraphOutput->getNodeGraphString();
}
NodePtr sourceSubNode = sourceSubGraph->getNode(destName);
NodePtr destNode = sourceSubNode ? subNodeMap[sourceSubNode] : nullptr;
if (destNode)
{
destName = destNode->getName();
OutputPtr subGraphOutput = sourceSubGraph->getOutput(downstreamPort->getOutputString());
if (subGraphOutput)
{
string destName = subGraphOutput->getNodeName();
NodePtr sourceSubNode = sourceSubGraph->getNode(destName);
NodePtr destNode = sourceSubNode ? subNodeMap[sourceSubNode] : nullptr;
if (destNode)
{
destName = destNode->getName();
}
downstreamPort->setNodeName(destName);
downstreamPort->setOutputString(EMPTY_STRING);
}
}

// Use that node to overwrite downstream port connection
downstreamPort->setNodeName(destName);
downstreamPort->setOutputString(EMPTY_STRING);
}

// The processed node has been replaced, so remove it from the graph.
Expand Down
11 changes: 8 additions & 3 deletions source/MaterialXCore/Node.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,14 @@ class MX_CORE_API GraphElement : public InterfaceElement
/// @}
/// @name Utility
/// @{

/// Flatten any references to graph-based node definitions within this
/// node graph, replacing each reference with the equivalent node network.

/// Flatten all subgraphs at the root scope of this graph element,
/// recursively replacing each graph-defined node with its equivalent
/// node network.
/// @param target An optional target string to be used in specifying
/// which node definitions are used in this process.
/// @param filter An optional node predicate specifying which nodes
/// should be included and excluded from this process.
void flattenSubgraphs(const string& target = EMPTY_STRING, NodePredicate filter = nullptr);

/// Return a vector of all children (nodes and outputs) sorted in
Expand Down
12 changes: 6 additions & 6 deletions source/MaterialXTest/MaterialXFormat/XmlIo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ TEST_CASE("Load content", "[xmlio]")
mx::readFromXmlBuffer(writtenDoc, xmlString.c_str());
REQUIRE(*writtenDoc == *doc);

// Flatten subgraph references.
for (mx::NodeGraphPtr nodeGraph : doc->getNodeGraphs())
// Flatten all subgraphs.
doc->flattenSubgraphs();
for (mx::NodeGraphPtr graph : doc->getNodeGraphs())
{
if (nodeGraph->getActiveSourceUri() != doc->getSourceUri())
if (graph->getActiveSourceUri() == doc->getSourceUri())
{
continue;
graph->flattenSubgraphs();
}
nodeGraph->flattenSubgraphs();
REQUIRE(nodeGraph->validate());
}
REQUIRE(doc->validate());

// Verify that all referenced types and nodes are declared.
bool referencesValid = true;
Expand Down
21 changes: 21 additions & 0 deletions source/MaterialXView/Viewer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ Viewer::Viewer(const std::string& materialFilename,
_splitByUdims(true),
_mergeMaterials(false),
_showAllInputs(false),
_flattenSubgraphs(false),
_targetShader("standard_surface"),
_captureRequested(false),
_exitRequested(false),
Expand Down Expand Up @@ -892,6 +893,13 @@ void Viewer::createAdvancedSettings(Widget* parent)
_showAllInputs = enable;
});

ng::CheckBox* flattenBox = new ng::CheckBox(advancedPopup, "Flatten Subgraphs");
flattenBox->set_checked(_flattenSubgraphs);
flattenBox->set_callback([this](bool enable)
{
_flattenSubgraphs = enable;
});

ng::CheckBox* splitDirectLightBox = new ng::CheckBox(advancedPopup, "Split Direct Light");
splitDirectLightBox->set_checked(_splitDirectLight);
splitDirectLightBox->set_callback([this](bool enable)
Expand Down Expand Up @@ -1173,6 +1181,19 @@ void Viewer::loadDocument(const mx::FilePath& filename, mx::DocumentPtr librarie
// Apply modifiers to the content document.
applyModifiers(doc, _modifiers);

// Flatten subgraphs if requested.
if (_flattenSubgraphs)
{
doc->flattenSubgraphs();
for (mx::NodeGraphPtr graph : doc->getNodeGraphs())
{
if (graph->getActiveSourceUri() == doc->getActiveSourceUri())
{
graph->flattenSubgraphs();
}
}
}

// Validate the document.
std::string message;
if (!doc->validate(&message))
Expand Down
1 change: 1 addition & 0 deletions source/MaterialXView/Viewer.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ class Viewer : public ng::Screen
bool _splitByUdims;
bool _mergeMaterials;
bool _showAllInputs;
bool _flattenSubgraphs;

// Shader translation
std::string _targetShader;
Expand Down

0 comments on commit 7d4b758

Please sign in to comment.