diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py index 7d66d2639f..19e63a6b5c 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/field_list.py @@ -422,5 +422,65 @@ class User(Base): self.assertRaises(LookUpError, Base.objects.exclude, "made_up") + def test_exclude_and_modify(self): + # Make sure a document with missing fields from exclude won't + # try to overwrite them + + class EmbeddedDoc(EmbeddedDocument): + field = StringField(default='default') + + class Doc(Document): + excluded_field = StringField() + excluded_with_default_field = StringField(default='default') + excluded_embedded_field = EmbeddedDocumentField(EmbeddedDoc) + present_field = StringField() + + Doc.drop_collection() + doc = Doc(excluded_field='v1', present_field='v1', + excluded_with_default_field='v1', + excluded_embedded_field=EmbeddedDoc(field='v1')).save() + + doc_ex = Doc.objects.exclude( + 'excluded_field', + 'excluded_with_default_field', + 'excluded_embedded_field').get(id=doc.id) + doc_ex.present_field = 'v2' + doc_ex.save() + + doc.reload() + self.assertEqual(doc.present_field, 'v2') + self.assertEqual(doc.excluded_field, 'v1') + self.assertEqual(doc.excluded_with_default_field, 'v1') + self.assertEqual(doc.excluded_embedded_field, EmbeddedDoc(field='v1')) + + def test_only_and_modify(self): + # Make sure a document with missing fields from only won't + # try to overwrite them + + class EmbeddedDoc(EmbeddedDocument): + field = StringField(default='default') + + class Doc(Document): + excluded_field = StringField() + excluded_with_default_field = StringField(default='default') + excluded_embedded_field = EmbeddedDocumentField(EmbeddedDoc) + present_field = StringField() + + Doc.drop_collection() + doc = Doc(excluded_field='v1', present_field='v1', + excluded_with_default_field='v1', + excluded_embedded_field=EmbeddedDoc(field='v1')).save() + + doc_ex = Doc.objects.only('present_field').get(id=doc.id) + doc_ex.present_field = 'v2' + doc_ex.save() + + doc.reload() + self.assertEqual(doc.present_field, 'v2') + self.assertEqual(doc.excluded_field, 'v1') + self.assertEqual(doc.excluded_with_default_field, 'v1') + self.assertEqual(doc.excluded_embedded_field, EmbeddedDoc(field='v1')) + + if __name__ == '__main__': unittest.main()