Skip to content

Commit

Permalink
Refactor duplicate_id_map to its own class UniqueIdsManager
Browse files Browse the repository at this point in the history
Tests for UniqueIdsManager
Refactor Mailmerge.parts to be more flexible for handling different
categories of parts
Refactor settings with the new category of parts
  • Loading branch information
iulica committed Nov 9, 2023
1 parent a57da5c commit 22dc108
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 57 deletions.
109 changes: 62 additions & 47 deletions mailmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@
for tag in ATTACHMENT_TAGS
}

CONTENT_TYPES_PARTS = (
'application/vnd.openxmlformats-officedocument.wordprocessingml.document.main+xml',
'application/vnd.ms-word.document.macroEnabled.main+xml',
'application/vnd.openxmlformats-officedocument.wordprocessingml.header+xml',
'application/vnd.openxmlformats-officedocument.wordprocessingml.footer+xml',
'application/vnd.openxmlformats-officedocument.wordprocessingml.footnotes+xml',
'application/vnd.openxmlformats-officedocument.wordprocessingml.endnotes+xml',
)

CONTENT_TYPE_SETTINGS = 'application/vnd.openxmlformats-officedocument.wordprocessingml.settings+xml'
CONTENT_TYPES_PARTS = {
'application/vnd.openxmlformats-officedocument.wordprocessingml.document.main+xml':'main',
'application/vnd.ms-word.document.macroEnabled.main+xml':'main',
'application/vnd.openxmlformats-officedocument.wordprocessingml.header+xml':'rel_part',
'application/vnd.openxmlformats-officedocument.wordprocessingml.footer+xml':'rel_part',
'application/vnd.openxmlformats-officedocument.wordprocessingml.footnotes+xml':'notes',
'application/vnd.openxmlformats-officedocument.wordprocessingml.endnotes+xml':'notes',
'application/vnd.openxmlformats-officedocument.wordprocessingml.settings+xml':'settings'
}

VALID_SEPARATORS = {
'page_break', 'column_break', 'textWrapping_break',
Expand Down Expand Up @@ -359,6 +358,23 @@ class NextField(MergeField):
def fill_data(self, merge_data, row):
raise NextRecord()

class UniqueIdsManager(object):
""" handles different counters for various ids in the document """

def __init__(self):
self.id_type_map = {} # type of id -> {'max': max_id, 'ids': set(existing_ids)}

def register_id(self, id_type, obj_id=None):
""" registers an new object id or creates a new id for the type """
type_id_value = self.id_type_map.setdefault(id_type, {"max": 0, "ids": set()})
new_obj_id = None
if obj_id is None or obj_id in type_id_value['ids']:
obj_id = type_id_value['max'] + 1
new_obj_id = obj_id
type_id_value['ids'].add(obj_id)
type_id_value['max'] = max(type_id_value['max'], obj_id)
return new_obj_id

class MergeData(object):

""" prepare the MergeField objects and the data """
Expand All @@ -371,7 +387,7 @@ class MergeData(object):
def __init__(self, remove_empty_tables=False, keep_fields="none"):
self._merge_field_map = {} # merge_field.key: MergeField()
self._merge_field_next_id = 0
self.duplicate_id_map = {} # tag: {'max': max_id, 'ids': set(existing_ids)}
self.unique_id_manager = UniqueIdsManager()
self.has_nested_fields = False
self.remove_empty_tables = remove_empty_tables
self.keep_fields = keep_fields
Expand Down Expand Up @@ -400,21 +416,14 @@ def is_first(self):

def get_new_element_id(self, element):
""" Returns None if the existing id is new otherwise a new id """
tag = element.tag
id = element.get('id')
if id is None:
# tag = element.tag
elem_id = element.get('id')
if elem_id is None:
return None
id = int(id)
id_data = self.duplicate_id_map.setdefault(tag, {'max': id, 'values': set()})
if id in id_data['values']:
# it already exists
id = id_data['max'] + 1
id_data['values'].add(id)
id_data['max'] = id
return str(id)

id_data['values'].add(id)
id_data['max'] = max(id, id_data['max'])
elem_id = int(elem_id)
new_id = self.unique_id_manager.register_id("id", elem_id)
if new_id:
return str(new_id)
return None

def get_merge_fields(self, key):
Expand Down Expand Up @@ -707,9 +716,8 @@ def __init__(self, file, remove_empty_tables=False, auto_update_fields_on_open="
keep_fields : none - merge all fields even if no data, some - keep fields with no data, all - keep all fields
"""
self.zip = ZipFile(file)
self.parts = {} # part: ElementTree
self.settings = None
self._settings_info = None
self.parts = {} # zi_part: ElementTree
self.categories = {} # category: [zi, ...]
self.merge_data = MergeData(remove_empty_tables=remove_empty_tables, keep_fields=keep_fields)
self.remove_empty_tables = remove_empty_tables
self.auto_update_fields_on_open = auto_update_fields_on_open
Expand All @@ -719,14 +727,29 @@ def __init__(self, file, remove_empty_tables=False, auto_update_fields_on_open="
try:
self.__fill_parts()

for part in self.parts.values():
for part in self.get_parts().values():
self.__fill_simple_fields(part)
self.__fill_complex_fields(part)

except:
self.zip.close()
raise

def get_parts(self, categories=None):
""" return all the parts based on categories """
if categories is None:
categories = ["main", "rel_part", "notes"]
elif isinstance(categories, str):
categories = [categories]
return {
zi: self.parts[zi]
for category in categories
for zi in self.categories.get(category, [])
}

def get_settings(self):
""" returns the settings part """
return list(self.get_parts(['settings']).values())[0]

def __setattr__(self, __name, __value):
super(MailMerge, self).__setattr__(__name, __value)
Expand All @@ -736,11 +759,11 @@ def __setattr__(self, __name, __value):
def __fill_parts(self):
content_types = etree.parse(self.zip.open('[Content_Types].xml'))
for file in content_types.findall('{%(ct)s}Override' % NAMESPACES):
type = file.attrib['ContentType' % NAMESPACES]
if type in CONTENT_TYPES_PARTS:
part_type = file.attrib['ContentType' % NAMESPACES]
category = CONTENT_TYPES_PARTS.get(part_type)
if category:
zi, self.parts[zi] = self.__get_tree_of_file(file)
elif type == CONTENT_TYPE_SETTINGS:
self._settings_info, self.settings = self.__get_tree_of_file(file)
self.categories.setdefault(category, []).append(zi)

def __fill_simple_fields(self, part):
for fld_simple_elem in part.findall('.//{%(w)s}fldSimple' % NAMESPACES):
Expand Down Expand Up @@ -842,8 +865,8 @@ def __fill_complex_fields(self, part):

def __fix_settings(self):

if self.settings:
settings_root = self.settings.getroot()
for settings in self.get_parts(categories=['settings']).values():
settings_root = settings.getroot()
if not self._has_unmerged_fields:
mail_merge = settings_root.find('{%(w)s}mailMerge' % NAMESPACES)
if mail_merge is not None:
Expand Down Expand Up @@ -888,15 +911,12 @@ def write(self, file, empty_value=''):
if zi in self.parts:
xml = etree.tostring(self.parts[zi].getroot(), encoding='UTF-8', xml_declaration=True)
output.writestr(zi.filename, xml)
elif zi == self._settings_info:
xml = etree.tostring(self.settings.getroot(), encoding='UTF-8', xml_declaration=True)
output.writestr(zi.filename, xml)
else:
output.writestr(zi.filename, self.zip.read(zi))

def get_merge_fields(self, parts=None):
if not parts:
parts = self.parts.values()
parts = self.get_parts().values()

fields = set()
for part in parts:
Expand Down Expand Up @@ -926,13 +946,8 @@ def merge_templates(self, replacements, separator):
# Duplicate template. Creates a copy of the template, does a merge, and separates them by a new paragraph, a new break or a new section break.

#GET ROOT - WORK WITH DOCUMENT
for part in self.parts.values():
for part in self.get_parts(["main"]).values():
root = part.getroot()
tag = root.tag

# ignore header, footer, footnotes, endnotes, etc
if tag in ATTACHMENT_TAGS_WITH_NAMESPACE:
continue

# the mailmerge is done with the help of the MergeDocument class
# that handles the document duplication
Expand Down Expand Up @@ -964,13 +979,13 @@ def merge(self, **replacements):
self._merge(replacements)

def _merge(self, replacements):
for part in self.parts.values():
for part in self.get_parts().values():
self.merge_data.replace(part, replacements)

def merge_rows(self, anchor, rows):
""" anchor is one of the fields in the table """

for part in self.parts.values():
for part in self.get_parts().values():
self.merge_data.replace_table_rows(part, anchor, rows)

def __enter__(self):
Expand Down
Binary file modified tests/test_footnote_header_footer.docx
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/test_keep_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_keep_fields_all(self):
)

self.assertListEqual(
document.settings.getroot().xpath(MERGE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
document.get_settings().getroot().xpath(MERGE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
[])
self.assertListEqual(
root_elem.xpath(TEXTS_XPATH, namespaces=NAMESPACES),
Expand All @@ -54,7 +54,7 @@ def test_keep_fields_some(self):
)

self.assertListEqual(
document.settings.getroot().xpath(MERGE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
document.get_settings().getroot().xpath(MERGE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
[])
self.assertListEqual(
root_elem.xpath(TEXTS_XPATH, namespaces=NAMESPACES),
Expand All @@ -77,7 +77,7 @@ def test_keep_fields_none(self):
)

self.assertListEqual(
document.settings.getroot().xpath(MERGE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
document.get_settings().getroot().xpath(MERGE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
[])
self.assertListEqual(
root_elem.xpath(TEXTS_XPATH, namespaces=NAMESPACES),
Expand Down
10 changes: 5 additions & 5 deletions tests/test_nested_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_outside_auto_update_fields(self):
# output="tests/output/test_output_nested_if_outside.docx"
)
self.assertListEqual(
document.settings.getroot().xpath(UPDATE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
document.get_settings().getroot().xpath(UPDATE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
[])

document, root_elem = self.merge(
Expand All @@ -55,7 +55,7 @@ def test_outside_auto_update_fields(self):
# output="tests/output/test_output_nested_if_outside.docx"
)
self.assertListEqual(
document.settings.getroot().xpath(UPDATE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
document.get_settings().getroot().xpath(UPDATE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
[])

document, root_elem = self.merge(
Expand All @@ -65,7 +65,7 @@ def test_outside_auto_update_fields(self):
# output="tests/output/test_output_nested_if_outside.docx"
)
self.assertListEqual(
document.settings.getroot().xpath(UPDATE_FIELDS_XPATH, namespaces=NAMESPACES),
document.get_settings().getroot().xpath(UPDATE_FIELDS_XPATH, namespaces=NAMESPACES),
["true"])

def test_field_inside(self):
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_inside_auto_update_fields(self):
# output="tests/output/test_output_nested_if_inside.docx"
)
self.assertListEqual(
document.settings.getroot().xpath(UPDATE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
document.get_settings().getroot().xpath(UPDATE_FIELDS_TRUE_XPATH, namespaces=NAMESPACES),
[])

document, root_elem = self.merge(
Expand All @@ -133,5 +133,5 @@ def test_inside_auto_update_fields(self):
# output="tests/output/test_output_nested_if_inside.docx"
)
self.assertListEqual(
document.settings.getroot().xpath(UPDATE_FIELDS_XPATH, namespaces=NAMESPACES),
document.get_settings().getroot().xpath(UPDATE_FIELDS_XPATH, namespaces=NAMESPACES),
["true"])
24 changes: 24 additions & 0 deletions tests/test_unique_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import unittest

from mailmerge import UniqueIdsManager

class UniqueIdsManagerTest(unittest.TestCase):
"""
Testing UniqueIdsManager class
"""

def test_unique_id_manager(self):
"""
Tests if the next record field works
"""
tests = [
("id", 2, None),
("id", 2, 3),
("id", None, 4),
("footer", 1, None),
("footer", None, 2)
]
id_man = UniqueIdsManager()

for type_id, obj_id, new_id in tests:
self.assertEqual(id_man.register_id(type_id, obj_id=obj_id), new_id)
2 changes: 1 addition & 1 deletion tests/test_winword2010.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,4 @@ def test(self):
)

self.assert_equal_tree(expected_tree, list(document.parts.values())[0].getroot())
self.assertIsNone(document.settings.getroot().find('{%(w)s}mailMerge' % NAMESPACES))
self.assertIsNone(document.get_settings().getroot().find('{%(w)s}mailMerge' % NAMESPACES))
1 change: 0 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from lxml import etree
from mailmerge import MailMerge, NAMESPACES, CONTENT_TYPES_PARTS

CONTENT_TYPE_DOCUMENT = CONTENT_TYPES_PARTS[0]
TEXTS_XPATH = "//w:t/text()"

class EtreeMixin(object):
Expand Down

0 comments on commit 22dc108

Please sign in to comment.