From 125f217631b864b8af4fdc0b77f4d769b3286f89 Mon Sep 17 00:00:00 2001 From: Remi Gau Date: Mon, 7 Aug 2023 12:50:53 -0400 Subject: [PATCH] fix bug with unsorted model (#613) --- +bids/Model.m | 90 ++++++++++++++++++++++++++++------------- tests/test_bids_model.m | 57 ++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 29 deletions(-) diff --git a/+bids/Model.m b/+bids/Model.m index fb2129ba..a50619dd 100644 --- a/+bids/Model.m +++ b/+bids/Model.m @@ -69,6 +69,8 @@ verbose = true % hides warning if ``false`` + dag_built = false % if the directed acyclic graph has been built + end methods @@ -187,6 +189,7 @@ end + obj = obj.build_dag; obj.validate(); end @@ -367,61 +370,68 @@ end + function source_nodes = get_parent(obj, node_name) + source_nodes = obj.get_source_node(node_name); + end + function source_nodes = get_source_node(obj, node_name) - source_nodes = {}; + obj = obj.build_dag; - if isempty(obj.Edges) - obj = obj.get_edges_from_nodes; - end + source_nodes = {}; - if strcmp(node_name, obj.Edges{1}.Source) - % The root node cannot have a source + % The root node cannot have a source + [~, root_node_name] = obj.get_root_node(); + if strcmp(node_name, root_node_name) return end - % we should only get 1 value - for i = 1:numel(obj.Edges) - if strcmp(obj.Edges{i}.Destination, node_name) - source = obj.Edges{i}.Source; - source_nodes{end + 1} = obj.get_nodes('Name', source); - end - end - - assert(numel(source_nodes) == 1); - - if numel(source_nodes) == 1 - source_nodes = source_nodes{1}; - end + node = obj.get_nodes('Name', node_name); + source_nodes = obj.get_nodes('Name', node.parent); end function [root_node, root_node_name] = get_root_node(obj) + obj = obj.build_dag; edges = obj.Edges; if isempty(edges) + % assume a serial model root_node = obj.Nodes(1); + if iscell(root_node) + root_node = root_node{1}; + end + root_node_name = root_node.Name; + return + end - elseif iscell(edges) - root_node_name = edges{1}.Source; - root_node = obj.get_nodes('Name', root_node_name); - + % start from the first edge and go up the DAG + if iscell(edges) + current_node_name = edges{1}.Source; elseif isstruct(edges(1)) - root_node_name = edges(1).Source; - root_node = obj.get_nodes('Name', root_node_name); + current_node_name = edges(1).Source; + end - else - root_node = obj.Nodes(1); + while true + current_node = obj.get_nodes('Name', current_node_name); + has_parent = isfield(current_node, 'parent'); + + if ~has_parent + root_node_name = current_node.Name; + break + end + + current_node_name = current_node.parent; end + root_node = current_node; + if iscell(root_node) root_node = root_node{1}; end - root_node_name = root_node.Name; - end function edge = get_edge(obj, field, value) @@ -497,6 +507,24 @@ value = cellfun(@(x) x.Name, obj.Nodes, 'UniformOutput', false); end + function obj = build_dag(obj) + if obj.dag_built + return + end + if isempty(obj.Edges) + obj = obj.get_edges_from_nodes; + end + for iEdge = 1:numel(obj.Edges) + source = obj.Edges{iEdge}.Source; + destination = obj.Edges{iEdge}.Destination; + [~, idx] = obj.get_nodes('Name', destination); + % assume can only have a single parent + % so we use a char and note cellstr + obj.Nodes{idx}.parent = source; + end + obj.dag_built = true; + end + function validate(obj) % % Very light validation of fields that were not checked on loading. @@ -873,6 +901,10 @@ function validate_edges(obj) end end + if isfield(this_node, 'parent') + this_node = rmfield(this_node, 'parent'); + end + obj.content.Nodes{i} = this_node; end diff --git a/tests/test_bids_model.m b/tests/test_bids_model.m index 3ee03d07..6ce3b95e 100644 --- a/tests/test_bids_model.m +++ b/tests/test_bids_model.m @@ -6,6 +6,36 @@ initTestSuite; end +function test_build_dag() + % + % model is run --> subject --> dataset + % \ + % --> session + % + % but nodes and edges and not ordered properly + % + + bm = bids.Model('init', true); + + bm.Nodes{1} = bm.empty_node('dataset'); + bm.Nodes{4} = bm.empty_node('session'); + bm.Nodes{2} = bm.empty_node('run'); + bm.Nodes{3} = bm.empty_node('subject'); + + bm.Edges{1} = struct('Source', 'subject', 'Destination', 'dataset'); + bm.Edges{3} = struct('Source', 'run', 'Destination', 'session'); + bm.Edges{2} = struct('Source', 'run', 'Destination', 'subject'); + + bm = bm.build_dag(); + + assertEqual(bm.dag_built, true); + assertEqual(bm.Nodes{1}.parent, 'subject'); + assert(~isfield(bm.Nodes{2}, 'parent')); + assertEqual(bm.Nodes{3}.parent, 'run'); + assertEqual(bm.Nodes{4}.parent, 'run'); + +end + function test_model_node_not_in_edges() bm = bids.Model('file', model_file('narps'), 'verbose', false); @@ -35,6 +65,33 @@ function test_model_load_edges() end +function test_model_with_edges_and_node_not_in_order() + % + % model is run --> subject --> dataset + % \ + % --> session + % + % but nodes and edges and not ordered properly + % + + bm = bids.Model('init', true); + + bm.Nodes{1} = bm.empty_node('dataset'); + bm.Nodes{4} = bm.empty_node('session'); + bm.Nodes{2} = bm.empty_node('run'); + bm.Nodes{3} = bm.empty_node('subject'); + + bm.Edges{1} = struct('Source', 'subject', 'Destination', 'dataset'); + bm.Edges{3} = struct('Source', 'run', 'Destination', 'session'); + bm.Edges{2} = struct('Source', 'run', 'Destination', 'subject'); + + assertEqual(bm.get_source_node('session'), bm.get_nodes('Name', 'run')); + + [root_node, root_node_name] = bm.get_root_node(); + assertEqual(root_node_name, 'run'); + +end + function test_model_get_root_node() bm = bids.Model('file', model_file('narps'), 'verbose', false);