diff --git a/mailmerge.py b/mailmerge.py index bbc1bc9..74712a1 100644 --- a/mailmerge.py +++ b/mailmerge.py @@ -1,6 +1,7 @@ import warnings import shlex import re +import os import datetime # import locale from zipfile import ZipFile, ZIP_DEFLATED @@ -15,6 +16,8 @@ 'wp': 'http://schemas.openxmlformats.org/drawingml/2006/wordprocessingDrawing', 'mc': 'http://schemas.openxmlformats.org/markup-compatibility/2006', 'ct': 'http://schemas.openxmlformats.org/package/2006/content-types', + 'rr': 'http://schemas.openxmlformats.org/package/2006/relationships', + 'r': 'http://schemas.openxmlformats.org/officeDocument/2006/relationships', 'xml': 'http://www.w3.org/XML/1998/namespace' } @@ -33,8 +36,8 @@ CONTENT_TYPES_PARTS = { 'application/vnd.openxmlformats-officedocument.wordprocessingml.document.main+xml':'main', 'application/vnd.ms-word.document.macroEnabled.main+xml':'main', - 'application/vnd.openxmlformats-officedocument.wordprocessingml.header+xml':'rel_part', - 'application/vnd.openxmlformats-officedocument.wordprocessingml.footer+xml':'rel_part', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.header+xml':'header_footer', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.footer+xml':'header_footer', 'application/vnd.openxmlformats-officedocument.wordprocessingml.footnotes+xml':'notes', 'application/vnd.openxmlformats-officedocument.wordprocessingml.endnotes+xml':'notes', 'application/vnd.openxmlformats-officedocument.wordprocessingml.settings+xml':'settings' @@ -44,7 +47,9 @@ 'page_break', 'column_break', 'textWrapping_break', 'continuous_section', 'evenPage_section', 'nextColumn_section', 'nextPage_section', 'oddPage_section'} -NUMBERFORMAT_RE = r"([^0.,'#PN]+)?(P\d+|N\d+|[0.,'#]+%?)([^0.,'#%].*)?" +PARTFILENAME_RE = re.compile(r'([A-Za-z_]+)(\d+).xml') +IDSTR_RE = re.compile(r'([A-Za-z_]+)(\d+)') +NUMBERFORMAT_RE = re.compile(r"([^0.,'#PN]+)?(P\d+|N\d+|[0.,'#]+%?)([^0.,'#%].*)?") DATEFORMAT_RE = "|".join([r"{}+".format(switch) for switch in "yYmMdDhHsS"] + [r"am/pm", r"AM/PM"]) DATEFORMAT_MAP = { "M": "{d.month}", @@ -183,7 +188,7 @@ def _format_text(self, value, flag, option): return value def _format_number(self, value, flag, option): - format_match = re.match(NUMBERFORMAT_RE, option) + format_match = NUMBERFORMAT_RE.match(option) if value is None: value = 0 if format_match is None: @@ -373,8 +378,21 @@ def register_id(self, id_type, obj_id=None): new_obj_id = obj_id type_id_value['ids'].add(obj_id) type_id_value['max'] = max(type_id_value['max'], obj_id) + # print("registered", id_type, obj_id, new_obj_id, "max", type_id_value['max']) return new_obj_id + def register_id_str(self, id_str): + """ registers directly a string of format 'type1231' where the id_type is before the id """ + match = IDSTR_RE.match(id_str) + assert match + id_type, obj_id = match.groups() + new_obj_id = self.register_id(id_type, obj_id=int(obj_id)) + if new_obj_id is not None: + # print(id_type, obj_id, new_obj_id, self.id_type_map[id_type]) + return "%s%d" % (id_type, new_obj_id) + + + class MergeData(object): """ prepare the MergeField objects and the data """ @@ -580,6 +598,105 @@ def get_field_object(self, field_element, row): field_obj = self._merge_field_map[field_key] return field_obj + def fix_id(self, element, attr_gen): + """ will replace an id with a new unique id """ + new_id = self.get_new_element_id(element) + if new_id is not None: + element.attrib['id'] = new_id + for attr_name, attr_value in attr_gen.items(): + element.attrib[attr_name] = attr_value.format(id=new_id) + + def fix_ids(self, current_body): + """ will fix all ids in the current body """ + for tag, attr_gen in TAGS_WITH_ID.items(): + for elem in current_body.xpath('//{}'.format(tag), namespaces=NAMESPACES): + self.fix_id(elem, attr_gen) + + +class RelationsDocument(object): + """ handling relations document """ + + def __init__(self, rel_part): + self.rel_part = rel_part + + def replace_relation(self, merge_data, old_relation_elem, new_target): + root = self.rel_part.getroot() + new_relation = deepcopy(old_relation_elem) + # print(etree.tostring(new_relation)) + new_relation.attrib['Id'] = merge_data.unique_id_manager.register_id_str(new_relation.attrib['Id']) + # print(old_relation_elem.attrib['Id'], "->", new_relation.attrib['Id']) + new_relation.attrib['Target'] = new_target + root.append(new_relation) + return new_relation.attrib['Id'] + + def get_relation_elem(self, target): + """ returns the relation element for the """ + return self.rel_part.getroot().find('rr:Relationship[@Target="%s"]' % target, namespaces=NAMESPACES) + + def get_all(self): + """ returns all relations """ + return self.rel_part.getroot().xpath('rr:Relationship', namespaces=NAMESPACES) + +class MergeHeaderFooterDocument(object): + """ prepare and merge one Header/Footer document for merge_templates + + helper class to handle the actual merging of one header/footer document + It handles Header and Footer relation documents, for which you have to + create copies of documents and update relations. + """ + + def __init__(self, part_info, relations, separator): + self.part_content_type = part_info['file'] + self.zi = part_info['zi'] + self.part = part_info['part'] + self.relations = relations + self.sep_type = None + self.target, self.id_type, self.part_id = self._parse_part_filename(self.zi.filename) + self.new_parts = [] # list of (filename, root) parts + self._current_part = None + self.has_fields = bool(self.part.findall('.//MergeField')) + self._prepare_data(separator) + + def _parse_part_filename(self, filename): + filename = os.path.basename(filename) + match = PARTFILENAME_RE.match(filename) + assert match + return filename, *match.groups() + + def _prepare_data(self, separator): + if separator not in VALID_SEPARATORS: + raise ValueError("Invalid separator argument") + self.sep_type, _sep_class = separator.split("_") + + def prepare(self, merge_data, first=False): + if self.has_fields: + self._current_part = deepcopy(self.part) + + def merge(self, merge_data, row, first=False): + """ Merges one row into the current prepared body """ + if self.has_fields: + assert self._current_part is not None + merge_data.replace(self._current_part.getroot(), row) + + def finish(self, merge_data, abort=False): + """ finishes the current merge, by updating the relations """ + + if abort: # for skipping the record + self._current_part = None + + if self._current_part is not None: + # @TODO use the existing header/footers for the first section + new_id = merge_data.unique_id_manager.register_id(self.id_type) + new_target = self.target.replace(self.part_id, str(new_id)) + new_filename = self.zi.filename.replace(self.part_id, str(new_id)) + new_part_content_type = deepcopy(self.part_content_type) + new_part_content_type.attrib['PartName'] = self.part_content_type.attrib['PartName'].replace(self.target, new_target) + self.new_parts.append((new_filename, new_part_content_type, self._current_part)) + self._current_part = None + return [(self.target, new_target)] + + return [] + class MergeDocument(object): """ prepare and merge one document @@ -590,8 +707,10 @@ class MergeDocument(object): It prepares the body, sections, separators """ - def __init__(self, root, separator): + def __init__(self, merge_data, root, relations, separator): + self.merge_data = merge_data self.root = root + self.relations = relations # self.sep_type = sep_type # self.sep_class = sep_class # if sep_class == 'section': @@ -601,6 +720,8 @@ def __init__(self, root, separator): self._body = None # the document body, where all the documents are appended self._body_copy = None # a deep copy of the original body without ending section self._current_body = None # the current document body where all the changes are merged + self._current_separator = None + self._finish_rels = [] self._prepare_data(separator) def _prepare_data(self, separator): @@ -634,6 +755,7 @@ def _prepare_data(self, separator): self._body = self._last_section.getparent() self._body.remove(self._last_section) + self._last_section = deepcopy(self._last_section) # fix a bug self._body_copy = deepcopy(self._body) @@ -655,27 +777,36 @@ def prepare(self, merge_data, first=False): assert self._current_body is None # add separator if not the first document if not first: - self._body.append(deepcopy(self._separator)) + # @TODO replace the relation references in the full body, not only in the + # separator + # @TODO refactor the whole preparation process, so it is straightforward + # and doesn't look like a hack + for old_target, new_target in self._finish_rels: + self.replace_relation_reference(merge_data, old_target, new_target) + self._body.append(self._current_separator) + self._current_separator = deepcopy(self._separator) self._current_body = deepcopy(self._body_copy) - for tag, attr_gen in TAGS_WITH_ID.items(): - for elem in self._current_body.xpath('//{}'.format(tag), namespaces=NAMESPACES): - self._fix_id(merge_data, elem, attr_gen) + merge_data.fix_ids(self._current_body) def merge(self, merge_data, row, first=False): """ Merges one row into the current prepared body """ merge_data.replace(self._current_body, row) - def _fix_id(self, merge_data, element, attr_gen): - new_id = merge_data.get_new_element_id(element) - if new_id is not None: - element.attrib['id'] = new_id - for attr_name, attr_value in attr_gen.items(): - element.attrib[attr_name] = attr_value.format(id=new_id) + def replace_relation_reference(self, merge_data, old_target, new_target, sep=None): + # assert self._current_body is not None + if sep is None: + sep = self._current_separator + + old_relation = self.relations.get_relation_elem(old_target) + new_rel_id = self.relations.replace_relation(merge_data, old_relation, new_target) - def finish(self, abort=False): - """ finishes the current body by saving it into the main body or into a file (future feature) """ + for elem in sep.xpath('//*[@r:id="%s"]' % old_relation.attrib['Id'], namespaces=NAMESPACES): + elem.attrib['{%(r)s}id' % NAMESPACES] = new_rel_id + def finish(self, finish_rels, abort=False): + """ finishes the current body by saving it into the main body or into a file (future feature) """ + self._finish_rels = finish_rels if abort: # for skipping the record self._current_body = None @@ -689,7 +820,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: - self.finish(True) + # self.finish(True) + for old_target, new_target in self._finish_rels: + self.replace_relation_reference(self.merge_data, old_target, new_target, sep=self._last_section) self._body.append(self._last_section) class MailMerge(object): @@ -717,6 +850,7 @@ def __init__(self, file, remove_empty_tables=False, auto_update_fields_on_open=" """ self.zip = ZipFile(file) self.parts = {} # zi_part: ElementTree + self.new_parts = [] # list of [(filename, part)] self.categories = {} # category: [zi, ...] self.merge_data = MergeData(remove_empty_tables=remove_empty_tables, keep_fields=keep_fields) self.remove_empty_tables = remove_empty_tables @@ -727,9 +861,9 @@ def __init__(self, file, remove_empty_tables=False, auto_update_fields_on_open=" try: self.__fill_parts() - for part in self.get_parts().values(): - self.__fill_simple_fields(part) - self.__fill_complex_fields(part) + for part_info in self.get_parts(): + self.__fill_simple_fields(part_info['part']) + self.__fill_complex_fields(part_info['part']) except: self.zip.close() @@ -738,18 +872,36 @@ def __init__(self, file, remove_empty_tables=False, auto_update_fields_on_open=" def get_parts(self, categories=None): """ return all the parts based on categories """ if categories is None: - categories = ["main", "rel_part", "notes"] + categories = ["main", "header_footer", "notes"] elif isinstance(categories, str): categories = [categories] - return { - zi: self.parts[zi] + return [ + self.parts[zi] for category in categories for zi in self.categories.get(category, []) - } + ] def get_settings(self): """ returns the settings part """ - return list(self.get_parts(['settings']).values())[0] + return self.parts[self.categories['settings'][0]]['part'] + + def get_content_types(self): + """" returns the content types part """ + return self.parts[self.categories['content_types'][0]]['part'] + + def get_relations(self, part_zi): + """ returns the """ + rel_fn = "word/_rels/%s.rels" % os.path.basename(part_zi.filename) + if rel_fn in self.zip.namelist(): + zi = self.zip.getinfo(rel_fn) + rel_root = etree.parse(self.zip.open(zi)) + self.parts[zi] = dict(zi=zi, part=rel_root) + relations = RelationsDocument(rel_root) + for relation in relations.get_all(): + self.merge_data.unique_id_manager.register_id_str(relation.attrib['Id']) + return relations + # else: + # print(rel_fn, self.zip.namelist()) def __setattr__(self, __name, __value): super(MailMerge, self).__setattr__(__name, __value) @@ -757,7 +909,10 @@ def __setattr__(self, __name, __value): self.merge_data.remove_empty_tables = __value def __fill_parts(self): - content_types = etree.parse(self.zip.open('[Content_Types].xml')) + content_types_zi = self.zip.getinfo('[Content_Types].xml') + content_types = etree.parse(self.zip.open(content_types_zi)) + self.categories['content_types'] = [content_types_zi] + self.parts[content_types_zi] = dict(part=content_types) for file in content_types.findall('{%(ct)s}Override' % NAMESPACES): part_type = file.attrib['ContentType' % NAMESPACES] category = CONTENT_TYPES_PARTS.get(part_type) @@ -865,7 +1020,8 @@ def __fill_complex_fields(self, part): def __fix_settings(self): - for settings in self.get_parts(categories=['settings']).values(): + settings = self.get_settings() + if settings: settings_root = settings.getroot() if not self._has_unmerged_fields: mail_merge = settings_root.find('{%(w)s}mailMerge' % NAMESPACES) @@ -885,7 +1041,7 @@ def __fix_settings(self): def __get_tree_of_file(self, file): fn = file.attrib['PartName' % NAMESPACES].split('/', 1)[1] zi = self.zip.getinfo(fn) - return zi, etree.parse(self.zip.open(zi)) + return zi, dict(zi=zi, file=file, part=etree.parse(self.zip.open(zi))) def write(self, file, empty_value=''): self._has_unmerged_fields = bool(self.get_merge_fields()) @@ -906,21 +1062,34 @@ def write(self, file, empty_value=''): # Remove mail merge settings to avoid error messages when opening document in Winword self.__fix_settings() + # add the new files in the content types + content_types = self.get_content_types().getroot() + for _filename, part_content_type, _part in self.new_parts: + content_types.append(part_content_type) + with ZipFile(file, 'w', ZIP_DEFLATED) as output: for zi in self.zip.filelist: if zi in self.parts: - xml = etree.tostring(self.parts[zi].getroot(), encoding='UTF-8', xml_declaration=True) + xml = etree.tostring(self.parts[zi]['part'].getroot(), encoding='UTF-8', xml_declaration=True) output.writestr(zi.filename, xml) else: output.writestr(zi.filename, self.zip.read(zi)) - def get_merge_fields(self, parts=None): + for filename, _part_content_type, part in self.new_parts: + xml = etree.tostring(part.getroot(), encoding='UTF-8', xml_declaration=True) + output.writestr(filename, xml) + + def get_merge_fields(self): + """" get the fields from the document """ + return self._get_merge_fields() + + def _get_merge_fields(self, parts=None): if not parts: - parts = self.get_parts().values() + parts = self.get_parts() fields = set() for part in parts: - for mf in part.findall('.//MergeField'): + for mf in part['part'].findall('.//MergeField'): fields.add(mf.attrib['name']) # for name in self.merge_data.get_merge_fields(mf.attrib['merge_key']): # fields.add(name) @@ -943,24 +1112,45 @@ def merge_templates(self, replacements, separator): assert replacements, "empty data" #TYPE PARAM CONTROL AND SPLIT + # prepare the side documents, like headers, footers, etc + rel_docs = [] + for part_info in self.get_parts(["header_footer"]): + relations = self.get_relations(part_info['zi']) + merge_header_footer_doc = MergeHeaderFooterDocument(part_info, relations, separator) + rel_docs.append(merge_header_footer_doc) + self.merge_data.unique_id_manager.register_id(merge_header_footer_doc.id_type, int(merge_header_footer_doc.part_id)) + # Duplicate template. Creates a copy of the template, does a merge, and separates them by a new paragraph, a new break or a new section break. #GET ROOT - WORK WITH DOCUMENT - for part in self.get_parts(["main"]).values(): - root = part.getroot() + for part_info in self.get_parts(["main"]): + root = part_info['part'].getroot() + relations = self.get_relations(part_info['zi']) # the mailmerge is done with the help of the MergeDocument class # that handles the document duplication - with MergeDocument(root, separator) as merge_doc: + with MergeDocument(self.merge_data, root, relations, separator) as merge_doc: row = self.merge_data.start_merge(replacements) while row is not None: merge_doc.prepare(self.merge_data, first=self.merge_data.is_first()) + + finish_rels = [] + for rel_doc in rel_docs: + rel_doc.prepare(self.merge_data, first=self.merge_data.is_first()) + rel_doc.merge(self.merge_data, row) + finish_rels.extend(rel_doc.finish(self.merge_data)) + try: merge_doc.merge(self.merge_data, row) - merge_doc.finish() + merge_doc.finish(finish_rels) except SkipRecord: - merge_doc.finish(abort=True) + merge_doc.finish(finish_rels, abort=True) + row = self.merge_data.next_row() + + # add all new files in the zip + for rel_doc in rel_docs: + self.new_parts.extend(rel_doc.new_parts) def merge_pages(self, replacements): """ @@ -979,14 +1169,17 @@ def merge(self, **replacements): self._merge(replacements) def _merge(self, replacements): - for part in self.get_parts().values(): + for part_info in self.get_parts(): + self.merge_data.replace(part_info['part'], replacements) + + for _filename, _part_content_type, part in self.new_parts: self.merge_data.replace(part, replacements) def merge_rows(self, anchor, rows): """ anchor is one of the fields in the table """ - for part in self.get_parts().values(): - self.merge_data.replace_table_rows(part, anchor, rows) + for part_info in self.get_parts(): + self.merge_data.replace_table_rows(part_info['part'], anchor, rows) def __enter__(self): return self diff --git a/tests/test_footer.docx b/tests/test_footer.docx new file mode 100644 index 0000000..8c68964 Binary files /dev/null and b/tests/test_footer.docx differ diff --git a/tests/test_footnote_header_footer.docx b/tests/test_footnote_header_footer.docx index 72178f2..6805a52 100644 Binary files a/tests/test_footnote_header_footer.docx and b/tests/test_footnote_header_footer.docx differ diff --git a/tests/test_footnote_header_footer.py b/tests/test_footnote_header_footer.py index b67d2d0..bfb7bee 100644 --- a/tests/test_footnote_header_footer.py +++ b/tests/test_footnote_header_footer.py @@ -1,18 +1,14 @@ import unittest -import tempfile -import warnings -import datetime -import locale -from os import path -from lxml import etree from mailmerge import NAMESPACES -from tests.utils import EtreeMixin, get_document_body_part +from tests.utils import EtreeMixin, get_document_body_part, TEXTS_XPATH, get_document_body_parts FOOTNOTE_XPATH = "//w:footnote[@w:id = '1']/w:p/w:r/w:t/text()" class FootnoteHeaderFooterTest(EtreeMixin, unittest.TestCase): - + # @TODO test missing values + # @TODO test if separator isn't section + # @TODO test headers/footers with relations @unittest.expectedFailure def test_all(self): values = ["one", "two", "three"] @@ -20,12 +16,54 @@ def test_all(self): # fix this when it is implemented document, root_elem = self.merge_templates( 'test_footnote_header_footer.docx', - [{"fieldname": value, "footerfield": "f_" + value, "headerfield": "h_" + value} + [ + { + "fieldname": value, + "footerfield": "f_" + value, + "headerfield": "h_" + value, + "footerfirst": "ff_" + value, + "headerfirst": "hf_" + value, + "footereven": "fe_" + value, + "headereven": "he_" + value, + } for value in values ], - # output="tests/test_output_footnote_header_footer.docx" + separator="nextPage_section", + # output="tests/output/test_output_footnote_header_footer.docx" ) + footers = sorted([ + "".join(footer_doc_tree.getroot().xpath(TEXTS_XPATH, namespaces=NAMESPACES)) + for footer_doc_tree in get_document_body_parts(document, endswith="ftr") + ]) + self.assertListEqual(footers, sorted([ + footer + for value in values + for footer in [ + 'Footer on even page fe_%s' % value, + 'Footer on every page f_%s' % value, + 'Footer on first page ff_%s' % value]] + [ + 'Footer on even page ', + 'Footer on every page ', + 'Footer on first page '] + )) + + headers = sorted([ + "".join(header_doc_tree.getroot().xpath(TEXTS_XPATH, namespaces=NAMESPACES)) + for header_doc_tree in get_document_body_parts(document, endswith="hdr") + ]) + self.assertListEqual(headers, sorted([ + header + for value in values + for header in [ + 'Header even: he_%s' % value, + 'Header on every page: h_%s' % value, + 'Header on first page: hf_%s' % value]] + [ + 'Header even: ', + 'Header on every page: ', + 'Header on first page: '] + )) + footnote_root_elem = get_document_body_part(document, "footnotes").getroot() footnote = "".join(footnote_root_elem.xpath(FOOTNOTE_XPATH, namespaces=NAMESPACES)) correct_footnote = " Merge : one " @@ -35,12 +73,57 @@ def test_only_merge(self): values = ["one", "two", "three"] document, root_elem = self.merge( 'test_footnote_header_footer.docx', - [{"fieldname": value, "footerfield": "f_" + value, "headerfield": "h_" + value} + [ + { + "fieldname": value, + "footerfield": "f_" + value, + "headerfield": "h_" + value, + "footerfirst": "ff_" + value, + "headerfirst": "hf_" + value, + "footereven": "fe_" + value, + "headereven": "he_" + value, + } for value in values ][0], - # output="tests/test_output_one_footnote_header_footer.docx" + # output="tests/output/test_output_one_footnote_header_footer.docx" ) footnote_root_elem = get_document_body_part(document, "footnotes").getroot() footnote = "".join(footnote_root_elem.xpath(FOOTNOTE_XPATH, namespaces=NAMESPACES)) self.assertEqual(footnote, " Merge : one") + + footers = sorted([ + "".join(footer_doc_tree.getroot().xpath(TEXTS_XPATH, namespaces=NAMESPACES)) + for footer_doc_tree in get_document_body_parts(document, endswith="ftr") + ]) + value = values[0] + self.assertListEqual(footers, [ + 'Footer on even page fe_%s' % value, + 'Footer on every page f_%s' % value, + 'Footer on first page ff_%s' % value]) + + headers = sorted([ + "".join(header_doc_tree.getroot().xpath(TEXTS_XPATH, namespaces=NAMESPACES)) + for header_doc_tree in get_document_body_parts(document, endswith="hdr") + ]) + self.assertListEqual(headers, [ + 'Header even: he_%s' % value, + 'Header on every page: h_%s' % value, + 'Header on first page: hf_%s' % value]) + + + def test_footer(self): + values = ["one", "two"] + # header/footer/footnotes don't work with multiple replacements, only with merge + # fix this when it is implemented + document, root_elem = self.merge_templates( + 'test_footer.docx', + [ + { + "footer": value, + } + for value in values + ], + separator="nextPage_section", + # output="tests/output/test_footer.docx" + ) diff --git a/tests/test_macword2011.py b/tests/test_macword2011.py index 64e33c0..72409a1 100644 --- a/tests/test_macword2011.py +++ b/tests/test_macword2011.py @@ -4,7 +4,7 @@ from lxml import etree from mailmerge import MailMerge -from tests.utils import EtreeMixin +from tests.utils import EtreeMixin, get_document_body_part class MacWord2011Test(EtreeMixin, unittest.TestCase): @@ -57,4 +57,4 @@ def test(self): '/>') - self.assert_equal_tree(expected_tree, list(document.parts.values())[0].getroot()) + self.assert_equal_tree(expected_tree, get_document_body_part(document).getroot()) diff --git a/tests/test_merge_table_multipart.py b/tests/test_merge_table_multipart.py index 14860b2..eabcbc0 100644 --- a/tests/test_merge_table_multipart.py +++ b/tests/test_merge_table_multipart.py @@ -4,7 +4,7 @@ from lxml import etree from mailmerge import MailMerge -from tests.utils import EtreeMixin +from tests.utils import EtreeMixin, get_document_body_part class MergeTableRowsMultipartTest(EtreeMixin, unittest.TestCase): @@ -32,10 +32,7 @@ def test_merge_rows_on_multipart_file(self): with tempfile.TemporaryFile() as outfile: self.document.write(outfile) - for part in self.document.parts.values(): - # only check the document part - if (part.getroot().tag == '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}document'): - self.assert_equal_tree(self.expected_tree, part.getroot()) + self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot()) def test_merge_unified_on_multipart_file(self): self.document.merge( @@ -52,10 +49,7 @@ def test_merge_unified_on_multipart_file(self): with tempfile.TemporaryFile() as outfile: self.document.write(outfile) - for part in self.document.parts.values(): - # only check the document part - if (part.getroot().tag == '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}document'): - self.assert_equal_tree(self.expected_tree, part.getroot()) + self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot()) def tearDown(self): self.document.close() diff --git a/tests/test_merge_table_rows.py b/tests/test_merge_table_rows.py index 399b8a9..575018e 100644 --- a/tests/test_merge_table_rows.py +++ b/tests/test_merge_table_rows.py @@ -4,7 +4,7 @@ from lxml import etree from mailmerge import MailMerge, NAMESPACES -from tests.utils import EtreeMixin +from tests.utils import EtreeMixin, get_document_body_part class MergeTableRowsTest(EtreeMixin, unittest.TestCase): @@ -31,8 +31,7 @@ def test_merge_rows(self): with tempfile.TemporaryFile() as outfile: self.document.write(outfile) - self.assert_equal_tree(self.expected_tree, - list(self.document.parts.values())[0].getroot()) + self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot()) def test_merge_rows_no_table(self): """ @@ -56,8 +55,7 @@ def test_merge_rows_no_table(self): with tempfile.TemporaryFile() as outfile: self.document.write(outfile) - self.assert_equal_tree(self.expected_tree, - list(self.document.parts.values())[0].getroot()) + self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot()) def test_merge_rows_remove_table(self): """ @@ -74,7 +72,7 @@ def test_merge_rows_remove_table(self): with tempfile.TemporaryFile() as outfile: self.document.write(outfile) self.assertIsNone( - list(self.document.parts.values())[0].getroot().find('.//{%(w)s}tbl' % NAMESPACES) + get_document_body_part(self.document).getroot().find('.//{%(w)s}tbl' % NAMESPACES) ) def test_merge_unified(self): @@ -92,8 +90,7 @@ def test_merge_unified(self): with tempfile.TemporaryFile() as outfile: self.document.write(outfile) - self.assert_equal_tree(self.expected_tree, - list(self.document.parts.values())[0].getroot()) + self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot()) def test_merge_template_with_rows(self): """ @@ -125,8 +122,7 @@ def test_merge_template_with_rows(self): with tempfile.TemporaryFile() as outfile: self.document.write(outfile) expected_tree = etree.fromstring('GradesBouke Haarsma received the grades for in the table below.Class CodeClass NameGradeECON101Economics 101AECONADVEconomics AdvancedBOPRESOperations ResearchATHESISFinal thesisAGradesJon Snow received the grades for in the table below.Class CodeClass NameGradeECON101Economics 101EECONADVEconomics AdvancedFTHESISFinal thesisB') - self.assert_equal_tree(expected_tree, - list(self.document.parts.values())[0].getroot()) + self.assert_equal_tree(expected_tree, get_document_body_part(self.document).getroot()) def tearDown(self): self.document.close() diff --git a/tests/test_multiple_elements.py b/tests/test_multiple_elements.py index 21f5250..8ece76b 100644 --- a/tests/test_multiple_elements.py +++ b/tests/test_multiple_elements.py @@ -4,7 +4,7 @@ from lxml import etree from mailmerge import MailMerge -from tests.utils import EtreeMixin +from tests.utils import EtreeMixin, get_document_body_part class MultipleElementsTest(EtreeMixin, unittest.TestCase): @@ -49,4 +49,4 @@ def test(self): 'w:bottom="1440" w:left="1440" w:header="708" w:footer="708" w:gutter="0"/>') - self.assert_equal_tree(expected_tree, list(document.parts.values())[0].getroot()) \ No newline at end of file + self.assert_equal_tree(expected_tree, get_document_body_part(document).getroot()) diff --git a/tests/test_unique_id.py b/tests/test_unique_id.py index 6f10173..f4d4457 100644 --- a/tests/test_unique_id.py +++ b/tests/test_unique_id.py @@ -7,7 +7,7 @@ class UniqueIdsManagerTest(unittest.TestCase): Testing UniqueIdsManager class """ - def test_unique_id_manager(self): + def test_unique_id_manager_register_id(self): """ Tests if the next record field works """ @@ -22,3 +22,5 @@ def test_unique_id_manager(self): for type_id, obj_id, new_id in tests: self.assertEqual(id_man.register_id(type_id, obj_id=obj_id), new_id) + + self.assertEqual(id_man.register_id_str("footer2"), "footer3") diff --git a/tests/test_winword2010.py b/tests/test_winword2010.py index b261e7e..f74d140 100644 --- a/tests/test_winword2010.py +++ b/tests/test_winword2010.py @@ -4,7 +4,7 @@ from lxml import etree from mailmerge import MailMerge, NAMESPACES -from tests.utils import EtreeMixin +from tests.utils import EtreeMixin, get_document_body_part class Windword2010Test(EtreeMixin, unittest.TestCase): @@ -178,5 +178,5 @@ def test(self): '' # noqa ) - self.assert_equal_tree(expected_tree, list(document.parts.values())[0].getroot()) + self.assert_equal_tree(expected_tree, get_document_body_part(document).getroot()) self.assertIsNone(document.get_settings().getroot().find('{%(w)s}mailMerge' % NAMESPACES)) diff --git a/tests/utils.py b/tests/utils.py index f1385fd..f7bf272 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -100,7 +100,18 @@ def get_new_docx(self, replacement_parts): def get_document_body_part(document, endswith="document"): for part in document.parts.values(): - if part.getroot().tag.endswith('}%s' % endswith): - return part + if part['part'].getroot().tag.endswith('}%s' % endswith): + return part['part'] raise AssertionError("main document body not found in document.parts") + +def get_document_body_parts(document, endswith="document"): + parts = [] + for part in document.parts.values(): + if part['part'].getroot().tag.endswith('}%s' % endswith): + parts.append(part['part']) + for _, _, part in document.new_parts: + if part.getroot().tag.endswith('}%s' % endswith): + parts.append(part) + + return parts