diff --git a/mailmerge.py b/mailmerge.py
index bbc1bc9..74712a1 100644
--- a/mailmerge.py
+++ b/mailmerge.py
@@ -1,6 +1,7 @@
import warnings
import shlex
import re
+import os
import datetime
# import locale
from zipfile import ZipFile, ZIP_DEFLATED
@@ -15,6 +16,8 @@
'wp': 'http://schemas.openxmlformats.org/drawingml/2006/wordprocessingDrawing',
'mc': 'http://schemas.openxmlformats.org/markup-compatibility/2006',
'ct': 'http://schemas.openxmlformats.org/package/2006/content-types',
+ 'rr': 'http://schemas.openxmlformats.org/package/2006/relationships',
+ 'r': 'http://schemas.openxmlformats.org/officeDocument/2006/relationships',
'xml': 'http://www.w3.org/XML/1998/namespace'
}
@@ -33,8 +36,8 @@
CONTENT_TYPES_PARTS = {
'application/vnd.openxmlformats-officedocument.wordprocessingml.document.main+xml':'main',
'application/vnd.ms-word.document.macroEnabled.main+xml':'main',
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.header+xml':'rel_part',
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.footer+xml':'rel_part',
+ 'application/vnd.openxmlformats-officedocument.wordprocessingml.header+xml':'header_footer',
+ 'application/vnd.openxmlformats-officedocument.wordprocessingml.footer+xml':'header_footer',
'application/vnd.openxmlformats-officedocument.wordprocessingml.footnotes+xml':'notes',
'application/vnd.openxmlformats-officedocument.wordprocessingml.endnotes+xml':'notes',
'application/vnd.openxmlformats-officedocument.wordprocessingml.settings+xml':'settings'
@@ -44,7 +47,9 @@
'page_break', 'column_break', 'textWrapping_break',
'continuous_section', 'evenPage_section', 'nextColumn_section', 'nextPage_section', 'oddPage_section'}
-NUMBERFORMAT_RE = r"([^0.,'#PN]+)?(P\d+|N\d+|[0.,'#]+%?)([^0.,'#%].*)?"
+PARTFILENAME_RE = re.compile(r'([A-Za-z_]+)(\d+).xml')
+IDSTR_RE = re.compile(r'([A-Za-z_]+)(\d+)')
+NUMBERFORMAT_RE = re.compile(r"([^0.,'#PN]+)?(P\d+|N\d+|[0.,'#]+%?)([^0.,'#%].*)?")
DATEFORMAT_RE = "|".join([r"{}+".format(switch) for switch in "yYmMdDhHsS"] + [r"am/pm", r"AM/PM"])
DATEFORMAT_MAP = {
"M": "{d.month}",
@@ -183,7 +188,7 @@ def _format_text(self, value, flag, option):
return value
def _format_number(self, value, flag, option):
- format_match = re.match(NUMBERFORMAT_RE, option)
+ format_match = NUMBERFORMAT_RE.match(option)
if value is None:
value = 0
if format_match is None:
@@ -373,8 +378,21 @@ def register_id(self, id_type, obj_id=None):
new_obj_id = obj_id
type_id_value['ids'].add(obj_id)
type_id_value['max'] = max(type_id_value['max'], obj_id)
+ # print("registered", id_type, obj_id, new_obj_id, "max", type_id_value['max'])
return new_obj_id
+ def register_id_str(self, id_str):
+ """ registers directly a string of format 'type1231' where the id_type is before the id """
+ match = IDSTR_RE.match(id_str)
+ assert match
+ id_type, obj_id = match.groups()
+ new_obj_id = self.register_id(id_type, obj_id=int(obj_id))
+ if new_obj_id is not None:
+ # print(id_type, obj_id, new_obj_id, self.id_type_map[id_type])
+ return "%s%d" % (id_type, new_obj_id)
+
+
+
class MergeData(object):
""" prepare the MergeField objects and the data """
@@ -580,6 +598,105 @@ def get_field_object(self, field_element, row):
field_obj = self._merge_field_map[field_key]
return field_obj
+ def fix_id(self, element, attr_gen):
+ """ will replace an id with a new unique id """
+ new_id = self.get_new_element_id(element)
+ if new_id is not None:
+ element.attrib['id'] = new_id
+ for attr_name, attr_value in attr_gen.items():
+ element.attrib[attr_name] = attr_value.format(id=new_id)
+
+ def fix_ids(self, current_body):
+ """ will fix all ids in the current body """
+ for tag, attr_gen in TAGS_WITH_ID.items():
+ for elem in current_body.xpath('//{}'.format(tag), namespaces=NAMESPACES):
+ self.fix_id(elem, attr_gen)
+
+
+class RelationsDocument(object):
+ """ handling relations document """
+
+ def __init__(self, rel_part):
+ self.rel_part = rel_part
+
+ def replace_relation(self, merge_data, old_relation_elem, new_target):
+ root = self.rel_part.getroot()
+ new_relation = deepcopy(old_relation_elem)
+ # print(etree.tostring(new_relation))
+ new_relation.attrib['Id'] = merge_data.unique_id_manager.register_id_str(new_relation.attrib['Id'])
+ # print(old_relation_elem.attrib['Id'], "->", new_relation.attrib['Id'])
+ new_relation.attrib['Target'] = new_target
+ root.append(new_relation)
+ return new_relation.attrib['Id']
+
+ def get_relation_elem(self, target):
+ """ returns the relation element for the """
+ return self.rel_part.getroot().find('rr:Relationship[@Target="%s"]' % target, namespaces=NAMESPACES)
+
+ def get_all(self):
+ """ returns all relations """
+ return self.rel_part.getroot().xpath('rr:Relationship', namespaces=NAMESPACES)
+
+class MergeHeaderFooterDocument(object):
+ """ prepare and merge one Header/Footer document for merge_templates
+
+ helper class to handle the actual merging of one header/footer document
+ It handles Header and Footer relation documents, for which you have to
+ create copies of documents and update relations.
+ """
+
+ def __init__(self, part_info, relations, separator):
+ self.part_content_type = part_info['file']
+ self.zi = part_info['zi']
+ self.part = part_info['part']
+ self.relations = relations
+ self.sep_type = None
+ self.target, self.id_type, self.part_id = self._parse_part_filename(self.zi.filename)
+ self.new_parts = [] # list of (filename, root) parts
+ self._current_part = None
+ self.has_fields = bool(self.part.findall('.//MergeField'))
+ self._prepare_data(separator)
+
+ def _parse_part_filename(self, filename):
+ filename = os.path.basename(filename)
+ match = PARTFILENAME_RE.match(filename)
+ assert match
+ return filename, *match.groups()
+
+ def _prepare_data(self, separator):
+ if separator not in VALID_SEPARATORS:
+ raise ValueError("Invalid separator argument")
+ self.sep_type, _sep_class = separator.split("_")
+
+ def prepare(self, merge_data, first=False):
+ if self.has_fields:
+ self._current_part = deepcopy(self.part)
+
+ def merge(self, merge_data, row, first=False):
+ """ Merges one row into the current prepared body """
+ if self.has_fields:
+ assert self._current_part is not None
+ merge_data.replace(self._current_part.getroot(), row)
+
+ def finish(self, merge_data, abort=False):
+ """ finishes the current merge, by updating the relations """
+
+ if abort: # for skipping the record
+ self._current_part = None
+
+ if self._current_part is not None:
+ # @TODO use the existing header/footers for the first section
+ new_id = merge_data.unique_id_manager.register_id(self.id_type)
+ new_target = self.target.replace(self.part_id, str(new_id))
+ new_filename = self.zi.filename.replace(self.part_id, str(new_id))
+ new_part_content_type = deepcopy(self.part_content_type)
+ new_part_content_type.attrib['PartName'] = self.part_content_type.attrib['PartName'].replace(self.target, new_target)
+ self.new_parts.append((new_filename, new_part_content_type, self._current_part))
+ self._current_part = None
+ return [(self.target, new_target)]
+
+ return []
+
class MergeDocument(object):
""" prepare and merge one document
@@ -590,8 +707,10 @@ class MergeDocument(object):
It prepares the body, sections, separators
"""
- def __init__(self, root, separator):
+ def __init__(self, merge_data, root, relations, separator):
+ self.merge_data = merge_data
self.root = root
+ self.relations = relations
# self.sep_type = sep_type
# self.sep_class = sep_class
# if sep_class == 'section':
@@ -601,6 +720,8 @@ def __init__(self, root, separator):
self._body = None # the document body, where all the documents are appended
self._body_copy = None # a deep copy of the original body without ending section
self._current_body = None # the current document body where all the changes are merged
+ self._current_separator = None
+ self._finish_rels = []
self._prepare_data(separator)
def _prepare_data(self, separator):
@@ -634,6 +755,7 @@ def _prepare_data(self, separator):
self._body = self._last_section.getparent()
self._body.remove(self._last_section)
+ self._last_section = deepcopy(self._last_section) # fix a bug
self._body_copy = deepcopy(self._body)
@@ -655,27 +777,36 @@ def prepare(self, merge_data, first=False):
assert self._current_body is None
# add separator if not the first document
if not first:
- self._body.append(deepcopy(self._separator))
+ # @TODO replace the relation references in the full body, not only in the
+ # separator
+ # @TODO refactor the whole preparation process, so it is straightforward
+ # and doesn't look like a hack
+ for old_target, new_target in self._finish_rels:
+ self.replace_relation_reference(merge_data, old_target, new_target)
+ self._body.append(self._current_separator)
+ self._current_separator = deepcopy(self._separator)
self._current_body = deepcopy(self._body_copy)
- for tag, attr_gen in TAGS_WITH_ID.items():
- for elem in self._current_body.xpath('//{}'.format(tag), namespaces=NAMESPACES):
- self._fix_id(merge_data, elem, attr_gen)
+ merge_data.fix_ids(self._current_body)
def merge(self, merge_data, row, first=False):
""" Merges one row into the current prepared body """
merge_data.replace(self._current_body, row)
- def _fix_id(self, merge_data, element, attr_gen):
- new_id = merge_data.get_new_element_id(element)
- if new_id is not None:
- element.attrib['id'] = new_id
- for attr_name, attr_value in attr_gen.items():
- element.attrib[attr_name] = attr_value.format(id=new_id)
+ def replace_relation_reference(self, merge_data, old_target, new_target, sep=None):
+ # assert self._current_body is not None
+ if sep is None:
+ sep = self._current_separator
+
+ old_relation = self.relations.get_relation_elem(old_target)
+ new_rel_id = self.relations.replace_relation(merge_data, old_relation, new_target)
- def finish(self, abort=False):
- """ finishes the current body by saving it into the main body or into a file (future feature) """
+ for elem in sep.xpath('//*[@r:id="%s"]' % old_relation.attrib['Id'], namespaces=NAMESPACES):
+ elem.attrib['{%(r)s}id' % NAMESPACES] = new_rel_id
+ def finish(self, finish_rels, abort=False):
+ """ finishes the current body by saving it into the main body or into a file (future feature) """
+ self._finish_rels = finish_rels
if abort: # for skipping the record
self._current_body = None
@@ -689,7 +820,9 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
- self.finish(True)
+ # self.finish(True)
+ for old_target, new_target in self._finish_rels:
+ self.replace_relation_reference(self.merge_data, old_target, new_target, sep=self._last_section)
self._body.append(self._last_section)
class MailMerge(object):
@@ -717,6 +850,7 @@ def __init__(self, file, remove_empty_tables=False, auto_update_fields_on_open="
"""
self.zip = ZipFile(file)
self.parts = {} # zi_part: ElementTree
+ self.new_parts = [] # list of [(filename, part)]
self.categories = {} # category: [zi, ...]
self.merge_data = MergeData(remove_empty_tables=remove_empty_tables, keep_fields=keep_fields)
self.remove_empty_tables = remove_empty_tables
@@ -727,9 +861,9 @@ def __init__(self, file, remove_empty_tables=False, auto_update_fields_on_open="
try:
self.__fill_parts()
- for part in self.get_parts().values():
- self.__fill_simple_fields(part)
- self.__fill_complex_fields(part)
+ for part_info in self.get_parts():
+ self.__fill_simple_fields(part_info['part'])
+ self.__fill_complex_fields(part_info['part'])
except:
self.zip.close()
@@ -738,18 +872,36 @@ def __init__(self, file, remove_empty_tables=False, auto_update_fields_on_open="
def get_parts(self, categories=None):
""" return all the parts based on categories """
if categories is None:
- categories = ["main", "rel_part", "notes"]
+ categories = ["main", "header_footer", "notes"]
elif isinstance(categories, str):
categories = [categories]
- return {
- zi: self.parts[zi]
+ return [
+ self.parts[zi]
for category in categories
for zi in self.categories.get(category, [])
- }
+ ]
def get_settings(self):
""" returns the settings part """
- return list(self.get_parts(['settings']).values())[0]
+ return self.parts[self.categories['settings'][0]]['part']
+
+ def get_content_types(self):
+ """" returns the content types part """
+ return self.parts[self.categories['content_types'][0]]['part']
+
+ def get_relations(self, part_zi):
+ """ returns the """
+ rel_fn = "word/_rels/%s.rels" % os.path.basename(part_zi.filename)
+ if rel_fn in self.zip.namelist():
+ zi = self.zip.getinfo(rel_fn)
+ rel_root = etree.parse(self.zip.open(zi))
+ self.parts[zi] = dict(zi=zi, part=rel_root)
+ relations = RelationsDocument(rel_root)
+ for relation in relations.get_all():
+ self.merge_data.unique_id_manager.register_id_str(relation.attrib['Id'])
+ return relations
+ # else:
+ # print(rel_fn, self.zip.namelist())
def __setattr__(self, __name, __value):
super(MailMerge, self).__setattr__(__name, __value)
@@ -757,7 +909,10 @@ def __setattr__(self, __name, __value):
self.merge_data.remove_empty_tables = __value
def __fill_parts(self):
- content_types = etree.parse(self.zip.open('[Content_Types].xml'))
+ content_types_zi = self.zip.getinfo('[Content_Types].xml')
+ content_types = etree.parse(self.zip.open(content_types_zi))
+ self.categories['content_types'] = [content_types_zi]
+ self.parts[content_types_zi] = dict(part=content_types)
for file in content_types.findall('{%(ct)s}Override' % NAMESPACES):
part_type = file.attrib['ContentType' % NAMESPACES]
category = CONTENT_TYPES_PARTS.get(part_type)
@@ -865,7 +1020,8 @@ def __fill_complex_fields(self, part):
def __fix_settings(self):
- for settings in self.get_parts(categories=['settings']).values():
+ settings = self.get_settings()
+ if settings:
settings_root = settings.getroot()
if not self._has_unmerged_fields:
mail_merge = settings_root.find('{%(w)s}mailMerge' % NAMESPACES)
@@ -885,7 +1041,7 @@ def __fix_settings(self):
def __get_tree_of_file(self, file):
fn = file.attrib['PartName' % NAMESPACES].split('/', 1)[1]
zi = self.zip.getinfo(fn)
- return zi, etree.parse(self.zip.open(zi))
+ return zi, dict(zi=zi, file=file, part=etree.parse(self.zip.open(zi)))
def write(self, file, empty_value=''):
self._has_unmerged_fields = bool(self.get_merge_fields())
@@ -906,21 +1062,34 @@ def write(self, file, empty_value=''):
# Remove mail merge settings to avoid error messages when opening document in Winword
self.__fix_settings()
+ # add the new files in the content types
+ content_types = self.get_content_types().getroot()
+ for _filename, part_content_type, _part in self.new_parts:
+ content_types.append(part_content_type)
+
with ZipFile(file, 'w', ZIP_DEFLATED) as output:
for zi in self.zip.filelist:
if zi in self.parts:
- xml = etree.tostring(self.parts[zi].getroot(), encoding='UTF-8', xml_declaration=True)
+ xml = etree.tostring(self.parts[zi]['part'].getroot(), encoding='UTF-8', xml_declaration=True)
output.writestr(zi.filename, xml)
else:
output.writestr(zi.filename, self.zip.read(zi))
- def get_merge_fields(self, parts=None):
+ for filename, _part_content_type, part in self.new_parts:
+ xml = etree.tostring(part.getroot(), encoding='UTF-8', xml_declaration=True)
+ output.writestr(filename, xml)
+
+ def get_merge_fields(self):
+ """" get the fields from the document """
+ return self._get_merge_fields()
+
+ def _get_merge_fields(self, parts=None):
if not parts:
- parts = self.get_parts().values()
+ parts = self.get_parts()
fields = set()
for part in parts:
- for mf in part.findall('.//MergeField'):
+ for mf in part['part'].findall('.//MergeField'):
fields.add(mf.attrib['name'])
# for name in self.merge_data.get_merge_fields(mf.attrib['merge_key']):
# fields.add(name)
@@ -943,24 +1112,45 @@ def merge_templates(self, replacements, separator):
assert replacements, "empty data"
#TYPE PARAM CONTROL AND SPLIT
+ # prepare the side documents, like headers, footers, etc
+ rel_docs = []
+ for part_info in self.get_parts(["header_footer"]):
+ relations = self.get_relations(part_info['zi'])
+ merge_header_footer_doc = MergeHeaderFooterDocument(part_info, relations, separator)
+ rel_docs.append(merge_header_footer_doc)
+ self.merge_data.unique_id_manager.register_id(merge_header_footer_doc.id_type, int(merge_header_footer_doc.part_id))
+
# Duplicate template. Creates a copy of the template, does a merge, and separates them by a new paragraph, a new break or a new section break.
#GET ROOT - WORK WITH DOCUMENT
- for part in self.get_parts(["main"]).values():
- root = part.getroot()
+ for part_info in self.get_parts(["main"]):
+ root = part_info['part'].getroot()
+ relations = self.get_relations(part_info['zi'])
# the mailmerge is done with the help of the MergeDocument class
# that handles the document duplication
- with MergeDocument(root, separator) as merge_doc:
+ with MergeDocument(self.merge_data, root, relations, separator) as merge_doc:
row = self.merge_data.start_merge(replacements)
while row is not None:
merge_doc.prepare(self.merge_data, first=self.merge_data.is_first())
+
+ finish_rels = []
+ for rel_doc in rel_docs:
+ rel_doc.prepare(self.merge_data, first=self.merge_data.is_first())
+ rel_doc.merge(self.merge_data, row)
+ finish_rels.extend(rel_doc.finish(self.merge_data))
+
try:
merge_doc.merge(self.merge_data, row)
- merge_doc.finish()
+ merge_doc.finish(finish_rels)
except SkipRecord:
- merge_doc.finish(abort=True)
+ merge_doc.finish(finish_rels, abort=True)
+
row = self.merge_data.next_row()
+
+ # add all new files in the zip
+ for rel_doc in rel_docs:
+ self.new_parts.extend(rel_doc.new_parts)
def merge_pages(self, replacements):
"""
@@ -979,14 +1169,17 @@ def merge(self, **replacements):
self._merge(replacements)
def _merge(self, replacements):
- for part in self.get_parts().values():
+ for part_info in self.get_parts():
+ self.merge_data.replace(part_info['part'], replacements)
+
+ for _filename, _part_content_type, part in self.new_parts:
self.merge_data.replace(part, replacements)
def merge_rows(self, anchor, rows):
""" anchor is one of the fields in the table """
- for part in self.get_parts().values():
- self.merge_data.replace_table_rows(part, anchor, rows)
+ for part_info in self.get_parts():
+ self.merge_data.replace_table_rows(part_info['part'], anchor, rows)
def __enter__(self):
return self
diff --git a/tests/test_footer.docx b/tests/test_footer.docx
new file mode 100644
index 0000000..8c68964
Binary files /dev/null and b/tests/test_footer.docx differ
diff --git a/tests/test_footnote_header_footer.docx b/tests/test_footnote_header_footer.docx
index 72178f2..6805a52 100644
Binary files a/tests/test_footnote_header_footer.docx and b/tests/test_footnote_header_footer.docx differ
diff --git a/tests/test_footnote_header_footer.py b/tests/test_footnote_header_footer.py
index b67d2d0..bfb7bee 100644
--- a/tests/test_footnote_header_footer.py
+++ b/tests/test_footnote_header_footer.py
@@ -1,18 +1,14 @@
import unittest
-import tempfile
-import warnings
-import datetime
-import locale
-from os import path
-from lxml import etree
from mailmerge import NAMESPACES
-from tests.utils import EtreeMixin, get_document_body_part
+from tests.utils import EtreeMixin, get_document_body_part, TEXTS_XPATH, get_document_body_parts
FOOTNOTE_XPATH = "//w:footnote[@w:id = '1']/w:p/w:r/w:t/text()"
class FootnoteHeaderFooterTest(EtreeMixin, unittest.TestCase):
-
+ # @TODO test missing values
+ # @TODO test if separator isn't section
+ # @TODO test headers/footers with relations
@unittest.expectedFailure
def test_all(self):
values = ["one", "two", "three"]
@@ -20,12 +16,54 @@ def test_all(self):
# fix this when it is implemented
document, root_elem = self.merge_templates(
'test_footnote_header_footer.docx',
- [{"fieldname": value, "footerfield": "f_" + value, "headerfield": "h_" + value}
+ [
+ {
+ "fieldname": value,
+ "footerfield": "f_" + value,
+ "headerfield": "h_" + value,
+ "footerfirst": "ff_" + value,
+ "headerfirst": "hf_" + value,
+ "footereven": "fe_" + value,
+ "headereven": "he_" + value,
+ }
for value in values
],
- # output="tests/test_output_footnote_header_footer.docx"
+ separator="nextPage_section",
+ # output="tests/output/test_output_footnote_header_footer.docx"
)
+ footers = sorted([
+ "".join(footer_doc_tree.getroot().xpath(TEXTS_XPATH, namespaces=NAMESPACES))
+ for footer_doc_tree in get_document_body_parts(document, endswith="ftr")
+ ])
+ self.assertListEqual(footers, sorted([
+ footer
+ for value in values
+ for footer in [
+ 'Footer on even page fe_%s' % value,
+ 'Footer on every page f_%s' % value,
+ 'Footer on first page ff_%s' % value]] + [
+ 'Footer on even page ',
+ 'Footer on every page ',
+ 'Footer on first page ']
+ ))
+
+ headers = sorted([
+ "".join(header_doc_tree.getroot().xpath(TEXTS_XPATH, namespaces=NAMESPACES))
+ for header_doc_tree in get_document_body_parts(document, endswith="hdr")
+ ])
+ self.assertListEqual(headers, sorted([
+ header
+ for value in values
+ for header in [
+ 'Header even: he_%s' % value,
+ 'Header on every page: h_%s' % value,
+ 'Header on first page: hf_%s' % value]] + [
+ 'Header even: ',
+ 'Header on every page: ',
+ 'Header on first page: ']
+ ))
+
footnote_root_elem = get_document_body_part(document, "footnotes").getroot()
footnote = "".join(footnote_root_elem.xpath(FOOTNOTE_XPATH, namespaces=NAMESPACES))
correct_footnote = " Merge : one "
@@ -35,12 +73,57 @@ def test_only_merge(self):
values = ["one", "two", "three"]
document, root_elem = self.merge(
'test_footnote_header_footer.docx',
- [{"fieldname": value, "footerfield": "f_" + value, "headerfield": "h_" + value}
+ [
+ {
+ "fieldname": value,
+ "footerfield": "f_" + value,
+ "headerfield": "h_" + value,
+ "footerfirst": "ff_" + value,
+ "headerfirst": "hf_" + value,
+ "footereven": "fe_" + value,
+ "headereven": "he_" + value,
+ }
for value in values
][0],
- # output="tests/test_output_one_footnote_header_footer.docx"
+ # output="tests/output/test_output_one_footnote_header_footer.docx"
)
footnote_root_elem = get_document_body_part(document, "footnotes").getroot()
footnote = "".join(footnote_root_elem.xpath(FOOTNOTE_XPATH, namespaces=NAMESPACES))
self.assertEqual(footnote, " Merge : one")
+
+ footers = sorted([
+ "".join(footer_doc_tree.getroot().xpath(TEXTS_XPATH, namespaces=NAMESPACES))
+ for footer_doc_tree in get_document_body_parts(document, endswith="ftr")
+ ])
+ value = values[0]
+ self.assertListEqual(footers, [
+ 'Footer on even page fe_%s' % value,
+ 'Footer on every page f_%s' % value,
+ 'Footer on first page ff_%s' % value])
+
+ headers = sorted([
+ "".join(header_doc_tree.getroot().xpath(TEXTS_XPATH, namespaces=NAMESPACES))
+ for header_doc_tree in get_document_body_parts(document, endswith="hdr")
+ ])
+ self.assertListEqual(headers, [
+ 'Header even: he_%s' % value,
+ 'Header on every page: h_%s' % value,
+ 'Header on first page: hf_%s' % value])
+
+
+ def test_footer(self):
+ values = ["one", "two"]
+ # header/footer/footnotes don't work with multiple replacements, only with merge
+ # fix this when it is implemented
+ document, root_elem = self.merge_templates(
+ 'test_footer.docx',
+ [
+ {
+ "footer": value,
+ }
+ for value in values
+ ],
+ separator="nextPage_section",
+ # output="tests/output/test_footer.docx"
+ )
diff --git a/tests/test_macword2011.py b/tests/test_macword2011.py
index 64e33c0..72409a1 100644
--- a/tests/test_macword2011.py
+++ b/tests/test_macword2011.py
@@ -4,7 +4,7 @@
from lxml import etree
from mailmerge import MailMerge
-from tests.utils import EtreeMixin
+from tests.utils import EtreeMixin, get_document_body_part
class MacWord2011Test(EtreeMixin, unittest.TestCase):
@@ -57,4 +57,4 @@ def test(self):
'/>')
- self.assert_equal_tree(expected_tree, list(document.parts.values())[0].getroot())
+ self.assert_equal_tree(expected_tree, get_document_body_part(document).getroot())
diff --git a/tests/test_merge_table_multipart.py b/tests/test_merge_table_multipart.py
index 14860b2..eabcbc0 100644
--- a/tests/test_merge_table_multipart.py
+++ b/tests/test_merge_table_multipart.py
@@ -4,7 +4,7 @@
from lxml import etree
from mailmerge import MailMerge
-from tests.utils import EtreeMixin
+from tests.utils import EtreeMixin, get_document_body_part
class MergeTableRowsMultipartTest(EtreeMixin, unittest.TestCase):
@@ -32,10 +32,7 @@ def test_merge_rows_on_multipart_file(self):
with tempfile.TemporaryFile() as outfile:
self.document.write(outfile)
- for part in self.document.parts.values():
- # only check the document part
- if (part.getroot().tag == '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}document'):
- self.assert_equal_tree(self.expected_tree, part.getroot())
+ self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot())
def test_merge_unified_on_multipart_file(self):
self.document.merge(
@@ -52,10 +49,7 @@ def test_merge_unified_on_multipart_file(self):
with tempfile.TemporaryFile() as outfile:
self.document.write(outfile)
- for part in self.document.parts.values():
- # only check the document part
- if (part.getroot().tag == '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}document'):
- self.assert_equal_tree(self.expected_tree, part.getroot())
+ self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot())
def tearDown(self):
self.document.close()
diff --git a/tests/test_merge_table_rows.py b/tests/test_merge_table_rows.py
index 399b8a9..575018e 100644
--- a/tests/test_merge_table_rows.py
+++ b/tests/test_merge_table_rows.py
@@ -4,7 +4,7 @@
from lxml import etree
from mailmerge import MailMerge, NAMESPACES
-from tests.utils import EtreeMixin
+from tests.utils import EtreeMixin, get_document_body_part
class MergeTableRowsTest(EtreeMixin, unittest.TestCase):
@@ -31,8 +31,7 @@ def test_merge_rows(self):
with tempfile.TemporaryFile() as outfile:
self.document.write(outfile)
- self.assert_equal_tree(self.expected_tree,
- list(self.document.parts.values())[0].getroot())
+ self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot())
def test_merge_rows_no_table(self):
"""
@@ -56,8 +55,7 @@ def test_merge_rows_no_table(self):
with tempfile.TemporaryFile() as outfile:
self.document.write(outfile)
- self.assert_equal_tree(self.expected_tree,
- list(self.document.parts.values())[0].getroot())
+ self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot())
def test_merge_rows_remove_table(self):
"""
@@ -74,7 +72,7 @@ def test_merge_rows_remove_table(self):
with tempfile.TemporaryFile() as outfile:
self.document.write(outfile)
self.assertIsNone(
- list(self.document.parts.values())[0].getroot().find('.//{%(w)s}tbl' % NAMESPACES)
+ get_document_body_part(self.document).getroot().find('.//{%(w)s}tbl' % NAMESPACES)
)
def test_merge_unified(self):
@@ -92,8 +90,7 @@ def test_merge_unified(self):
with tempfile.TemporaryFile() as outfile:
self.document.write(outfile)
- self.assert_equal_tree(self.expected_tree,
- list(self.document.parts.values())[0].getroot())
+ self.assert_equal_tree(self.expected_tree, get_document_body_part(self.document).getroot())
def test_merge_template_with_rows(self):
"""
@@ -125,8 +122,7 @@ def test_merge_template_with_rows(self):
with tempfile.TemporaryFile() as outfile:
self.document.write(outfile)
expected_tree = etree.fromstring('GradesBouke Haarsma received the grades for in the table below.Class CodeClass NameGradeECON101Economics 101AECONADVEconomics AdvancedBOPRESOperations ResearchATHESISFinal thesisAGradesJon Snow received the grades for in the table below.Class CodeClass NameGradeECON101Economics 101EECONADVEconomics AdvancedFTHESISFinal thesisB')
- self.assert_equal_tree(expected_tree,
- list(self.document.parts.values())[0].getroot())
+ self.assert_equal_tree(expected_tree, get_document_body_part(self.document).getroot())
def tearDown(self):
self.document.close()
diff --git a/tests/test_multiple_elements.py b/tests/test_multiple_elements.py
index 21f5250..8ece76b 100644
--- a/tests/test_multiple_elements.py
+++ b/tests/test_multiple_elements.py
@@ -4,7 +4,7 @@
from lxml import etree
from mailmerge import MailMerge
-from tests.utils import EtreeMixin
+from tests.utils import EtreeMixin, get_document_body_part
class MultipleElementsTest(EtreeMixin, unittest.TestCase):
@@ -49,4 +49,4 @@ def test(self):
'w:bottom="1440" w:left="1440" w:header="708" w:footer="708" w:gutter="0"/>')
- self.assert_equal_tree(expected_tree, list(document.parts.values())[0].getroot())
\ No newline at end of file
+ self.assert_equal_tree(expected_tree, get_document_body_part(document).getroot())
diff --git a/tests/test_unique_id.py b/tests/test_unique_id.py
index 6f10173..f4d4457 100644
--- a/tests/test_unique_id.py
+++ b/tests/test_unique_id.py
@@ -7,7 +7,7 @@ class UniqueIdsManagerTest(unittest.TestCase):
Testing UniqueIdsManager class
"""
- def test_unique_id_manager(self):
+ def test_unique_id_manager_register_id(self):
"""
Tests if the next record field works
"""
@@ -22,3 +22,5 @@ def test_unique_id_manager(self):
for type_id, obj_id, new_id in tests:
self.assertEqual(id_man.register_id(type_id, obj_id=obj_id), new_id)
+
+ self.assertEqual(id_man.register_id_str("footer2"), "footer3")
diff --git a/tests/test_winword2010.py b/tests/test_winword2010.py
index b261e7e..f74d140 100644
--- a/tests/test_winword2010.py
+++ b/tests/test_winword2010.py
@@ -4,7 +4,7 @@
from lxml import etree
from mailmerge import MailMerge, NAMESPACES
-from tests.utils import EtreeMixin
+from tests.utils import EtreeMixin, get_document_body_part
class Windword2010Test(EtreeMixin, unittest.TestCase):
@@ -178,5 +178,5 @@ def test(self):
'' # noqa
)
- self.assert_equal_tree(expected_tree, list(document.parts.values())[0].getroot())
+ self.assert_equal_tree(expected_tree, get_document_body_part(document).getroot())
self.assertIsNone(document.get_settings().getroot().find('{%(w)s}mailMerge' % NAMESPACES))
diff --git a/tests/utils.py b/tests/utils.py
index f1385fd..f7bf272 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -100,7 +100,18 @@ def get_new_docx(self, replacement_parts):
def get_document_body_part(document, endswith="document"):
for part in document.parts.values():
- if part.getroot().tag.endswith('}%s' % endswith):
- return part
+ if part['part'].getroot().tag.endswith('}%s' % endswith):
+ return part['part']
raise AssertionError("main document body not found in document.parts")
+
+def get_document_body_parts(document, endswith="document"):
+ parts = []
+ for part in document.parts.values():
+ if part['part'].getroot().tag.endswith('}%s' % endswith):
+ parts.append(part['part'])
+ for _, _, part in document.new_parts:
+ if part.getroot().tag.endswith('}%s' % endswith):
+ parts.append(part)
+
+ return parts