Skip to content
This repository has been archived by the owner on Dec 14, 2020. It is now read-only.

Commit

Permalink
Mrp 2020 (#106)
Browse files Browse the repository at this point in the history
* check for cycles

* add support for 2020 mrp format

* merge changes

* fix

* CR fixes

* remove print

Co-authored-by: Ofir Arviv <[email protected]>
  • Loading branch information
OfirArviv and Ofir Arviv authored Oct 14, 2020
1 parent 3c195eb commit e5536d8
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 21 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ dynet==2.1
logbook>=1.5.2
word2number>=1.1
git+https://github.com/cfmrp/mtool.git#egg=mtool
networkx==2.4
matplotlib == 3.2.1
14 changes: 7 additions & 7 deletions tupa/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
COMPOUND = "compound"

# Required number of edge labels per framework
NODE_LABELS_NUM = {"amr": 1000, "dm": 1000, "psd": 1000, "eds": 1000, "ucca": 0}
NODE_PROPERTY_NUM = {"amr": 1000, "dm": 510, "psd": 1000, "eds": 1000, "ucca": 0}
EDGE_LABELS_NUM = {"amr": 141, "dm": 59, "psd": 90, "eds": 10, "ucca": 15}
NODE_LABELS_NUM = {"amr": 1000, "dm": 1000, "psd": 1000, "eds": 1000, "ucca": 0, "ptg": 1000}
NODE_PROPERTY_NUM = {"amr": 1000, "dm": 510, "psd": 1000, "eds": 1000, "ucca": 0, "ptg": 1000}
EDGE_LABELS_NUM = {"amr": 141, "dm": 59, "psd": 90, "eds": 10, "ucca": 15, "ptg": 150}
EDGE_ATTRIBUTE_NUM = {"amr": 0, "dm": 0, "psd": 0, "eds": 0, "ucca": 2}
NN_ARG_NAMES = set()
DYNET_ARG_NAMES = set()
Expand Down Expand Up @@ -537,20 +537,20 @@ def __str__(self):


def requires_node_labels(framework):
return framework != "ucca"
return framework not in ("ucca", "drg", "ptg")


def requires_node_properties(framework):
return framework != "ucca"
return framework not in ("ucca", "drg")


def requires_edge_attributes(framework):
return framework == "ucca"


def requires_anchors(framework):
return framework != "amr"
return framework not in ("amr", "drg")


def requires_tops(framework):
return framework in ("ucca", "amr")
return framework in ("ucca", "amr", "drg", "ptg")
6 changes: 6 additions & 0 deletions tupa/constraints/ptg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .validation import Constraints


class PtgConstraints(Constraints):
def __init__(self, **kwargs):
super().__init__(multigraph=True, **kwargs)
6 changes: 6 additions & 0 deletions tupa/constraints/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,18 @@ def eds_constraints(**kwargs):
return EdsConstraints(**kwargs)


def ptg_constraints(**kwargs):
from .ptg import PtgConstraints
return PtgConstraints(**kwargs)


CONSTRAINTS = {
"ucca": ucca_constraints,
"amr": amr_constraints,
"dm": sdp_constraints,
"psd": sdp_constraints,
"eds": eds_constraints,
"ptg": ptg_constraints,
}


Expand Down
8 changes: 6 additions & 2 deletions tupa/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,12 @@ def get_node_label(self, state, node):
return true_label, raw_true_label

def get_node_property_value(self, state, node):
true_property_value = next((k, v) for k, v in node.ref_node.properties.items()
if k not in (node.properties or ()))
try:
true_property_value = next((k, v) for k, v in (node.ref_node.properties.items()
if node.ref_node.properties else [])
if k not in (node.properties or ()))
except StopIteration:
return None
if self.args.validate_oracle:
try:
state.check_valid_property_value(true_property_value, message=True)
Expand Down
58 changes: 46 additions & 12 deletions tupa/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def tokens_per_second(self):

class GraphParser(AbstractParser):
""" Parser for a single graph, has a state and optionally an oracle """

def __init__(self, graph, *args, target=None, **kwargs):
"""
:param graph: gold Graph to get the correct nodes and edges from (in training), or just to get id from (in test)
Expand All @@ -74,20 +75,54 @@ def __init__(self, graph, *args, target=None, **kwargs):
assert self.lang, "Attribute 'lang' is required per passage when using multilingual BERT"
self.state_hash_history = set()
self.state = self.oracle = None
if self.framework == "amr" and self.alignment: # Copy alignments to anchors, updating graph
if self.framework in ("amr", "drg", "ptg") and self.alignment: # Copy alignments to anchors, updating graph
for alignment_node in self.alignment.nodes:
node = self.graph.find_node(alignment_node.id)
if node is None:
self.config.log("graph %s: invalid alignment node %s" % (self.graph.id, alignment_node.id))
continue
if node.anchors is None:
node.anchors = []
for conllu_node_id in (alignment_node.label or []) + list(chain(*alignment_node.values or [])):
conllu_node = self.conllu.find_node(conllu_node_id)
if conllu_node is None:
raise ValueError("Alignments incompatible with tokenization: token %s "
"not found in graph %s" % (conllu_node_id, self.graph.id))
node.anchors += conllu_node.anchors

conllu_node_id_list = None
alignment_node_anchor_char_range_list = None
if self.alignment.framework == "alignment":
conllu_node_id_list = (alignment_node.label or []) + list(chain(*alignment_node.values or []))
elif self.alignment.framework == "anchoring" and self.framework in ("amr", "ptg"):
conllu_node_id_list = set([alignment_dict["#"] for alignment_dict in
(alignment_node.anchors or [])
+ ([anchor for anchor_list in (alignment_node.anchorings or []) for anchor in anchor_list])])
elif self.alignment.framework == "anchoring" and self.framework == "drg":
alignment_node_anchor_char_range_list = [(int(alignment_dict["from"]),(int(alignment_dict["to"]))) for alignment_dict in
(alignment_node.anchors or [])
+ ([anchor for anchor_list in (alignment_node.anchorings or []) for anchor in anchor_list])]
assert all([len(conllu_node.anchors) == 1 for conllu_node in self.conllu.nodes])
anchors_to_conllu_node = {(int(conllu_node.anchors[0]["from"]), int(conllu_node.anchors[0]["to"])):
conllu_node
for conllu_node in self.conllu.nodes}
else:
raise ValueError(f'Unknown alignments framework: {alignment_node.framework}')

if conllu_node_id_list is not None:
assert self.framework in ("amr", "ptg")
for conllu_node_id in conllu_node_id_list:
conllu_node = self.conllu.find_node(conllu_node_id + 1)

if conllu_node is None:
raise ValueError("Alignments incompatible with tokenization: token %s "
"not found in graph %s" % (conllu_node_id, self.graph.id))
node.anchors += conllu_node.anchors

elif alignment_node_anchor_char_range_list is not None:
for alignment_node_char_range in alignment_node_anchor_char_range_list:
for conllu_anchor_range in anchors_to_conllu_node:
if alignment_node_char_range[0] <= conllu_anchor_range[0] \
and alignment_node_char_range[1] >= conllu_anchor_range[1]:
conllu_node = anchors_to_conllu_node[conllu_anchor_range]
if conllu_node is None:
raise ValueError("Alignments incompatible with tokenization: token %s "
"not found in graph %s" % (conllu_anchor_range, self.graph.id))
node.anchors += conllu_node.anchors

def init(self):
self.config.set_framework(self.framework)
Expand Down Expand Up @@ -320,6 +355,7 @@ def num_tokens(self, _):

class BatchParser(AbstractParser):
""" Parser for a single training iteration or single pass over dev/test graphs """

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.seen_per_framework = defaultdict(int)
Expand All @@ -335,14 +371,11 @@ def parse(self, graphs, display=True, write=False, accuracies=None):
if conllu is None:
self.config.print("skipped '%s', no companion conllu data found" % graph.id)
continue
alignment = self.alignment.get(graph.id)
alignment = self.alignment.get(graph.id) if self.alignment else None
for target in graph.targets() or [graph.framework]:
if not self.training and target not in self.model.classifier.labels:
self.config.print("skipped target '%s' for '%s': did not train on it" % (target, graph.id), level=1)
continue
if target == "amr" and alignment is None:
self.config.print("skipped target 'amr' for '%s': no companion alignment found" % graph.id, level=1)
continue
parser = GraphParser(
graph, self.config, self.model, self.training, conllu=conllu, alignment=alignment, target=target)
if self.config.args.verbose and display:
Expand Down Expand Up @@ -403,6 +436,7 @@ def time_per_graph(self):

class Parser(AbstractParser):
""" Main class to implement transition-based meaning representation parser """

def __init__(self, model_file=None, config=None, training=None, conllu=None, alignment=None):
super().__init__(config=config or Config(), model=Model(model_file or config.args.model),
training=config.args.train if training is None else training,
Expand Down Expand Up @@ -646,7 +680,7 @@ def read_graphs_with_progress_bar(file_handle_or_graphs):
if isinstance(file_handle_or_graphs, IOBase):
graphs, _ = read_graphs(
tqdm(file_handle_or_graphs, desc="Reading " + getattr(file_handle_or_graphs, "name", "input"),
unit=" graphs"), format="mrp")
unit=" graphs"), format="mrp", robust=True)
return graphs
return file_handle_or_graphs

Expand Down
45 changes: 45 additions & 0 deletions tupa/states/ref_graph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import sys
from typing import List, Tuple

from .anchors import expand_anchors
from .edge import StateEdge
from .node import StateNode
from ..constraints.amr import NAME
from ..constraints.validation import ROOT_ID, ROOT_LAB, ANCHOR_LAB
from ..recategorization import resolve, compress_name

import networkx as nx

class RefGraph:
def __init__(self, graph, conllu, framework):
Expand All @@ -27,6 +31,7 @@ def __init__(self, graph, conllu, framework):
offset = len(conllu.nodes) + 1
self.non_virtual_nodes = []
self.edges = []
have_anchors = False
for graph_node in graph.nodes:
node_id = graph_node.id + offset
id2node[node_id] = node = \
Expand All @@ -43,7 +48,21 @@ def __init__(self, graph, conllu, framework):
anchor_terminals = [min(self.terminals, key=lambda terminal: min(
x - y for x in terminal.anchors for y in node.anchors))] # Must have anchors, get closest one
for terminal in anchor_terminals:
have_anchors = True
self.edges.append(StateEdge(node, terminal, ANCHOR_LAB).add())

if not have_anchors:
print(f'framework {graph.framework} graph id {graph.id} have no anchors', file=sys.stderr)

cycle = find_cycle(graph)
while len(cycle) > 0:
edge_list = list(graph.edges)
first_edge_idx = \
[i for i, edge in enumerate(graph.edges) if edge.src == cycle[0][0] and edge.tgt == cycle[0][1]][0]
del edge_list[first_edge_idx]
graph.edges = set(edge_list)
cycle = find_cycle(graph)

for edge in graph.edges:
if edge.src != edge.tgt: # Drop self-loops as the parser currently does not support them
self.edges.append(StateEdge(id2node[edge.src + offset],
Expand All @@ -55,4 +74,30 @@ def __init__(self, graph, conllu, framework):
node.properties = compress_name(node.properties)
node.properties = {prop: resolve(node, value, introduce_placeholders=True)
for prop, value in node.properties.items()}

node.label = resolve(node, node.label, introduce_placeholders=True) # Must be after properties in case NAME


def find_cycle(graph, plot_graph=False) -> List[Tuple[int, int]]:
edges_tuple = [(e.src, e.tgt) for e in graph.edges]
nx_graph = nx.DiGraph()
nx_graph.add_edges_from(edges_tuple)
try:
cycle = nx.find_cycle(nx_graph)
except nx.exception.NetworkXNoCycle as e:
cycle = []

if plot_graph:
import matplotlib.pyplot as plt
nx.draw(nx_graph, with_labels=True, font_weight='bold')
plt.show()

return cycle


def is_directed_acyclic_graph(graph) -> bool:
edges_tuple = list(map(lambda x: (x.src, x.tgt), graph.edges))
nx_graph = nx.DiGraph()
nx_graph.add_edges_from(edges_tuple)

assert nx.is_directed_acyclic_graph(nx_graph)

0 comments on commit e5536d8

Please sign in to comment.