Skip to content

Commit

Permalink
Split torsion drive single-points into separate tasks (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored May 20, 2022
1 parent 5c60118 commit b7999db
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 232 deletions.
131 changes: 112 additions & 19 deletions openff/bespokefit/executor/services/qcgenerator/cache.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import hashlib
from typing import TypeVar, Union
from typing import TYPE_CHECKING, Optional, TypeVar, Union

import redis
from openff.toolkit.topology import Molecule

from openff.bespokefit.executor.services.qcgenerator import worker
from openff.bespokefit.schema.tasks import HessianTask, OptimizationTask, Torsion1DTask
from openff.bespokefit.schema.tasks import (
HessianTask,
OptimizationTask,
QCGenerationTask,
Torsion1DTask,
)
from openff.bespokefit.utilities.molecule import canonical_order_atoms

if TYPE_CHECKING:
# Only use as a type hint. Use `celery_app.AsyncResult` to initialize
from celery.result import AsyncResult

_T = TypeVar("_T", HessianTask, OptimizationTask, Torsion1DTask)


Expand Down Expand Up @@ -50,6 +59,95 @@ def _canonicalize_task(task: _T) -> _T:
return task


def _hash_task(task: QCGenerationTask) -> str:
"""Returns a hashed representation of a QC task"""
return hashlib.sha512(task.json().encode()).hexdigest()


def _retrieve_cached_task_id(
task_hash: str, redis_connection: redis.Redis
) -> Optional[str]:
"""Retrieve the task ID of a cached QC task if present in the redis cache"""

task_id = redis_connection.hget("qcgenerator:task-ids", task_hash)

return None if task_id is None else task_id.decode()


def _cache_task_id(
task_id: str, task_type: str, task_hash: str, redis_connection: redis.Redis
):
"""Store the ID of a running QC task in the QC task cache."""

redis_connection.hset("qcgenerator:types", task_id, task_type)
# Make sure to only set the hash after the type is set in case the connection
# goes down before this information is entered and subsequently discarded.
redis_connection.hset("qcgenerator:task-ids", task_hash, task_id)


def _compute_torsion_drive_task(
task: Torsion1DTask, redis_connection: redis.Redis
) -> str:
"""Submit a torsion drive to celery, optionally chaining together a torsion
drive followed by a single point energy re-evaluation."""

task_id = None

torsion_drive_task = task.copy(deep=True)
torsion_drive_task.sp_specification = None

torsion_drive_hash = _hash_task(torsion_drive_task)
torsion_drive_id = _retrieve_cached_task_id(torsion_drive_hash, redis_connection)

if torsion_drive_id is None:

# There are no cached torsion drives at the 'pre-optimise' level of theory
# we need to run a torsion drive and then optionally a single point
if task.sp_specification is None:

torsion_drive_id = worker.compute_torsion_drive.delay(
task_json=task.json()
).id

else:

task_future: AsyncResult = (
worker.compute_torsion_drive.s(task_json=task.json())
| worker.evaluate_torsion_drive.s(
model_json=task.sp_specification.model.json(),
program=task.sp_specification.program,
)
).delay()

torsion_drive_id = task_future.parent.id
task_id = task_future.id

_cache_task_id(
torsion_drive_id, task.type, torsion_drive_hash, redis_connection
)

if task.sp_specification is None:
return torsion_drive_id

if task_id is None:

# Handle the case where we have a running torsion drive that we need to
# append a single point calculation to the end of.
task_id = (
(
worker.wait_for_task.s(torsion_drive_id)
| worker.evaluate_torsion_drive.s(
model_json=task.sp_specification.model.json(),
program=task.sp_specification.program,
)
)
.delay()
.id
)

return task_id


def cached_compute_task(
task: Union[HessianTask, OptimizationTask, Torsion1DTask],
redis_connection: redis.Redis,
Expand All @@ -58,28 +156,23 @@ def cached_compute_task(
worker.
"""

if isinstance(task, Torsion1DTask):
compute = worker.compute_torsion_drive
elif isinstance(task, OptimizationTask):
compute = worker.compute_optimization
elif isinstance(task, HessianTask):
compute = worker.compute_hessian
else:
raise NotImplementedError()

# Canonicalize the task to improve the cache hit rate.
task = _canonicalize_task(task)

task_hash = hashlib.sha512(task.json().encode()).hexdigest()
task_id = redis_connection.hget("qcgenerator:task-ids", task_hash)
task_hash = _hash_task(task)
task_id = _retrieve_cached_task_id(task_hash, redis_connection)

if task_id is not None:
return task_id.decode()
return task_id

task_id = compute.delay(task_json=task.json()).id
if isinstance(task, Torsion1DTask):
task_id = _compute_torsion_drive_task(task, redis_connection)
elif isinstance(task, OptimizationTask):
task_id = worker.compute_optimization.delay(task_json=task.json()).id
elif isinstance(task, HessianTask):
task_id = worker.compute_hessian.delay(task_json=task.json()).id
else:
raise NotImplementedError()

redis_connection.hset("qcgenerator:types", task_id, task.type)
# Make sure to only set the hash after the type is set in case the connection
# goes down before this information is entered and subsequently discarded.
redis_connection.hset("qcgenerator:task-ids", task_hash, task_id)
_cache_task_id(task_id, task.type, task_hash, redis_connection)
return task_id
Loading

0 comments on commit b7999db

Please sign in to comment.