diff --git a/requirements.txt b/requirements.txt index f4609ee6..a4f0442a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,5 @@ dynet==2.1 logbook>=1.5.2 word2number>=1.1 git+https://github.com/cfmrp/mtool.git#egg=mtool +networkx==2.4 +matplotlib == 3.2.1 diff --git a/tupa/config.py b/tupa/config.py index 65490d68..1daa1b42 100644 --- a/tupa/config.py +++ b/tupa/config.py @@ -22,9 +22,9 @@ COMPOUND = "compound" # Required number of edge labels per framework -NODE_LABELS_NUM = {"amr": 1000, "dm": 1000, "psd": 1000, "eds": 1000, "ucca": 0} -NODE_PROPERTY_NUM = {"amr": 1000, "dm": 510, "psd": 1000, "eds": 1000, "ucca": 0} -EDGE_LABELS_NUM = {"amr": 141, "dm": 59, "psd": 90, "eds": 10, "ucca": 15} +NODE_LABELS_NUM = {"amr": 1000, "dm": 1000, "psd": 1000, "eds": 1000, "ucca": 0, "ptg": 1000} +NODE_PROPERTY_NUM = {"amr": 1000, "dm": 510, "psd": 1000, "eds": 1000, "ucca": 0, "ptg": 1000} +EDGE_LABELS_NUM = {"amr": 141, "dm": 59, "psd": 90, "eds": 10, "ucca": 15, "ptg": 150} EDGE_ATTRIBUTE_NUM = {"amr": 0, "dm": 0, "psd": 0, "eds": 0, "ucca": 2} NN_ARG_NAMES = set() DYNET_ARG_NAMES = set() @@ -537,11 +537,11 @@ def __str__(self): def requires_node_labels(framework): - return framework != "ucca" + return framework not in ("ucca", "drg", "ptg") def requires_node_properties(framework): - return framework != "ucca" + return framework not in ("ucca", "drg") def requires_edge_attributes(framework): @@ -549,8 +549,8 @@ def requires_edge_attributes(framework): def requires_anchors(framework): - return framework != "amr" + return framework not in ("amr", "drg") def requires_tops(framework): - return framework in ("ucca", "amr") + return framework in ("ucca", "amr", "drg", "ptg") diff --git a/tupa/constraints/ptg.py b/tupa/constraints/ptg.py new file mode 100644 index 00000000..c09353b3 --- /dev/null +++ b/tupa/constraints/ptg.py @@ -0,0 +1,6 @@ +from .validation import Constraints + + +class PtgConstraints(Constraints): + def __init__(self, **kwargs): + super().__init__(multigraph=True, **kwargs) diff --git a/tupa/constraints/validation.py b/tupa/constraints/validation.py index 24273fc0..ad7bb91c 100644 --- a/tupa/constraints/validation.py +++ b/tupa/constraints/validation.py @@ -155,12 +155,18 @@ def eds_constraints(**kwargs): return EdsConstraints(**kwargs) +def ptg_constraints(**kwargs): + from .ptg import PtgConstraints + return PtgConstraints(**kwargs) + + CONSTRAINTS = { "ucca": ucca_constraints, "amr": amr_constraints, "dm": sdp_constraints, "psd": sdp_constraints, "eds": eds_constraints, + "ptg": ptg_constraints, } diff --git a/tupa/oracle.py b/tupa/oracle.py index 08e3aa57..941fe177 100644 --- a/tupa/oracle.py +++ b/tupa/oracle.py @@ -193,8 +193,12 @@ def get_node_label(self, state, node): return true_label, raw_true_label def get_node_property_value(self, state, node): - true_property_value = next((k, v) for k, v in node.ref_node.properties.items() - if k not in (node.properties or ())) + try: + true_property_value = next((k, v) for k, v in (node.ref_node.properties.items() + if node.ref_node.properties else []) + if k not in (node.properties or ())) + except StopIteration: + return None if self.args.validate_oracle: try: state.check_valid_property_value(true_property_value, message=True) diff --git a/tupa/parse.py b/tupa/parse.py index d4811631..ca3f8468 100644 --- a/tupa/parse.py +++ b/tupa/parse.py @@ -57,6 +57,7 @@ def tokens_per_second(self): class GraphParser(AbstractParser): """ Parser for a single graph, has a state and optionally an oracle """ + def __init__(self, graph, *args, target=None, **kwargs): """ :param graph: gold Graph to get the correct nodes and edges from (in training), or just to get id from (in test) @@ -74,7 +75,7 @@ def __init__(self, graph, *args, target=None, **kwargs): assert self.lang, "Attribute 'lang' is required per passage when using multilingual BERT" self.state_hash_history = set() self.state = self.oracle = None - if self.framework == "amr" and self.alignment: # Copy alignments to anchors, updating graph + if self.framework in ("amr", "drg", "ptg") and self.alignment: # Copy alignments to anchors, updating graph for alignment_node in self.alignment.nodes: node = self.graph.find_node(alignment_node.id) if node is None: @@ -82,12 +83,46 @@ def __init__(self, graph, *args, target=None, **kwargs): continue if node.anchors is None: node.anchors = [] - for conllu_node_id in (alignment_node.label or []) + list(chain(*alignment_node.values or [])): - conllu_node = self.conllu.find_node(conllu_node_id) - if conllu_node is None: - raise ValueError("Alignments incompatible with tokenization: token %s " - "not found in graph %s" % (conllu_node_id, self.graph.id)) - node.anchors += conllu_node.anchors + + conllu_node_id_list = None + alignment_node_anchor_char_range_list = None + if self.alignment.framework == "alignment": + conllu_node_id_list = (alignment_node.label or []) + list(chain(*alignment_node.values or [])) + elif self.alignment.framework == "anchoring" and self.framework in ("amr", "ptg"): + conllu_node_id_list = set([alignment_dict["#"] for alignment_dict in + (alignment_node.anchors or []) + + ([anchor for anchor_list in (alignment_node.anchorings or []) for anchor in anchor_list])]) + elif self.alignment.framework == "anchoring" and self.framework == "drg": + alignment_node_anchor_char_range_list = [(int(alignment_dict["from"]),(int(alignment_dict["to"]))) for alignment_dict in + (alignment_node.anchors or []) + + ([anchor for anchor_list in (alignment_node.anchorings or []) for anchor in anchor_list])] + assert all([len(conllu_node.anchors) == 1 for conllu_node in self.conllu.nodes]) + anchors_to_conllu_node = {(int(conllu_node.anchors[0]["from"]), int(conllu_node.anchors[0]["to"])): + conllu_node + for conllu_node in self.conllu.nodes} + else: + raise ValueError(f'Unknown alignments framework: {alignment_node.framework}') + + if conllu_node_id_list is not None: + assert self.framework in ("amr", "ptg") + for conllu_node_id in conllu_node_id_list: + conllu_node = self.conllu.find_node(conllu_node_id + 1) + + if conllu_node is None: + raise ValueError("Alignments incompatible with tokenization: token %s " + "not found in graph %s" % (conllu_node_id, self.graph.id)) + node.anchors += conllu_node.anchors + + elif alignment_node_anchor_char_range_list is not None: + for alignment_node_char_range in alignment_node_anchor_char_range_list: + for conllu_anchor_range in anchors_to_conllu_node: + if alignment_node_char_range[0] <= conllu_anchor_range[0] \ + and alignment_node_char_range[1] >= conllu_anchor_range[1]: + conllu_node = anchors_to_conllu_node[conllu_anchor_range] + if conllu_node is None: + raise ValueError("Alignments incompatible with tokenization: token %s " + "not found in graph %s" % (conllu_anchor_range, self.graph.id)) + node.anchors += conllu_node.anchors def init(self): self.config.set_framework(self.framework) @@ -320,6 +355,7 @@ def num_tokens(self, _): class BatchParser(AbstractParser): """ Parser for a single training iteration or single pass over dev/test graphs """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.seen_per_framework = defaultdict(int) @@ -335,14 +371,11 @@ def parse(self, graphs, display=True, write=False, accuracies=None): if conllu is None: self.config.print("skipped '%s', no companion conllu data found" % graph.id) continue - alignment = self.alignment.get(graph.id) + alignment = self.alignment.get(graph.id) if self.alignment else None for target in graph.targets() or [graph.framework]: if not self.training and target not in self.model.classifier.labels: self.config.print("skipped target '%s' for '%s': did not train on it" % (target, graph.id), level=1) continue - if target == "amr" and alignment is None: - self.config.print("skipped target 'amr' for '%s': no companion alignment found" % graph.id, level=1) - continue parser = GraphParser( graph, self.config, self.model, self.training, conllu=conllu, alignment=alignment, target=target) if self.config.args.verbose and display: @@ -403,6 +436,7 @@ def time_per_graph(self): class Parser(AbstractParser): """ Main class to implement transition-based meaning representation parser """ + def __init__(self, model_file=None, config=None, training=None, conllu=None, alignment=None): super().__init__(config=config or Config(), model=Model(model_file or config.args.model), training=config.args.train if training is None else training, @@ -646,7 +680,7 @@ def read_graphs_with_progress_bar(file_handle_or_graphs): if isinstance(file_handle_or_graphs, IOBase): graphs, _ = read_graphs( tqdm(file_handle_or_graphs, desc="Reading " + getattr(file_handle_or_graphs, "name", "input"), - unit=" graphs"), format="mrp") + unit=" graphs"), format="mrp", robust=True) return graphs return file_handle_or_graphs diff --git a/tupa/states/ref_graph.py b/tupa/states/ref_graph.py index 511f3462..1f38894b 100644 --- a/tupa/states/ref_graph.py +++ b/tupa/states/ref_graph.py @@ -1,3 +1,6 @@ +import sys +from typing import List, Tuple + from .anchors import expand_anchors from .edge import StateEdge from .node import StateNode @@ -5,6 +8,7 @@ from ..constraints.validation import ROOT_ID, ROOT_LAB, ANCHOR_LAB from ..recategorization import resolve, compress_name +import networkx as nx class RefGraph: def __init__(self, graph, conllu, framework): @@ -27,6 +31,7 @@ def __init__(self, graph, conllu, framework): offset = len(conllu.nodes) + 1 self.non_virtual_nodes = [] self.edges = [] + have_anchors = False for graph_node in graph.nodes: node_id = graph_node.id + offset id2node[node_id] = node = \ @@ -43,7 +48,21 @@ def __init__(self, graph, conllu, framework): anchor_terminals = [min(self.terminals, key=lambda terminal: min( x - y for x in terminal.anchors for y in node.anchors))] # Must have anchors, get closest one for terminal in anchor_terminals: + have_anchors = True self.edges.append(StateEdge(node, terminal, ANCHOR_LAB).add()) + + if not have_anchors: + print(f'framework {graph.framework} graph id {graph.id} have no anchors', file=sys.stderr) + + cycle = find_cycle(graph) + while len(cycle) > 0: + edge_list = list(graph.edges) + first_edge_idx = \ + [i for i, edge in enumerate(graph.edges) if edge.src == cycle[0][0] and edge.tgt == cycle[0][1]][0] + del edge_list[first_edge_idx] + graph.edges = set(edge_list) + cycle = find_cycle(graph) + for edge in graph.edges: if edge.src != edge.tgt: # Drop self-loops as the parser currently does not support them self.edges.append(StateEdge(id2node[edge.src + offset], @@ -55,4 +74,30 @@ def __init__(self, graph, conllu, framework): node.properties = compress_name(node.properties) node.properties = {prop: resolve(node, value, introduce_placeholders=True) for prop, value in node.properties.items()} + node.label = resolve(node, node.label, introduce_placeholders=True) # Must be after properties in case NAME + + +def find_cycle(graph, plot_graph=False) -> List[Tuple[int, int]]: + edges_tuple = [(e.src, e.tgt) for e in graph.edges] + nx_graph = nx.DiGraph() + nx_graph.add_edges_from(edges_tuple) + try: + cycle = nx.find_cycle(nx_graph) + except nx.exception.NetworkXNoCycle as e: + cycle = [] + + if plot_graph: + import matplotlib.pyplot as plt + nx.draw(nx_graph, with_labels=True, font_weight='bold') + plt.show() + + return cycle + + +def is_directed_acyclic_graph(graph) -> bool: + edges_tuple = list(map(lambda x: (x.src, x.tgt), graph.edges)) + nx_graph = nx.DiGraph() + nx_graph.add_edges_from(edges_tuple) + + assert nx.is_directed_acyclic_graph(nx_graph)