Skip to content

Commit

Permalink
Support ${{ matrix.ORDINAL }} for matrix rows context (#716)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov authored Dec 23, 2021
1 parent 8793a9c commit cdee0ad
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.D/693.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Support ${{ matrix.ORDINAL }} as unique 0-based index for selected rows.

If a batch flow has a matrix, all matrix rows are enumerated.
The ordinal number of each row is available as `${{ matrix.ORDINAL }}` system value.
2 changes: 1 addition & 1 deletion neuro_flow/batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ async def _process_task(self, full_id: FullID) -> None:
log.debug(f"BatchExecutor: processing task {full_id}")
task = await self._get_task(full_id)

# Check is is task fits in max_parallel
# Check is task fits max_parallel
for n in range(1, len(full_id) + 1):
node = full_id[:n]
prefix = node[:-1]
Expand Down
8 changes: 6 additions & 2 deletions neuro_flow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,8 @@ async def setup_matrix(ast_matrix: Optional[ast.Matrix]) -> Sequence[MatrixCtx]:
ast_matrix._end,
)
matrices.append({k: await v.eval(EMPTY_ROOT) for k, v in inc_spec.items()})
for pos, dct in enumerate(matrices):
dct["ORDINAL"] = pos
return matrices


Expand Down Expand Up @@ -2739,8 +2741,10 @@ async def _setup_ids(
task_id = await ast_task.id.eval(ctx)
if task_id is None:
# Dash is not allowed in identifier, so the generated read id
# never clamps with user_provided one.
suffix = [str(ctx.matrix[k]) for k in sorted(ctx.matrix)]
# never clamps with user-provided one.
# Filter system properties
keys = [key for key in sorted(ctx.matrix) if key == key.lower()]
suffix = [str(ctx.matrix[key]) for key in keys]
real_id = "-".join(["task", str(num), *suffix])
else:
real_id = task_id
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/batch-matrix-doubles.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
kind: batch
tasks:
- id: task_${{ replace(join('__', values(matrix), True), '.', '_') }}
- id: taskN${{ matrix.ORDINAL }}__${{ replace(join('__', values(matrix), True), '.', '_') }}
strategy:
matrix:
x: [0.1, 0.2]
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,10 @@ async def test_pipeline_matrix_with_doubles(batch_config_loader: ConfigLoader) -
)

assert flow.graph == {
"task_0_1__0_3": {},
"task_0_1__0_5": {},
"task_0_2__0_3": {},
"task_0_2__0_5": {},
"taskN0__0_1__0_3__0": {},
"taskN1__0_1__0_5__1": {},
"taskN2__0_2__0_3__2": {},
"taskN3__0_2__0_5__3": {},
}


Expand Down

0 comments on commit cdee0ad

Please sign in to comment.