-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from dgasmith/procedures
Procedures Base Models
- Loading branch information
Showing
10 changed files
with
284 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .base import get_procedure, list_all_procedures, list_available_procedures, register_procedure |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
Imports the various procedure backends | ||
""" | ||
|
||
from typing import List, Set | ||
|
||
from .geometric import GeometricProcedure | ||
|
||
__all__ = ["register_procedure", "get_procedure", "list_all_procedures", "list_available_procedures"] | ||
|
||
procedures = {} | ||
|
||
|
||
def register_procedure(entry_point: 'BaseProcedure') -> None: | ||
""" | ||
Register a new BaseProcedure with QCEngine | ||
""" | ||
|
||
name = entry_point.name | ||
if name.lower() in procedures.keys(): | ||
raise ValueError('{} is already a registered procedure.'.format(name)) | ||
|
||
procedures[name.lower()] = entry_point | ||
|
||
|
||
def get_procedure(name: str) -> 'BaseProcedure': | ||
""" | ||
Returns a procedures executor class | ||
""" | ||
return procedures[name.lower()] | ||
|
||
|
||
def list_all_procedures() -> Set[str]: | ||
""" | ||
List all procedures registered by QCEngine. | ||
""" | ||
return set(procedures.keys()) | ||
|
||
|
||
def list_available_procedures() -> Set[str]: | ||
""" | ||
List all procedures that can be exectued (found) by QCEngine. | ||
""" | ||
|
||
ret = set() | ||
for k, p in procedures.items(): | ||
if p.found(): | ||
ret.add(k) | ||
|
||
return ret | ||
|
||
|
||
register_procedure(GeometricProcedure()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from typing import Any, Dict, Union | ||
|
||
from qcelemental.models import ComputeError, FailedOperation, Optimization, OptimizationInput | ||
|
||
from .procedure_model import BaseProcedure | ||
|
||
|
||
class GeometricProcedure(BaseProcedure): | ||
|
||
_defaults = {"name": "geomeTRIC", "procedure": "optimization"} | ||
|
||
class Config(BaseProcedure.Config): | ||
pass | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**{**self._defaults, **kwargs}) | ||
|
||
def build_input_model(self, data: Union[Dict[str, Any], 'OptimizationInput']) -> 'OptimizationInput': | ||
return self._build_model(data, OptimizationInput) | ||
|
||
def compute(self, input_data: 'OptimizationInput', config: 'JobConfig') -> 'Optimization': | ||
try: | ||
import geometric | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError("Could not find geomeTRIC in the Python path.") | ||
|
||
geometric_input = input_data.dict() | ||
|
||
# Older QCElemental compat, can be removed in v0.6 | ||
if "extras" not in geometric_input["input_specification"]: | ||
geometric_input["input_specification"]["extras"] = {} | ||
|
||
geometric_input["input_specification"]["extras"]["_qcengine_local_config"] = config.dict() | ||
|
||
# Run the program | ||
output_data = geometric.run_json.geometric_run_json(geometric_input) | ||
|
||
output_data["provenance"] = { | ||
"creator": "geomeTRIC", | ||
"routine": "geometric.run_json.geometric_run_json", | ||
"version": geometric.__version__ | ||
} | ||
|
||
output_data["schema_name"] = "qcschema_optimization_output" | ||
output_data["input_specification"]["extras"].pop("_qcengine_local_config", None) | ||
if output_data["success"]: | ||
output_data = Optimization(**output_data) | ||
|
||
return output_data | ||
|
||
def found(self) -> bool: | ||
try: | ||
import geometric | ||
return True | ||
except ModuleNotFoundError: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import abc | ||
from typing import Any, Dict, Union | ||
|
||
from pydantic import BaseModel | ||
|
||
from ..util import model_wrapper | ||
|
||
|
||
class BaseProcedure(BaseModel, abc.ABC): | ||
|
||
name: str | ||
procedure: str | ||
|
||
class Config: | ||
allow_mutation: False | ||
extra: "forbid" | ||
|
||
@abc.abstractmethod | ||
def build_input_model(self, data: Union[Dict[str, Any], 'BaseModel'], raise_error: bool=True) -> 'BaseModel': | ||
""" | ||
Build and validate the input model, passes if the data was a normal BaseModel input. | ||
Parameters | ||
---------- | ||
data : Union[Dict[str, Any], 'BaseModel'] | ||
A data blob to construct the model from or the input model itself | ||
raise_error : bool, optional | ||
Raise an error or not if the operation failed. | ||
Returns | ||
------- | ||
BaseModel | ||
The input model for the procedure. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def compute(self, input_data: 'BaseModel', config: 'JobConfig') -> 'BaseModel': | ||
pass | ||
|
||
@abc.abstractmethod | ||
def found(self) -> bool: | ||
""" | ||
Checks if the program can be found. | ||
Returns | ||
------- | ||
bool | ||
If the proceudre was found or not. | ||
""" | ||
pass | ||
|
||
def _build_model(self, data: Dict[str, Any], model: 'BaseModel') -> 'BaseModel': | ||
""" | ||
Quick wrapper around util.model_wrapper for inherited classes | ||
""" | ||
|
||
return model_wrapper(data, model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.