Skip to content

Commit

Permalink
Merge pull request #129 from materialsproject/fix_async
Browse files Browse the repository at this point in the history
Fix async
  • Loading branch information
Shyam Dwaraknath authored Apr 6, 2020
2 parents 7e2d398 + aee3e54 commit 9a8f5cc
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 93 deletions.
2 changes: 1 addition & 1 deletion src/maggma/api/APIManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def load(self, endpoint, prefix: str = "/"):
class_name = endpoint.split(".")[-1]
new_endpoint = dynamic_import(module_path, class_name)
self.__setitem__(prefix, new_endpoint)
pass

elif isclass(endpoint) and issubclass(endpoint, Resource):
self.__setitem__(prefix, endpoint)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/maggma/api/query_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Optional, Any, Type, Tuple, Mapping
from typing import List, Dict, Optional, Any, Mapping
from pydantic import BaseModel
from fastapi import Query
from monty.json import MSONable
Expand Down
1 change: 0 additions & 1 deletion src/maggma/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)
from fastapi import FastAPI, APIRouter, Path, HTTPException, Depends
from maggma.api.models import Response, Meta
from starlette.responses import RedirectResponse


class Resource(MSONable):
Expand Down
192 changes: 136 additions & 56 deletions src/maggma/cli/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,19 @@

from logging import getLogger
from types import GeneratorType
from asyncio import BoundedSemaphore, get_running_loop, gather
from aioitertools import zip_longest
from asyncio import (
BoundedSemaphore,
get_running_loop,
Queue,
create_task,
Condition,
wait,
FIRST_COMPLETED,
Event,
gather,
)

from aioitertools import enumerate
from concurrent.futures import ProcessPoolExecutor
from maggma.utils import primed
from tqdm import tqdm
Expand All @@ -13,75 +24,133 @@
logger = getLogger("MultiProcessor")


class ProcessItemsSemaphore(BoundedSemaphore):
class BackPressure:
"""
Modified BoundedSemaphore to update a TQDM bar
for process_items
Wrapper for an iterator to provide
async access with backpressure
"""

def __init__(self, total=None, *args, **kwargs):
self.tqdm = tqdm(total=total, desc="Process Items")
super().__init__(*args, **kwargs)
def __init__(self, iterator, n):
self.iterator = iter(iterator)
self.back_pressure = BoundedSemaphore(n)

def release(self):
self.tqdm.update(1)
super().release()
def __aiter__(self):
return self

async def __anext__(self):
await self.back_pressure.acquire()

def safe_dispatch(val):
func, item = val
try:
return func(item)
except Exception as e:
logger.error(e)
return None
try:
return next(self.iterator)
except StopIteration:
raise StopAsyncIteration

async def release(self, async_iterator):
"""
release iterator to pipeline the backpressure
"""
async for item in async_iterator:
try:
self.back_pressure.release()
except ValueError:
pass

yield item

class AsyncBackPressuredMap:

class AsyncUnorderedMap:
"""
Wrapper for an iterator to provide
async access with backpressure
Async iterator that maps a function to an async iterator
usign an executor and returns items as they are done
This does not guarantee order
"""

def __init__(self, iterator, func, max_run, executor, total=None):
self.iterator = iter(iterator)
def __init__(self, func, async_iterator, executor):
self.iterator = async_iterator
self.func = func
self.executor = executor
self.back_pressure = ProcessItemsSemaphore(value=max_run, total=total)
self.fill_task = create_task(self.get_from_iterator())

self.done_sentinel = object()
self.results = Queue()
self.tasks = {}

async def process_and_release(self, idx):
future = self.tasks[idx]
try:
item = await future
self.results.put_nowait(item)
except Exception:
pass
finally:
self.tasks.pop(idx)

async def get_from_iterator(self):
loop = get_running_loop()
async for idx, item in enumerate(self.iterator):
future = loop.run_in_executor(
self.executor, safe_dispatch, (self.func, item)
)

self.tasks[idx] = future

loop.create_task(self.process_and_release(idx))

await gather(*self.tasks.values())
self.results.put_nowait(self.done_sentinel)

def __aiter__(self):
return self

async def __anext__(self):
await self.back_pressure.acquire()
loop = get_running_loop()
item = await self.results.get()

try:
item = next(self.iterator)
except StopIteration:
if item == self.done_sentinel:
raise StopAsyncIteration
else:
return item

future = loop.run_in_executor(self.executor, safe_dispatch, (self.func, item))

async def process_and_release():
await future
self.back_pressure.release()
return future
async def atqdm(async_iterator, *args, **kwargs):
"""
Wrapper around tqdm for async generators
"""
_tqdm = tqdm(*args, **kwargs)
async for item in async_iterator:
_tqdm.update()
yield item

return process_and_release()
_tqdm.close()


async def grouper(iterable, n, fillvalue=None):
async def grouper(async_iterator, n: int):
"""
Collect data into fixed-length chunks or blocks.
>>> list(grouper(3, 'ABCDEFG'))
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
Updated from:
https://stackoverflow.com/questions/31164731/python-chunking-csv-file-multiproccessing/31170795#31170795
Modified for async
"""
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
args = [iterable] * n
iterator = zip_longest(*args, fillvalue=fillvalue)
chunk = []
async for item in async_iterator:
chunk.append(item)
if len(chunk) >= n:
yield chunk
chunk.clear()
if chunk != []:
yield chunk

async for group in iterator:
group = [g for g in group if g is not None]
yield group

def safe_dispatch(val):
func, item = val
try:
return func(item)
except Exception as e:
logger.error(e)
return None


async def multi(builder, num_workers):
Expand All @@ -107,15 +176,6 @@ async def multi(builder, num_workers):
elif hasattr(cursor, "count"):
total = cursor.count()

mapper = AsyncBackPressuredMap(
iterator=tqdm(cursor, desc="Get", total=total),
func=builder.process_item,
max_run=builder.chunk_size,
executor=executor,
total=total,
)
update_items = tqdm(total=total, desc="Update Targets")

logger.info(
f"Starting multiprocessing: {builder.__class__.__name__}",
extra={
Expand All @@ -128,9 +188,29 @@ async def multi(builder, num_workers):
}
},
)
async for chunk in grouper(mapper, builder.chunk_size, fillvalue=None):

back_pressured_get = BackPressure(
iterator=tqdm(cursor, desc="Get", total=total), n=builder.chunk_size
)

processed_items = atqdm(
async_iterator=AsyncUnorderedMap(
func=builder.process_item,
async_iterator=back_pressured_get,
executor=executor,
),
total=total,
desc="Process Items",
)

back_pressure_relief = back_pressured_get.release(processed_items)

update_items = tqdm(total=total, desc="Update Targets")

async for chunk in grouper(back_pressure_relief, n=builder.chunk_size):

logger.info(
"Processing batch of {} items".format(builder.chunk_size),
"Processed batch of {} items".format(builder.chunk_size),
extra={
"maggma": {
"event": "UPDATE",
Expand All @@ -141,9 +221,7 @@ async def multi(builder, num_workers):
}
},
)
chunk = await gather(*chunk)
processed_chunk = [c.result() for c in chunk if c is not None]
processed_items = [item for item in processed_chunk if item is not None]
processed_items = [item for item in chunk if item is not None]
builder.update_targets(processed_items)
update_items.update(len(processed_items))

Expand All @@ -158,4 +236,6 @@ async def multi(builder, num_workers):
}
},
)

update_items.close()
builder.finalize()
3 changes: 0 additions & 3 deletions src/maggma/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,6 @@ def updated_keys(self, target, criteria=None):
def __ne__(self, other):
return not self == other

def __hash__(self):
return hash((self.last_updated_field,))

def __getstate__(self):
return self.as_dict()

Expand Down
1 change: 0 additions & 1 deletion src/maggma/stores/gridfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pymongo import MongoClient
from monty.json import jsanitize
from monty.dev import deprecated
from maggma.utils import confirm_field_index
from maggma.core import Store, Sort
from maggma.stores import MongoStore

Expand Down
18 changes: 13 additions & 5 deletions src/maggma/stores/mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def name(self) -> str:
"""
return f"mongo://{self.host}/{self.database}/{self.collection_name}"

def connect(self, force_reset: bool = False, ssh_tunnel: SSHTunnelForwarder = None):
def connect(
self, force_reset: bool = False, ssh_tunnel: SSHTunnelForwarder = None
): # lgtm[py/conflicting-attributes]
"""
Connect to the source data
"""
Expand Down Expand Up @@ -341,7 +343,7 @@ def __init__(self, uri: str, database: str, collection_name: str, **kwargs):
self.collection_name = collection_name
self.kwargs = kwargs
self._collection = None
super(MongoStore, self).__init__(**kwargs)
super(MongoStore, self).__init__(**kwargs) # lgtm

@property
def name(self) -> str:
Expand All @@ -351,7 +353,9 @@ def name(self) -> str:
# TODO: This is not very safe since it exposes the username/password info
return self.uri

def connect(self, force_reset: bool = False, ssh_tunnel: SSHTunnelForwarder = None):
def connect(
self, force_reset: bool = False, ssh_tunnel: SSHTunnelForwarder = None
): # lgtm[py/conflicting-attributes]
"""
Connect to the source data
"""
Expand Down Expand Up @@ -380,7 +384,9 @@ def __init__(self, collection_name: str = "memory_db", **kwargs):
self.kwargs = kwargs
super(MongoStore, self).__init__(**kwargs) # noqa

def connect(self, force_reset: bool = False, ssh_tunnel: SSHTunnelForwarder = None):
def connect(
self, force_reset: bool = False, ssh_tunnel: SSHTunnelForwarder = None
): # lgtm[py/conflicting-attributes]
"""
Connect to the source data
"""
Expand Down Expand Up @@ -465,7 +471,9 @@ def __init__(self, paths: Union[str, List[str]], **kwargs):
self.kwargs = kwargs
super().__init__(collection_name="collection", **kwargs)

def connect(self, force_reset=False, ssh_tunnel=None):
def connect(
self, force_reset=False, ssh_tunnel=None
): # lgtm[py/conflicting-attributes]
"""
Loads the files into the collection in memory
"""
Expand Down
Loading

0 comments on commit 9a8f5cc

Please sign in to comment.