diff --git a/.travis.yml b/.travis.yml index 636bdb6c1..8aa99393c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,5 @@ language: python cache: pip -services: mongodb -sudo: required python: - "3.6" install: @@ -13,17 +11,23 @@ install: before_script: - python setup.py develop - cd $HOME - - curl -O https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-3.6.5.tgz - - tar -zxvf mongodb-linux-x86_64-3.6.5.tgz + - curl -O https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-3.6.8.tgz + - tar -zxvf mongodb-linux-x86_64-3.6.8.tgz - mkdir -p mongodbdata - touch mongodblog - | - mongodb-linux-x86_64-3.6.5/bin/mongod \ + mongodb-linux-x86_64-3.6.8/bin/mongod \ --port 27020 --dbpath mongodbdata --logpath mongodblog \ --auth --bind_ip_all --fork - | - mongodb-linux-x86_64-3.6.5/bin/mongo 127.0.0.1:27020/admin --eval \ + mongodb-linux-x86_64-3.6.8/bin/mongo 127.0.0.1:27020/admin --eval \ 'db.createUser({user:"mongoadmin",pwd:"mongoadminpass",roles:["root"]});' + - mkdir -p localdbdata + - touch localdblog + - | + mongodb-linux-x86_64-3.6.8/bin/mongod \ + --port 27017 --dbpath localdbdata --logpath localdblog \ + --noauth --bind_ip_all --fork - cd - script: - mpiexec -n 2 python $PWD/maggma/tests/mpi_test.py diff --git a/maggma/advanced_stores.py b/maggma/advanced_stores.py index da26b05fe..9e6060e5a 100644 --- a/maggma/advanced_stores.py +++ b/maggma/advanced_stores.py @@ -37,8 +37,8 @@ class MongograntStore(Mongolike, Store): mongogrant documentation: https://github.com/materialsproject/mongogrant """ - def __init__(self, mongogrant_spec, collection_name, - mgclient_config_path=None, **kwargs): + + def __init__(self, mongogrant_spec, collection_name, mgclient_config_path=None, **kwargs): """ Args: @@ -54,7 +54,7 @@ def __init__(self, mongogrant_spec, collection_name, self.collection_name = collection_name self.mgclient_config_path = mgclient_config_path self._collection = None - if set(("username", "password","database", "host")) & set(kwargs): + if set(("username", "password", "database", "host")) & set(kwargs): raise StoreError("MongograntStore does not accept " "username, password, database, or host " "arguments. Use `mongogrant_spec`.") @@ -75,8 +75,7 @@ def __hash__(self): return hash((self.mongogrant_spec, self.collection_name, self.lu_field)) def groupby(self, keys, criteria=None, properties=None, **kwargs): - return MongoStore.groupby( - self, keys, criteria=None, properties=None, **kwargs) + return MongoStore.groupby(self, keys, criteria=None, properties=None, **kwargs) class VaultStore(MongoStore): @@ -247,10 +246,11 @@ def __init__(self, store, sandbox, exclusive=False): self.store = store self.sandbox = sandbox self.exclusive = exclusive - super().__init__(key=self.store.key, - lu_field=self.store.lu_field, - lu_type=self.store.lu_type, - validator=self.store.validator) + super().__init__( + key=self.store.key, + lu_field=self.store.lu_field, + lu_type=self.store.lu_type, + validator=self.store.validator) @property @lru_cache(maxsize=1) @@ -258,8 +258,7 @@ def sbx_criteria(self): if self.exclusive: return {"sbxn": self.sandbox} else: - return {"$or": [{"sbxn": {"$in": [self.sandbox]}}, - {"sbxn": {"$exists": False}}]} + return {"$or": [{"sbxn": {"$in": [self.sandbox]}}, {"sbxn": {"$exists": False}}]} def query(self, criteria=None, properties=None, **kwargs): criteria = dict(**criteria, **self.sbx_criteria) if criteria else self.sbx_criteria @@ -315,8 +314,7 @@ def __init__(self, index, bucket, **kwargs): bucket (str) : name of the bucket """ if not boto_import: - raise ValueError("boto not available, please install boto3 to " - "use AmazonS3Store") + raise ValueError("boto not available, please install boto3 to " "use AmazonS3Store") self.index = index self.bucket = bucket self.s3 = None @@ -523,8 +521,17 @@ def rebuild_index_from_s3_data(self): class JointStore(Store): """Store corresponding to multiple collections, uses lookup to join""" - def __init__(self, database, collection_names, host="localhost", - port=27017, username="", password="", master=None, **kwargs): + + def __init__(self, + database, + collection_names, + host="localhost", + port=27017, + username="", + password="", + master=None, + merge_at_root=False, + **kwargs): self.database = database self.collection_names = collection_names self.host = host @@ -533,6 +540,7 @@ def __init__(self, database, collection_names, host="localhost", self.password = password self._collection = None self.master = master or collection_names[0] + self.merge_at_root = merge_at_root self.kwargs = kwargs super(JointStore, self).__init__(**kwargs) @@ -542,6 +550,7 @@ def connect(self, force_reset=False): if self.username is not "": db.authenticate(self.username, self.password) self._collection = db[self.master] + self._has_merge_objects = self._collection.database.client.server_info()["version"] > "3.6" def close(self): self.collection.database.client.close() @@ -554,17 +563,11 @@ def collection(self): def nonmaster_names(self): return list(set(self.collection_names) - {self.master}) - def query(self, criteria=None, properties=None, **kwargs): - pipeline = self._get_pipeline(criteria=criteria, properties=properties) - return self.collection.aggregate(pipeline, **kwargs) - @property def last_updated(self): lus = [] for cname in self.collection_names: - lu = MongoStore.from_collection( - self.collection.database[cname], - lu_field=self.lu_field).last_updated + lu = MongoStore.from_collection(self.collection.database[cname], lu_field=self.lu_field).last_updated lus.append(lu) return max(lus) @@ -579,8 +582,7 @@ def distinct(self, key, criteria=None, all_exist=True, **kwargs): g_key = key if isinstance(key, list) else [key] if all_exist: criteria = criteria or {} - criteria.update({k: {"$exists": True} for k in g_key - if k not in criteria}) + criteria.update({k: {"$exists": True} for k in g_key if k not in criteria}) cursor = self.groupby(g_key, criteria=criteria, **kwargs) if isinstance(key, list): return [d['_id'] for d in cursor] @@ -605,16 +607,34 @@ def _get_pipeline(self, criteria=None, properties=None): for cname in self.collection_names: if cname is not self.master: pipeline.append({ - "$lookup": {"from": cname, "localField": self.key, - "foreignField": self.key, "as": cname}}) - pipeline.append({ - "$unwind": {"path": "${}".format(cname), - "preserveNullAndEmptyArrays": True}}) + "$lookup": { + "from": cname, + "localField": self.key, + "foreignField": self.key, + "as": cname + } + }) + + if self.merge_at_root: + if not self._has_merge_objects: + raise Exception( + "MongoDB server version too low to use $mergeObjects.") + + pipeline.append({ + "$replaceRoot": { + "newRoot": { + "$mergeObjects": [{ + "$arrayElemAt": ["${}".format(cname), 0] + }, "$$ROOT"] + } + } + }) + else: + pipeline.append({"$unwind": {"path": "${}".format(cname), "preserveNullAndEmptyArrays": True}}) # Do projection for max last_updated lu_max_fields = ["${}".format(self.lu_field)] - lu_max_fields.extend(["${}.{}".format(cname, self.lu_field) - for cname in self.collection_names]) + lu_max_fields.extend(["${}.{}".format(cname, self.lu_field) for cname in self.collection_names]) lu_proj = {self.lu_field: {"$max": lu_max_fields}} pipeline.append({"$addFields": lu_proj}) @@ -624,8 +644,14 @@ def _get_pipeline(self, criteria=None, properties=None): properties = {k: 1 for k in properties} if properties: pipeline.append({"$project": properties}) + return pipeline + def query(self, criteria=None, properties=None, **kwargs): + pipeline = self._get_pipeline(criteria=criteria, properties=properties) + agg = self.collection.aggregate(pipeline, **kwargs) + return agg + def groupby(self, keys, criteria=None, properties=None, **kwargs): pipeline = self._get_pipeline(criteria=criteria, properties=properties) if not isinstance(keys, list): @@ -633,10 +659,11 @@ def groupby(self, keys, criteria=None, properties=None, **kwargs): group_id = {} for key in keys: set_(group_id, key, "${}".format(key)) - pipeline.append({"$group": {"_id": group_id, - "docs": {"$push": "$$ROOT"}}}) + pipeline.append({"$group": {"_id": group_id, "docs": {"$push": "$$ROOT"}}}) + + agg = self.collection.aggregate(pipeline, **kwargs) - return self.collection.aggregate(pipeline, **kwargs) + return agg def query_one(self, criteria=None, properties=None, **kwargs): """ @@ -655,7 +682,7 @@ def query_one(self, criteria=None, properties=None, **kwargs): # pipeline.append({"$limit": 1}) query = self.query(criteria=criteria, properties=properties, **kwargs) try: - doc = query.next() + doc = next(query) return doc except StopIteration: return None diff --git a/maggma/cli/mrun.py b/maggma/cli/mrun.py index e9cab1330..dbc6ddf15 100644 --- a/maggma/cli/mrun.py +++ b/maggma/cli/mrun.py @@ -48,7 +48,8 @@ def main(): # This is a runner: root.info("Changing number of workers from default in input file") runner = Runner(objects.builders, args.num_workers, mpi=args.mpi) - else: + elif isinstance(objects, Builder): + runner = Runner([objects], args.num_workers, mpi=args.mpi) root.error("Couldn't properly read the builder file.") if not args.dry_run: diff --git a/maggma/examples/builders.py b/maggma/examples/builders.py index 7d9494ee6..a7e2a99ac 100644 --- a/maggma/examples/builders.py +++ b/maggma/examples/builders.py @@ -6,10 +6,10 @@ from datetime import datetime from maggma.builder import Builder -from maggma.utils import confirm_field_index, total_size +from maggma.utils import confirm_field_index, grouper -def source_keys_updated(source, target): +def source_keys_updated(source, target, query=None): """ Utility for incremental building. Gets a list of source.key values. @@ -22,17 +22,15 @@ def source_keys_updated(source, target): """ keys_updated = set() # Handle non-unique keys, e.g. for GroupBuilder. cursor_source = source.query( - properties=[source.key, source.lu_field], sort=[(source.lu_field, -1), (source.key, 1)]) + criteria=query, properties=[source.key, source.lu_field], sort=[(source.lu_field, -1), (source.key, 1)]) cursor_target = target.query( properties=[target.key, target.lu_field], sort=[(target.lu_field, -1), (target.key, 1)]) tdoc = next(cursor_target, None) for sdoc in cursor_source: if tdoc is None: keys_updated.add(sdoc[source.key]) - continue - - if tdoc[target.key] == sdoc[source.key]: - if tdoc[target.lu_field] < source.lu_func[0](sdoc[source.lu_field]): + elif tdoc[target.key] == sdoc[source.key]: + if target.lu_func[0](tdoc[target.lu_field]) < source.lu_func[0](sdoc[source.lu_field]): keys_updated.add(sdoc[source.key]) tdoc = next(cursor_target, None) else: @@ -40,8 +38,8 @@ def source_keys_updated(source, target): return list(keys_updated) -def get_criteria(source, target, query=None, incremental=True, logger=None): - """Return criteria to pass to `source.query` to get items.""" +def get_keys(source, target, query=None, incremental=True, logger=None): + """Return keys to pass to `source.query` to get items.""" index_checks = [confirm_field_index(target, target.key)] if incremental: # Ensure [(lu_field, -1), (key, 1)] index on both source and target @@ -58,42 +56,9 @@ def get_criteria(source, target, query=None, incremental=True, logger=None): if logger: logger.warning(index_warning) - criteria = {} - if query: - criteria.update(query) - if incremental: - if logger: - logger.info("incremental mode: finding new/updated source keys") - keys_updated = source_keys_updated(source, target) - # Merge existing criteria and {source.key: {"$in": keys_updated}}. - if "$and" in criteria: - criteria["$and"].append({source.key: {"$in": keys_updated}}) - elif source.key in criteria: - # XXX could go deeper and check for $in, but this is fine. - criteria["$and"] = [ - { - source.key: criteria[source.key].copy() - }, - { - source.key: { - "$in": keys_updated - } - }, - ] - del criteria[source.key] - else: - criteria.update({source.key: {"$in": keys_updated}}) - # Check ratio of criteria size to 16 MB MongoDB document size limit. - # Overestimates ratio via 1000 * 1000 instead of 1024 * 1024. - # If criteria is > 16MB, even cursor.count() will fail with a - # "DocumentTooLarge: "command document too large" error. - if (total_size(criteria) / (16 * 1000 * 1000)) >= 1: - raise RuntimeError("`get_items` query criteria too large. This can happen if " - "trying to run incremental=True for the initial build of a " - "very large source store, or if `query` is too large. You " - "can use maggma.utils.total_size to ensure `query` is smaller " - "than 16,000,000 bytes.") - return criteria + keys_updated = source_keys_updated(source, target, query) + + return keys_updated class MapBuilder(Builder, metaclass=ABCMeta): @@ -131,21 +96,31 @@ def __init__(self, source, target, ufn, query=None, incremental=True, projection self.query = query self.ufn = ufn self.projection = projection if projection else [] + self.kwargs = kwargs super().__init__(sources=[source], targets=[target], **kwargs) - self.kwargs = kwargs.copy() - self.kwargs.update(query=query, incremental=incremental) def get_items(self): - criteria = get_criteria( - self.source, self.target, query=self.query, incremental=self.incremental, logger=self.logger) + + self.logger.info("Starting {} Builder".format(self.__class__.__name__)) + keys = get_keys(source=self.source, target=self.target, query=self.query, logger=self.logger) + + self.logger.info("Processing {} items".format(len(keys))) + if self.projection: projection = list(set(self.projection + [self.source.key, self.source.lu_field])) else: projection = None - return self.source.query(criteria=criteria, properties=projection) + self.total = len(keys) + for chunked_keys in grouper(keys, self.chunk_size, None): + chunked_keys = list(filter(None.__ne__, chunked_keys)) + for doc in list(self.source.query(criteria={self.source.key: {"$in": chunked_keys}}, properties=projection)): + yield doc def process_item(self, item): + + self.logger.debug("Processing: {}".format(item[self.source.key])) + try: processed = self.ufn.__call__(item) except Exception as e: @@ -153,7 +128,7 @@ def process_item(self, item): processed = {"error": str(e)} key, lu_field = self.source.key, self.source.lu_field out = {self.target.key: item[key]} - out[self.target.lu_field] = self.source.lu_func[0](item[self.source.lu_field]) + out[self.target.lu_field] = self.source.lu_func[0](item[self.source.lu_field]) out.update(processed) return out @@ -167,7 +142,9 @@ def update_targets(self, items): item["_bt"] = datetime.utcnow() if "_id" in item: del item["_id"] - target.update(items, update_lu=False) + + if len(items) > 0: + target.update(items, update_lu=False) class GroupBuilder(MapBuilder, metaclass=ABCMeta): @@ -197,7 +174,7 @@ def __init__(self, source, target, query=None, incremental=True, **kwargs): self.total = None def get_items(self): - criteria = get_criteria( + criteria = get_keys( self.source, self.target, query=self.query, incremental=self.incremental, logger=self.logger) if all(isinstance(entry, str) for entry in self.grouping_properties()): properties = {entry: 1 for entry in self.grouping_properties()} diff --git a/maggma/examples/tests/test_copybuilder.py b/maggma/examples/tests/test_copybuilder.py index 3f126b266..6babd8c04 100644 --- a/maggma/examples/tests/test_copybuilder.py +++ b/maggma/examples/tests/test_copybuilder.py @@ -1,6 +1,7 @@ """Test maggma.examples.builders.CopyBuilder.""" import logging +import unittest from datetime import datetime, timedelta from unittest import TestCase from uuid import uuid4 @@ -74,9 +75,8 @@ def test_index_warning(self): """Should log warning when recommended store indexes are not present.""" self.source.collection.drop_index("lu_-1_k_1") with self.assertLogs(level=logging.WARNING) as cm: - self.builder.get_items() + list(self.builder.get_items()) self.assertIn("Ensure indices", "\n".join(cm.output)) - self.source.collection.create_index("lu_-1_k_1") def test_runner(self): self.source.collection.insert_many(self.old_docs) @@ -96,3 +96,7 @@ def test_query(self): all_docs = list(self.target.query(criteria={})) self.assertEqual(len(all_docs), 14) self.assertTrue(min([d['k'] for d in all_docs]), 6) + + +if __name__ == "__main__": + unittest.main() diff --git a/maggma/stores.py b/maggma/stores.py index ed9bbabf2..3f14a535a 100644 --- a/maggma/stores.py +++ b/maggma/stores.py @@ -9,6 +9,7 @@ from datetime import datetime import json import zlib +import logging import mongomock import pymongo @@ -44,6 +45,8 @@ def __init__(self, key="task_id", lu_field='last_updated', lu_type="datetime", v self.lu_type = lu_type self.lu_func = LU_KEY_ISOFORMAT if lu_type == "isoformat" else (identity, identity) self.validator = validator + self.logger = logging.getLogger(type(self).__name__) + self.logger.addHandler(logging.NullHandler()) @property @abstractmethod diff --git a/maggma/tests/test_advanced_stores.py b/maggma/tests/test_advanced_stores.py index 97e159787..ebec42a72 100644 --- a/maggma/tests/test_advanced_stores.py +++ b/maggma/tests/test_advanced_stores.py @@ -33,22 +33,16 @@ def setUpClass(cls): cls.mdport = 27020 if not (os.getenv("CONTINUOUS_INTEGRATION") and os.getenv("TRAVIS")): basecmd = ("mongod --port {} --dbpath {} --quiet --logpath {} " - "--bind_ip_all --auth" - .format(cls.mdport, cls.mdpath, cls.mdlogpath)) - cls.mongod_process = subprocess.Popen( - basecmd, shell=True, start_new_session=True) + "--bind_ip_all --auth".format(cls.mdport, cls.mdpath, cls.mdlogpath)) + cls.mongod_process = subprocess.Popen(basecmd, shell=True, start_new_session=True) time.sleep(5) client = MongoClient(port=cls.mdport) - client.admin.command("createUser", "mongoadmin", - pwd="mongoadminpass", roles=["root"]) + client.admin.command("createUser", "mongoadmin", pwd="mongoadminpass", roles=["root"]) client.close() cls.dbname = "test_" + uuid4().hex - cls.db = MongoClient( - "mongodb://mongoadmin:mongoadminpass@127.0.0.1:{}/admin".format( - cls.mdport))[cls.dbname] + cls.db = MongoClient("mongodb://mongoadmin:mongoadminpass@127.0.0.1:{}/admin".format(cls.mdport))[cls.dbname] cls.db.command("createUser", "reader", pwd="readerpass", roles=["read"]) - cls.db.command("createUser", "writer", - pwd="writerpass", roles=["readWrite"]) + cls.db.command("createUser", "writer", pwd="writerpass", roles=["readWrite"]) cls.db.client.close() @classmethod @@ -77,23 +71,19 @@ def setUp(self): username="writer", password="writerpass", ) - self.client.set_alias( - "testhost", "localhost:{}".format(self.mdport), which="host") + self.client.set_alias("testhost", "localhost:{}".format(self.mdport), which="host") self.client.set_alias("testdb", self.dbname, which="db") @staticmethod def connected_user(store): - return store.collection.database.command( - "connectionStatus")['authInfo']['authenticatedUsers'][0]['user'] + return store.collection.database.command("connectionStatus")['authInfo']['authenticatedUsers'][0]['user'] def test_connect(self): - store = MongograntStore("ro:testhost/testdb", "tasks", - mgclient_config_path=self.config_path) + store = MongograntStore("ro:testhost/testdb", "tasks", mgclient_config_path=self.config_path) store.connect() self.assertIsInstance(store.collection, Collection) self.assertEqual(self.connected_user(store), "reader") - store = MongograntStore("rw:testhost/testdb", "tasks", - mgclient_config_path=self.config_path) + store = MongograntStore("rw:testhost/testdb", "tasks", mgclient_config_path=self.config_path) store.connect() self.assertIsInstance(store.collection, Collection) self.assertEqual(self.connected_user(store), "writer") @@ -113,11 +103,14 @@ def _create_vault_store(self): instance.read.return_value = { 'wrap_info': None, 'request_id': '2c72c063-2452-d1cd-19a2-91163c7395f7', - 'data': {'value': '{"db": "mg_core_prod", "host": "matgen2.lbl.gov", "username": "test", "password": "pass"}'}, + 'data': { + 'value': '{"db": "mg_core_prod", "host": "matgen2.lbl.gov", "username": "test", "password": "pass"}' + }, 'auth': None, 'warnings': None, 'renewable': False, - 'lease_duration': 2764800, 'lease_id': '' + 'lease_duration': 2764800, + 'lease_id': '' } v = VaultStore("test_coll", "secret/matgen/maggma") @@ -160,7 +153,6 @@ def test_vault_missing_env(self): class TestS3Store(unittest.TestCase): - def setUp(self): self.index = MemoryStore("index'") with patch("boto3.resource") as mock_resource: @@ -202,31 +194,35 @@ def test_update_compression(self): class TestAliasingStore(unittest.TestCase): - def setUp(self): self.memorystore = MemoryStore("test") self.memorystore.connect() - self.aliasingstore = AliasingStore( - self.memorystore, {"a": "b", "c.d": "e", "f": "g.h"}) + self.aliasingstore = AliasingStore(self.memorystore, {"a": "b", "c.d": "e", "f": "g.h"}) def test_query(self): d = [{"b": 1}, {"e": 2}, {"g": {"h": 3}}] self.memorystore.collection.insert_many(d) - self.assertTrue("a" in list(self.aliasingstore.query( - criteria={"a": {"$exists": 1}}))[0]) - self.assertTrue("c" in list(self.aliasingstore.query( - criteria={"c.d": {"$exists": 1}}))[0]) - self.assertTrue("d" in list(self.aliasingstore.query( - criteria={"c.d": {"$exists": 1}}))[0].get("c", {})) - self.assertTrue("f" in list(self.aliasingstore.query( - criteria={"f": {"$exists": 1}}))[0]) + self.assertTrue("a" in list(self.aliasingstore.query(criteria={"a": {"$exists": 1}}))[0]) + self.assertTrue("c" in list(self.aliasingstore.query(criteria={"c.d": {"$exists": 1}}))[0]) + self.assertTrue("d" in list(self.aliasingstore.query(criteria={"c.d": {"$exists": 1}}))[0].get("c", {})) + self.assertTrue("f" in list(self.aliasingstore.query(criteria={"f": {"$exists": 1}}))[0]) def test_update(self): - self.aliasingstore.update([{"task_id": "mp-3", "a": 4}, {"task_id": "mp-4", - "c": {"d": 5}}, {"task_id": "mp-5", "f": 6}]) + self.aliasingstore.update([{ + "task_id": "mp-3", + "a": 4 + }, { + "task_id": "mp-4", + "c": { + "d": 5 + } + }, { + "task_id": "mp-5", + "f": 6 + }]) self.assertEqual(list(self.aliasingstore.query(criteria={"task_id": "mp-3"}))[0]["a"], 4) self.assertEqual(list(self.aliasingstore.query(criteria={"task_id": "mp-4"}))[0]["c"]["d"], 5) self.assertEqual(list(self.aliasingstore.query(criteria={"task_id": "mp-5"}))[0]["f"], 6) @@ -257,7 +253,6 @@ def test_substitute(self): class TestSandboxStore(unittest.TestCase): - def setUp(self): self.store = MemoryStore() self.sandboxstore = SandboxStore(self.store, sandbox="test") @@ -265,7 +260,7 @@ def setUp(self): def test_connect(self): with self.assertRaises(Exception): self.sandboxstore.collection - + self.sandboxstore.connect() self.assertIsInstance(self.sandboxstore.collection, mongomock.collection.Collection) @@ -275,12 +270,10 @@ def test_query(self): self.assertEqual(self.sandboxstore.query_one(properties=["a"])['a'], 1) self.sandboxstore.collection.insert_one({"a": 2, "b": 2, "sbxn": ["test"]}) - self.assertEqual(self.sandboxstore.query_one(properties=["b"], - criteria={"a": 2})['b'], 2) + self.assertEqual(self.sandboxstore.query_one(properties=["b"], criteria={"a": 2})['b'], 2) self.sandboxstore.collection.insert_one({"a": 3, "b": 2, "sbxn": ["not_test"]}) - self.assertEqual(self.sandboxstore.query_one(properties=["c"], - criteria={"a": 3}), None) + self.assertEqual(self.sandboxstore.query_one(properties=["c"], criteria={"a": 3}), None) def test_distinct(self): self.sandboxstore.connect() @@ -299,8 +292,7 @@ def test_update(self): self.assertEqual(self.sandboxstore.query(criteria={"d": {"$exists": 1}}, properties=["d"])[0]["d"], 4) self.assertEqual(self.sandboxstore.collection.find_one({"e": 6})["sbxn"], ["test"]) self.sandboxstore.update([{"e": 7, "sbxn": ["core"]}], key="e") - self.assertEqual(set(self.sandboxstore.query_one( - criteria={"e": 7})["sbxn"]), {"test", "core"}) + self.assertEqual(set(self.sandboxstore.query_one(criteria={"e": 7})["sbxn"]), {"test", "core"}) def tearDown(self): try: @@ -314,13 +306,19 @@ def setUp(self): self.jointstore = JointStore("maggma_test", ["test1", "test2"]) self.jointstore.connect() self.jointstore.collection.drop() - self.jointstore.collection.insert_many( - [{"task_id": k, "my_prop": k+1, "last_updated": datetime.utcnow(), - "category": k // 5} for k in range(10)]) + self.jointstore.collection.insert_many([{ + "task_id": k, + "my_prop": k + 1, + "last_updated": datetime.utcnow(), + "category": k // 5 + } for k in range(10)]) self.jointstore.collection.database["test2"].drop() - self.jointstore.collection.database["test2"].insert_many( - [{"task_id": 2*k, "your_prop": k+3, "last_updated": datetime.utcnow(), - "category2": k // 3} for k in range(5)]) + self.jointstore.collection.database["test2"].insert_many([{ + "task_id": 2 * k, + "your_prop": k + 3, + "last_updated": datetime.utcnow(), + "category2": k // 3 + } for k in range(5)]) self.test1 = MongoStore("maggma_test", "test1") self.test1.connect() self.test2 = MongoStore("maggma_test", "test2") @@ -330,7 +328,7 @@ def test_query(self): # Test query all docs = list(self.jointstore.query()) self.assertEqual(len(docs), 10) - docs_w_field = [d for d in docs if d.get("test2")] + docs_w_field = [d for d in docs if "test2" in d] self.assertEqual(len(docs_w_field), 5) docs_w_field = sorted(docs_w_field, key=lambda x: x['task_id']) self.assertEqual(docs_w_field[0]['test2']['your_prop'], 3) @@ -350,6 +348,18 @@ def test_query_one(self): doc = self.jointstore.query_one(criteria={"test2.your_prop": {"$gt": 6}}) self.assertEqual(doc['task_id'], 8) + # Test merge_at_root + self.jointstore.merge_at_root = True + + # Test merging is working properly + doc = self.jointstore.query_one(criteria={"task_id": 2}) + self.assertEqual(doc['my_prop'], 3) + self.assertEqual(doc['your_prop'], 4) + + # Test merging is allowing for subsequent match + doc = self.jointstore.query_one(criteria={"your_prop": {"$gt": 6}}) + self.assertEqual(doc['task_id'], 8) + def test_distinct(self): dyour_prop = self.jointstore.distinct("test2.your_prop") self.assertEqual(set(dyour_prop), {k + 3 for k in range(5)}) diff --git a/maggma/tests/test_utils.py b/maggma/tests/test_utils.py index 7c63726b4..5d74fd41a 100644 --- a/maggma/tests/test_utils.py +++ b/maggma/tests/test_utils.py @@ -3,31 +3,11 @@ Tests utillities """ import unittest -from maggma.utils import get_mongolike, make_mongolike, put_mongolike, recursive_update +from maggma.utils import recursive_update class UtilsTests(unittest.TestCase): - def test_get_mongolike(self): - d = {"a": [{"b": 1}, {"c": {"d": 2}}], "e": {"f": {"g": 3}}, "g": 4, "h": [5, 6]} - - self.assertEqual(get_mongolike(d, "g"), 4) - self.assertEqual(get_mongolike(d, "e.f.g"), 3) - self.assertEqual(get_mongolike(d, "a.0.b"), 1) - self.assertEqual(get_mongolike(d, "a.1.c.d"), 2) - self.assertEqual(get_mongolike(d, "h.-1"), 6) - - def test_put_mongolike(self): - self.assertEqual(put_mongolike("e", 1), {"e": 1}) - self.assertEqual(put_mongolike("e.f.g", 1), {"e": {"f": {"g": 1}}}) - - def test_make_mongolike(self): - d = {"a": [{"b": 1}, {"c": {"d": 2}}], "e": {"f": {"g": 3}}, "g": 4, "h": [5, 6]} - - self.assertEqual(make_mongolike(d, "e.f.g", "a"), {"a": 3}) - self.assertEqual(make_mongolike(d, "e.f.g", "a.b"), {"a": {"b": 3}}) - self.assertEqual(make_mongolike(d, "a.0.b", "e.f"), {"e": {"f": 1}}) - def test_recursiveupdate(self): d = {"a": {"b": 3}, "c": [4]} diff --git a/maggma/utils.py b/maggma/utils.py index 3c86766d8..5db786c2b 100644 --- a/maggma/utils.py +++ b/maggma/utils.py @@ -96,57 +96,6 @@ def isostr_to_dt(s): LU_KEY_ISOFORMAT = (isostr_to_dt, dt_to_isoformat_ceil_ms) -def get_mongolike(d, key): - """ - Grab a dict value using dot-notation like "a.b.c" from dict {"a":{"b":{"c": 3}}} - Args: - d (dict): the dictionary to search - key (str): the key we want to grab with dot notation, e.g., "a.b.c" - - Returns: - value from desired dict (whatever is stored at the desired key) - - """ - lead_key = key.split(".", 1)[0] - try: - lead_key = int(lead_key) # for searching array data - except: - pass - - if "." in key: - remainder = key.split(".", 1)[1] - return get_mongolike(d[lead_key], remainder) - return d[lead_key] - - -def put_mongolike(key, value): - """ - Builds a dictionary with a value using mongo dot-notation - - Args: - key (str): the key to put into using mongo notation, doesn't support arrays - value: object - """ - lead_key = key.split(".", 1)[0] - - if "." in key: - remainder = key.split(".", 1)[1] - return {lead_key: put_mongolike(remainder, value)} - return {lead_key: value} - - -def make_mongolike(d, get_key, put_key): - """ - Builds a dictionary with a value from another dictionary using mongo dot-notation - - Args: - d (dict)L the dictionary to search - get_key (str): the key to grab using mongo notation - put_key (str): the key to put into using mongo notation, doesn't support arrays - """ - return put_mongolike(put_key, get_mongolike(d, get_key)) - - def recursive_update(d, u): """ Recursive updates d with values from u