Skip to content

Commit

Permalink
Improve by unknown affix classes
Browse files Browse the repository at this point in the history
  • Loading branch information
dizys committed Feb 17, 2022
1 parent c532df7 commit 077a7fd
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 39 deletions.
70 changes: 41 additions & 29 deletions src/exp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
"cells": [
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from argparse import Namespace\n",
"from pathlib import Path\n",
"\n",
"from typing import Dict, List, Tuple, Set"
"from typing import Dict, List, Tuple, Set, Union"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -24,7 +24,7 @@
"PosixPath('/Users/ziyang/Projects/aca/nyu-nlp-homework-3/src/..')"
]
},
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -75,12 +75,12 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"word_count: Dict[str, int] = {}\n",
"word_tag_count: Dict[Tuple[str, str], int] = {}\n",
"word_set: Set[str] = set()\n",
"tag_count: Dict[str, int] = {}\n",
"tag_tag_count: Dict[Tuple[str, str], int] = {}\n",
"\n",
Expand All @@ -100,7 +100,7 @@
" tag_count[tag] = tag_count.get(tag, 0) + 1\n",
" tag_tag_count[(last_tag, tag)] = tag_tag_count.get((last_tag, tag), 0) + 1\n",
" if word:\n",
" word_set.add(word)\n",
" word_count[word] = word_count.get(word, 0) + 1\n",
" word_tag_count[(word, tag)] = word_tag_count.get((word, tag), 0) + 1\n",
"\n",
" if tag != 'E':\n",
Expand All @@ -111,27 +111,39 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{('the', 'DT'): 2,\n",
" ('cat', 'NN'): 1,\n",
" ('sat', 'VBD'): 1,\n",
" ('on', 'IN'): 1,\n",
" ('mat', 'NN'): 1,\n",
" ('.', '.'): 1}"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"word_tag_count"
"# clean for unknown word\n",
"affixes: List[Union[List[str], str]] = [['able', 'ible'], 'al', 'an', 'ar', 'ed', 'en', 'er', 'est', 'ful', 'ic', 'ing', 'ish', 'ive', 'less', 'ly', 'ment', 'ness', 'or', 'ous', 'y']\n",
"\n",
"new_word_count = word_count.copy()\n",
"\n",
"for word, count in word_count.items():\n",
" if count > 1:\n",
" continue\n",
" word_class = \"UNKNOWN\"\n",
" for affix in affixes:\n",
" if type(affix) == list:\n",
" word_affix_class = affix[0].upper()\n",
" selected = False\n",
" for affix_item in affix:\n",
" if word.endswith(affix_item):\n",
" word_class = f\"UNKNOWN_AFFIXED_WITH_{word_affix_class}\"\n",
" selected = True\n",
" break\n",
" if selected:\n",
" break\n",
" else:\n",
" if word.endswith(affix):\n",
" word_class = f\"UNKNOWN_AFFIXED_WITH_{affix.upper()}\"\n",
" break\n",
" new_word_count.pop(word)\n",
" new_word_count[word_class] = new_word_count.get(word_class, 0) + 1\n",
" for tag in tag_count.keys():\n",
" original_word_tag_count = word_tag_count.pop((word, tag), 0)\n",
" word_tag_count[(word_class, tag)] = word_tag_count.get((word_class, tag), 0) + original_word_tag_count\n"
]
},
{
Expand Down
66 changes: 56 additions & 10 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,30 @@
import argparse
import pickle

from typing import Dict, List, Tuple, Set, TypedDict
from typing import Dict, List, Tuple, Set, Union, TypedDict

suffixes: List[Union[List[str], str]] = [['able', 'ible'], 'al', 'an', 'ar', 'ed', 'en', ['er', 'or'],
'est', 'ing', ['ish', 'ous', 'ful', 'less'], 'ive', 'ly', ['ment', 'ness'], 'y']


def get_unknown_word_class_by_suffix(word: str) -> str:
word_class = "UNKNOWN"
for suffix in suffixes:
if type(suffix) == list:
word_suffix_class = suffix[0].upper()
selected = False
for suffix_item in suffix:
if word.endswith(suffix_item):
word_class = f"UNKNOWN_AFFIXED_WITH_{word_suffix_class}"
selected = True
break
if selected:
break
else:
if word.endswith(suffix):
word_class = f"UNKNOWN_AFFIXED_WITH_{suffix.upper()}"
break
return word_class


class TrainedStates(TypedDict):
Expand Down Expand Up @@ -46,7 +69,11 @@ def tag(self, sentence) -> List[str]:
if (word, tag) in emission_tb:
emission_prob = emission_tb[(word, tag)]
else:
emission_prob = 1 / 1000
word_class = get_unknown_word_class_by_suffix(word)
if (word_class, tag) in emission_tb:
emission_prob = emission_tb[(word_class, tag)]
else:
emission_prob = 1 / 1000
prob = last_tag_prob * trans_prob * emission_prob
if prob > max_prob:
max_last_tag = last_tag
Expand Down Expand Up @@ -83,15 +110,15 @@ def tag(self, sentence) -> List[str]:
class BasicStatistics(TypedDict):
word_tag_count: Dict[str, Dict[str, int]]
tag_set: Set[str]
word_set: Set[str]
word_count: Dict[str, int]
tag_count: Dict[str, int]
tag_tag_count: Dict[str, Dict[str, int]]


def train_get_statistics(lines: List[str]) -> 'BasicStatistics':
word_tag_count: Dict[Tuple[str, str], int] = {}
tag_set: Set[str] = set(['B', 'E'])
word_set: Set[str] = set()
word_count: Dict[str, int] = {}
tag_count: Dict[str, int] = {}
tag_tag_count: Dict[Tuple[str, str], int] = {}

Expand All @@ -112,7 +139,7 @@ def train_get_statistics(lines: List[str]) -> 'BasicStatistics':
tag_tag_count[(last_tag, tag)] = tag_tag_count.get(
(last_tag, tag), 0) + 1
if word:
word_set.add(word)
word_count[word] = word_count.get(word, 0) + 1
tag_set.add(tag)
word_tag_count[(word, tag)] = word_tag_count.get(
(word, tag), 0) + 1
Expand All @@ -121,7 +148,25 @@ def train_get_statistics(lines: List[str]) -> 'BasicStatistics':
last_tag = tag
else:
last_tag = 'B'
return BasicStatistics(word_tag_count=word_tag_count, tag_set=tag_set, word_set=word_set, tag_count=tag_count, tag_tag_count=tag_tag_count)
return BasicStatistics(word_tag_count=word_tag_count, tag_set=tag_set, word_count=word_count, tag_count=tag_count, tag_tag_count=tag_tag_count)


def train_unkownify_statistics(statistics: BasicStatistics) -> BasicStatistics:
word_tag_count = statistics["word_tag_count"]
word_count = statistics["word_count"].copy()

for word, count in statistics["word_count"].items():
if count > 1:
continue
word_class = get_unknown_word_class_by_suffix(word)
word_count.pop(word)
word_count[word_class] = word_count.get(word_class, 0) + 1
for tag in statistics["tag_count"].keys():
original_word_tag_count = word_tag_count.pop((word, tag), 0)
word_tag_count[(word_class, tag)] = word_tag_count.get(
(word_class, tag), 0) + original_word_tag_count

return BasicStatistics(word_tag_count=word_tag_count, tag_set=statistics["tag_set"], word_count=word_count, tag_count=statistics["tag_count"], tag_tag_count=statistics["tag_tag_count"])


def train_get_trans_prob(tag_count: Dict[str, int], tag_tag_count: Dict[Tuple[str, str], int]) -> Dict[Tuple[str, str], float]:
Expand All @@ -138,9 +183,9 @@ def train_get_trans_prob(tag_count: Dict[str, int], tag_tag_count: Dict[Tuple[st
return trans_prob


def train_get_emission_prob(word_set: Set[str], tag_count: Dict[str, int], word_tag_count: Dict[Tuple[str, str], int]) -> Dict[Tuple[str, str], float]:
def train_get_emission_prob(word_count: Dict[str, int], tag_count: Dict[str, int], word_tag_count: Dict[Tuple[str, str], int]) -> Dict[Tuple[str, str], float]:
emission_prob: Dict[Tuple[str, str], float] = {} # (word, tag) -> prob
for word in word_set:
for word in word_count.keys():
for tag in tag_count.keys():
word_tag = word_tag_count.get((word, tag), 0)
tag_total = tag_count.get(tag, 0)
Expand All @@ -156,14 +201,15 @@ def train(inputfile: str, statefile: str) -> None:
with open(inputfile, 'r') as f:
lines = f.readlines()
statistics = train_get_statistics(lines)
statistics = train_unkownify_statistics(statistics)
trans_prob = train_get_trans_prob(
statistics["tag_count"], statistics["tag_tag_count"])
emission_prob = train_get_emission_prob(
statistics["word_set"], statistics["tag_count"], statistics["word_tag_count"])
statistics["word_count"], statistics["tag_count"], statistics["word_tag_count"])
states = TrainedStates(
trans_prob=trans_prob, emission_prob=emission_prob, tags=statistics["tag_set"])
print(
f"{len(statistics['word_set'])} distinct words, {len(statistics['tag_count'])} tags")
f"{len(statistics['word_count'])} distinct words, {len(statistics['tag_count'])} tags")
print(f"{len(trans_prob)} trans_prob pairs, {len(emission_prob)} emission pairs.")
with open(statefile, "wb") as f:
pickle.dump(states, f)
Expand Down

0 comments on commit 077a7fd

Please sign in to comment.