Skip to content

Commit

Permalink
fix bug with unsorted model (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
Remi-Gau authored Aug 7, 2023
1 parent 87b1216 commit 125f217
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 29 deletions.
90 changes: 61 additions & 29 deletions +bids/Model.m
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@

verbose = true % hides warning if ``false``

dag_built = false % if the directed acyclic graph has been built

end

methods
Expand Down Expand Up @@ -187,6 +189,7 @@

end

obj = obj.build_dag;
obj.validate();

end
Expand Down Expand Up @@ -367,61 +370,68 @@

end

function source_nodes = get_parent(obj, node_name)
source_nodes = obj.get_source_node(node_name);
end

function source_nodes = get_source_node(obj, node_name)

source_nodes = {};
obj = obj.build_dag;

if isempty(obj.Edges)
obj = obj.get_edges_from_nodes;
end
source_nodes = {};

if strcmp(node_name, obj.Edges{1}.Source)
% The root node cannot have a source
% The root node cannot have a source
[~, root_node_name] = obj.get_root_node();
if strcmp(node_name, root_node_name)
return
end

% we should only get 1 value
for i = 1:numel(obj.Edges)
if strcmp(obj.Edges{i}.Destination, node_name)
source = obj.Edges{i}.Source;
source_nodes{end + 1} = obj.get_nodes('Name', source);
end
end

assert(numel(source_nodes) == 1);

if numel(source_nodes) == 1
source_nodes = source_nodes{1};
end
node = obj.get_nodes('Name', node_name);
source_nodes = obj.get_nodes('Name', node.parent);

end

function [root_node, root_node_name] = get_root_node(obj)

obj = obj.build_dag;
edges = obj.Edges;

if isempty(edges)
% assume a serial model
root_node = obj.Nodes(1);
if iscell(root_node)
root_node = root_node{1};
end
root_node_name = root_node.Name;
return
end

elseif iscell(edges)
root_node_name = edges{1}.Source;
root_node = obj.get_nodes('Name', root_node_name);

% start from the first edge and go up the DAG
if iscell(edges)
current_node_name = edges{1}.Source;
elseif isstruct(edges(1))
root_node_name = edges(1).Source;
root_node = obj.get_nodes('Name', root_node_name);
current_node_name = edges(1).Source;
end

else
root_node = obj.Nodes(1);
while true

current_node = obj.get_nodes('Name', current_node_name);
has_parent = isfield(current_node, 'parent');

if ~has_parent
root_node_name = current_node.Name;
break
end

current_node_name = current_node.parent;
end

root_node = current_node;

if iscell(root_node)
root_node = root_node{1};
end

root_node_name = root_node.Name;

end

function edge = get_edge(obj, field, value)
Expand Down Expand Up @@ -497,6 +507,24 @@
value = cellfun(@(x) x.Name, obj.Nodes, 'UniformOutput', false);
end

function obj = build_dag(obj)
if obj.dag_built
return
end
if isempty(obj.Edges)
obj = obj.get_edges_from_nodes;
end
for iEdge = 1:numel(obj.Edges)
source = obj.Edges{iEdge}.Source;
destination = obj.Edges{iEdge}.Destination;
[~, idx] = obj.get_nodes('Name', destination);
% assume can only have a single parent
% so we use a char and note cellstr
obj.Nodes{idx}.parent = source;
end
obj.dag_built = true;
end

function validate(obj)
%
% Very light validation of fields that were not checked on loading.
Expand Down Expand Up @@ -873,6 +901,10 @@ function validate_edges(obj)
end
end

if isfield(this_node, 'parent')
this_node = rmfield(this_node, 'parent');
end

obj.content.Nodes{i} = this_node;

end
Expand Down
57 changes: 57 additions & 0 deletions tests/test_bids_model.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,36 @@
initTestSuite;
end

function test_build_dag()
%
% model is run --> subject --> dataset
% \
% --> session
%
% but nodes and edges and not ordered properly
%

bm = bids.Model('init', true);

bm.Nodes{1} = bm.empty_node('dataset');
bm.Nodes{4} = bm.empty_node('session');
bm.Nodes{2} = bm.empty_node('run');
bm.Nodes{3} = bm.empty_node('subject');

bm.Edges{1} = struct('Source', 'subject', 'Destination', 'dataset');
bm.Edges{3} = struct('Source', 'run', 'Destination', 'session');
bm.Edges{2} = struct('Source', 'run', 'Destination', 'subject');

bm = bm.build_dag();

assertEqual(bm.dag_built, true);
assertEqual(bm.Nodes{1}.parent, 'subject');
assert(~isfield(bm.Nodes{2}, 'parent'));
assertEqual(bm.Nodes{3}.parent, 'run');
assertEqual(bm.Nodes{4}.parent, 'run');

end

function test_model_node_not_in_edges()

bm = bids.Model('file', model_file('narps'), 'verbose', false);
Expand Down Expand Up @@ -35,6 +65,33 @@ function test_model_load_edges()

end

function test_model_with_edges_and_node_not_in_order()
%
% model is run --> subject --> dataset
% \
% --> session
%
% but nodes and edges and not ordered properly
%

bm = bids.Model('init', true);

bm.Nodes{1} = bm.empty_node('dataset');
bm.Nodes{4} = bm.empty_node('session');
bm.Nodes{2} = bm.empty_node('run');
bm.Nodes{3} = bm.empty_node('subject');

bm.Edges{1} = struct('Source', 'subject', 'Destination', 'dataset');
bm.Edges{3} = struct('Source', 'run', 'Destination', 'session');
bm.Edges{2} = struct('Source', 'run', 'Destination', 'subject');

assertEqual(bm.get_source_node('session'), bm.get_nodes('Name', 'run'));

[root_node, root_node_name] = bm.get_root_node();
assertEqual(root_node_name, 'run');

end

function test_model_get_root_node()

bm = bids.Model('file', model_file('narps'), 'verbose', false);
Expand Down

0 comments on commit 125f217

Please sign in to comment.