Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn on potential data races #1712

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dace/config_schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,13 @@ required:
description: >
Check for undefined symbols in memlets during SDFG validation.

check_race_conditions:
type: bool
default: false
title: Check race conditions
description: >
Check for potential race conditions during validation.

#############################################
# Features for unit testing

Expand Down
51 changes: 44 additions & 7 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" Exception classes and methods for validation of SDFGs. """

import copy
from dace.dtypes import DebugInfo
import os
from typing import TYPE_CHECKING, Dict, List, Set
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Set

import networkx as nx

from dace import dtypes, subsets, symbolic
from dace.dtypes import DebugInfo

if TYPE_CHECKING:
import dace
from dace.memlet import Memlet
from dace.sdfg import SDFG
from dace.sdfg import graph as gr
from dace.memlet import Memlet
from dace.sdfg.state import ControlFlowRegion

###########################################
Expand All @@ -34,8 +39,8 @@ def validate_control_flow_region(sdfg: 'SDFG',
symbols: dict,
references: Set[int] = None,
**context: bool):
from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalBlock
from dace.sdfg.scope import is_in_scope
from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, SDFGState

if len(region.source_nodes()) > 1 and region.start_block is None:
raise InvalidSDFGError("Starting block undefined", sdfg, None)
Expand Down Expand Up @@ -200,7 +205,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
# Avoid import loop
from dace import data as dt
from dace.codegen.targets import fpga
from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga
from dace.sdfg.scope import is_devicelevel_fpga, is_devicelevel_gpu

references = references or set()

Expand Down Expand Up @@ -383,7 +388,8 @@ def validate_state(state: 'dace.sdfg.SDFGState',
from dace.sdfg import SDFG
from dace.sdfg import nodes as nd
from dace.sdfg import utils as sdutil
from dace.sdfg.scope import scope_contains_scope, is_devicelevel_gpu, is_devicelevel_fpga
from dace.sdfg.scope import (is_devicelevel_fpga, is_devicelevel_gpu,
scope_contains_scope)

sdfg = sdfg or state.parent
state_id = state_id if state_id is not None else state.parent_graph.node_id(state)
Expand Down Expand Up @@ -839,6 +845,37 @@ def validate_state(state: 'dace.sdfg.SDFGState',
continue
raise error

if Config.get_bool('experimental.check_race_conditions'):
node_labels = []
write_accesses = defaultdict(list)
read_accesses = defaultdict(list)
for node in state.data_nodes():
node_labels.append(node.label)
write_accesses[node.label].extend(
[{'subset': e.data.dst_subset, 'node': node, 'wcr': e.data.wcr} for e in state.in_edges(node)])
read_accesses[node.label].extend(
[{'subset': e.data.src_subset, 'node': node} for e in state.out_edges(node)])

for node_label in node_labels:
writes = write_accesses[node_label]
reads = read_accesses[node_label]
# Check write-write data races.
for i in range(len(writes)):
for j in range(i+1, len(writes)):
same_or_unreachable_nodes = (writes[i]['node'] == writes[j]['node'] or
not nx.has_path(state.nx, writes[i]['node'], writes[j]['node']))
no_wcr = writes[i]['wcr'] is None and writes[j]['wcr'] is None
if same_or_unreachable_nodes and no_wcr:
subsets_intersect = subsets.intersects(writes[i]['subset'], writes[j]['subset'])
if subsets_intersect:
warnings.warn(f'Memlet range overlap while writing to "{node}" in state "{state.label}"')
# Check read-write data races.
for write in writes:
for read in reads:
if (not nx.has_path(state.nx, read['node'], write['node']) and
subsets.intersects(write['subset'], read['subset'])):
warnings.warn(f'Memlet range overlap while writing to "{node}" in state "{state.label}"')

########################################


Expand Down
Loading
Loading