From bb61adef9d86e6f377bce84f8b8b884cb8a82d4e Mon Sep 17 00:00:00 2001 From: Brady Johnston Date: Sat, 25 May 2024 09:24:07 +0200 Subject: [PATCH] begin strong typing with mypy --- .github/workflows/mypy.yml | 32 +++ build.py | 2 +- docs/install.py | 2 +- molecularnodes/blender/coll.py | 15 +- molecularnodes/blender/nodes.py | 352 ++++++++++++++++++---------- molecularnodes/blender/obj.py | 101 +++++--- molecularnodes/color.py | 40 ++-- molecularnodes/io/parse/__init__.py | 11 +- molecularnodes/io/parse/cellpack.py | 12 +- molecularnodes/io/parse/molecule.py | 262 +++++++++++++-------- molecularnodes/io/parse/pdb.py | 37 ++- molecularnodes/io/parse/pdbx.py | 117 ++++----- molecularnodes/io/retrieve.py | 50 ++-- molecularnodes/io/wwpdb.py | 43 ++-- molecularnodes/props.py | 26 +- molecularnodes/utils.py | 10 +- pyproject.toml | 42 +++- tests/python.py | 2 +- tests/run.py | 6 +- tests/test_nodes.py | 67 ++++-- tests/test_select.py | 9 +- 21 files changed, 785 insertions(+), 453 deletions(-) create mode 100644 .github/workflows/mypy.yml diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 00000000..dd774122 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,32 @@ +name: mypy +on: + push: [push, pull_request] +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: chartboost/ruff-action@v1 + + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.11.7 + cache: pip + + - name: Build docs using blender + run: | + wget -nv https://download.blender.org/release/Blender4.1/blender-4.1.0-linux-x64.tar.xz + tar -xf blender-4.1.0-linux-x64.tar.xz + + blender-4.1.0-linux-x64/blender --version + blender-4.1.0-linux-x64/blender -b --python tests/python.py -- -m pip install poetry + blender-4.1.0-linux-x64/blender -b --python tests/python.py -- -m poetry install --with dev + blender-4.1.0-linux-x64/blender -b --python tests/python.py -- -m mypy . + + diff --git a/build.py b/build.py index cc4773d0..82005f29 100644 --- a/build.py +++ b/build.py @@ -3,7 +3,7 @@ # zips up the template file -def zip_template(): +def zip_template() -> None: # Define the directory and zip file paths dir_path = "molecularnodes/assets/template/Molecular Nodes" zip_file_path = "molecularnodes/assets/template/Molecular Nodes.zip" diff --git a/docs/install.py b/docs/install.py index 3bf5bd38..70b309c3 100644 --- a/docs/install.py +++ b/docs/install.py @@ -3,7 +3,7 @@ import os -def main(): +def main() -> None: python = os.path.realpath(sys.executable) commands = [f"{python} -m pip install .", f"{python} -m pip install quartodoc"] diff --git a/molecularnodes/blender/coll.py b/molecularnodes/blender/coll.py index 950de7d6..1bd983d0 100644 --- a/molecularnodes/blender/coll.py +++ b/molecularnodes/blender/coll.py @@ -1,7 +1,8 @@ import bpy +from typing import Optional -def mn(): +def mn() -> bpy.types.Collection: """Return the MolecularNodes Collection The collection called 'MolecularNodes' inside the Blender scene is returned. If the @@ -14,7 +15,7 @@ def mn(): return coll -def armature(name="MN_armature"): +def armature(name: str = "MN_armature") -> bpy.types.Collection: coll = bpy.data.collections.get(name) if not coll: coll = bpy.data.collections.new(name) @@ -22,7 +23,7 @@ def armature(name="MN_armature"): return coll -def data(suffix=""): +def data(suffix: str = "") -> bpy.types.Collection: """A collection for storing MN related data objects.""" name = f"MN_data{suffix}" @@ -32,13 +33,11 @@ def data(suffix=""): mn().children.link(collection) # disable the view of the data collection - bpy.context.view_layer.layer_collection.children["MolecularNodes"].children[ - name - ].exclude = True + bpy.context.view_layer.layer_collection.children["MolecularNodes"].children[name].exclude = True return collection -def frames(name="", parent=None, suffix="_frames"): +def frames(name: str = "", parent: Optional[bpy.types.Object] = None, suffix: str = "_frames") -> bpy.types.Collection: """Create a Collection for Frames of a Trajectory Args: @@ -55,7 +54,7 @@ def frames(name="", parent=None, suffix="_frames"): return coll_frames -def cellpack(name="", parent=None, fallback=False): +def cellpack(name: str = "", parent: Optional[bpy.types.Object] = None, fallback: bool = False) -> bpy.types.Collection: full_name = f"cellpack_{name}" coll = bpy.data.collections.get(full_name) if coll and fallback: diff --git a/molecularnodes/blender/nodes.py b/molecularnodes/blender/nodes.py index 75cfe262..aaa818b8 100644 --- a/molecularnodes/blender/nodes.py +++ b/molecularnodes/blender/nodes.py @@ -4,6 +4,7 @@ import math import warnings import itertools +from typing import List, Optional, Union, Dict, Tuple from .. import utils from .. import color from .. import pkg @@ -51,7 +52,11 @@ ("surface", "Surface", "Solvent-accsible surface."), ("cartoon", "Cartoon", "Secondary structure cartoons"), ("ribbon", "Ribbon", "Continuous backbone ribbon."), - ("ball_and_stick", "Ball and Stick", "Spheres for atoms, sticks for bonds"), + ( + "ball_and_stick", + "Ball and Stick", + "Spheres for atoms, sticks for bonds", + ), ) bpy.types.Scene.MN_import_style = bpy.props.EnumProperty( @@ -66,12 +71,14 @@ class NodeGroupCreationError(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: self.message = message super().__init__(self.message) -def inputs(node): +def inputs( + node: bpy.types.GeometryNodeGroup, +) -> Dict[str, bpy.types.NodeSocket]: items = {} for item in node.interface.items_tree: if item.item_type == "SOCKET": @@ -80,7 +87,9 @@ def inputs(node): return items -def outputs(node): +def outputs( + node: bpy.types.GeometryNodeGroup, +) -> Dict[str, bpy.types.NodeSocket]: items = {} for item in node.interface.items_tree: if item.item_type == "SOCKET": @@ -89,34 +98,44 @@ def outputs(node): return items -def get_output_type(node, type="INT"): - for output in node.outputs: - if output.type == type: - return output - - -def set_selection(group, node, selection): +def set_selection( + tree: bpy.types.GeometryNodeTree, + node: bpy.types.GeometryNodeGroup, + selection: bpy.types.GeometryNodeGroup, +) -> bpy.types.GeometryNodeGroup: pos = node.location pos = [pos[0] - 200, pos[1] - 200] selection.location = pos - group.links.new(selection.outputs[0], node.inputs["Selection"]) + tree.links.new(selection.outputs[0], node.inputs["Selection"]) return selection -def create_debug_group(name="MolecularNodesDebugGroup"): - group = new_group(name=name, fallback=False) +def node_tree_debug( + name: str = "MolecularNodesTree", +) -> bpy.types.GeometryNodeTree: + group = new_tree(name=name, fallback=False) info = group.nodes.new("GeometryNodeObjectInfo") - group.links.new(info.outputs["Geometry"], group.nodes["Group Output"].inputs[0]) + group.links.new( + info.outputs["Geometry"], group.nodes["Group Output"].inputs[0] + ) return group -def add_selection(group, sel_name, input_list, field="chain_id"): +def add_selection( + group: bpy.types.GeometryNodeGroup, + sel_name: str, + input_list: List[str], + field: str = "chain_id", +) -> bpy.types.Node: style = style_node(group) sel_node = add_custom( group, custom_iswitch( - name="selection", iter_list=input_list, field=field, dtype="BOOLEAN" + name="selection", + iter_list=input_list, + field=field, + dtype="BOOLEAN", ).name, ) @@ -124,7 +143,7 @@ def add_selection(group, sel_name, input_list, field="chain_id"): return sel_node -def get_output(group): +def get_output(group: bpy.types.GeometryNodeGroup) -> bpy.types.Node: return group.nodes[ bpy.app.translations.pgettext_data( "Group Output", @@ -132,7 +151,7 @@ def get_output(group): ] -def get_input(group): +def get_input(group: bpy.types.GeometryNodeGroup) -> bpy.types.Node: return group.nodes[ bpy.app.translations.pgettext_data( "Group Input", @@ -140,7 +159,9 @@ def get_input(group): ] -def get_mod(object, name="MolecularNodes"): +def get_mod( + object: bpy.types.Object, name: str = "MolecularNodes" +) -> bpy.types.Modifier: node_mod = object.modifiers.get(name) if not node_mod: node_mod = object.modifiers.new(name, "NODES") @@ -148,7 +169,7 @@ def get_mod(object, name="MolecularNodes"): return node_mod -def format_node_name(name): +def format_node_name(name: str) -> str: "Formats a node's name for nicer printing." return ( name.strip("MN_") @@ -159,19 +180,21 @@ def format_node_name(name): ) -def get_nodes_last_output(group): +def get_nodes_last_output( + group: bpy.types.GeometryNodeTree, +) -> Tuple[bpy.types.GeometryNode, bpy.types.GeometryNode]: output = get_output(group) last = output.inputs[0].links[0].from_node return last, output -def previous_node(node): +def previous_node(node: bpy.types.GeometryNode) -> bpy.types.GeometryNode: "Get the node which is the first connection to the first input of this node" prev = node.inputs[0].links[0].from_node return prev -def style_node(group): +def style_node(group: bpy.types.Object) -> bpy.types.GeometryNode: prev = previous_node(get_output(group)) is_style_node = "style" in prev.name while not is_style_node: @@ -181,13 +204,13 @@ def style_node(group): return prev -def get_style_node(object): +def get_style_node(object: bpy.types.Object) -> bpy.types.GeometryNode: "Walk back through the primary node connections until you find the first style node" group = object.modifiers["MolecularNodes"].node_group return style_node(group) -def star_node(group): +def star_node(group: bpy.types.GeometryNodeTree) -> bpy.types.GeometryNode: prev = previous_node(get_output(group)) is_star_node = "MN_starfile_instances" in prev.name while not is_star_node: @@ -196,13 +219,13 @@ def star_node(group): return prev -def get_star_node(object): +def get_star_node(object: bpy.types.Object) -> bpy.types.GeometryNode: "Walk back through the primary node connections until you find the first style node" group = object.modifiers["MolecularNodes"].node_group return star_node(group) -def get_color_node(object): +def get_color_node(object: bpy.types.Object) -> bpy.types.GeometryNode: "Walk back through the primary node connections until you find the first style node" group = object.modifiers["MolecularNodes"].node_group for node in group.nodes: @@ -210,7 +233,11 @@ def get_color_node(object): return node -def insert_last_node(group, node, link_input=True): +def insert_last_node( + group: bpy.types.GeometryNodeTree, + node: bpy.types.GeometryNodeGroup, + link_input: bool = True, +) -> None: last, output = get_nodes_last_output(group) link = group.links.new location = output.location @@ -221,13 +248,13 @@ def insert_last_node(group, node, link_input=True): link(node.outputs[0], output.inputs[0]) -def realize_instances(obj): +def realize_instances(obj: bpy.types.Object) -> None: group = obj.modifiers["MolecularNodes"].node_group realize = group.nodes.new("GeometryNodeRealizeInstances") insert_last_node(group, realize) -def append(node_name, link=False): +def append(node_name: str, link: bool = False) -> bpy.types.GeometryNodeTree: node = bpy.data.node_groups.get(node_name) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -245,7 +272,12 @@ def append(node_name, link=False): if not node or link: node_name_components = node_name.split("_") if node_name_components[0] == "MN": - data_file = MN_DATA_FILE[:-6] + "_" + node_name_components[1] + ".blend" + data_file = ( + MN_DATA_FILE[:-6] + + "_" + + node_name_components[1] + + ".blend" + ) bpy.ops.wm.append( "EXEC_DEFAULT", directory=os.path.join(data_file, "NodeTree"), @@ -256,7 +288,7 @@ def append(node_name, link=False): return bpy.data.node_groups[node_name] -def material_default(): +def material_default() -> bpy.types.Material: """ Append MN Default to the .blend file it it doesn't already exist, and return that material. @@ -276,7 +308,7 @@ def material_default(): return bpy.data.materials[mat_name] -def MN_micrograph_material(): +def MN_micrograph_material() -> bpy.types.Material: """ Append MN_micrograph_material to the .blend file it it doesn't already exist, and return that material. @@ -287,7 +319,9 @@ def MN_micrograph_material(): return bpy.data.materials[mat_name] -def new_group(name="Geometry Nodes", geometry=True, fallback=True): +def new_tree( + name: str = "Geometry Nodes", geometry: bool = True, fallback: bool = True +) -> bpy.types.GeometryNodeTree: group = bpy.data.node_groups.get(name) # if the group already exists, return it and don't create a new one if group and fallback: @@ -310,7 +344,9 @@ def new_group(name="Geometry Nodes", geometry=True, fallback=True): return group -def assign_material(node, material="default"): +def assign_material( + node: bpy.types.GeometryNode, material: str = "default" +) -> None: material_socket = node.inputs.get("Material") if material_socket: if not material: @@ -321,7 +357,12 @@ def assign_material(node, material="default"): material_socket.default_value = material -def add_node(node_name, label: str = "", show_options=False, material="default"): +def add_node( + node_name: str, + label: str = "", + show_options: bool = False, + material: str = "default", +) -> None: # intended to be called upon button press in the node tree prev_context = bpy.context.area.type @@ -349,15 +390,15 @@ def add_node(node_name, label: str = "", show_options=False, material="default") def add_custom( - group, - name, - location=[0, 0], - width=200, - material="default", - show_options=False, - link=False, -): - node = group.nodes.new("GeometryNodeGroup") + tree: bpy.types.GeometryNodeTree, + name: str, + location: List[float] = [0, 0], + width: int = 200, + material: str = "default", + show_options: bool = False, + link: bool = False, +) -> bpy.types.GeometryNode: + node = tree.nodes.new("GeometryNodeGroup") node.node_tree = append(name, link=link) # if there is an input socket called 'Material', assign it to the base MN material @@ -374,7 +415,7 @@ def add_custom( return node -def change_style_node(object, style): +def change_style_node(object: bpy.types.Object, style: str) -> None: # get the node group that we are working on, to change the specific style node group = get_mod(object).node_group link = group.links.new @@ -421,7 +462,9 @@ def change_style_node(object, style): pass -def create_starting_nodes_starfile(object, n_images=1): +def create_starting_nodes_starfile( + object: bpy.types.Object, n_images: int = 1 +) -> None: # ensure there is a geometry nodes modifier called 'MolecularNodes' that is created and applied to the object node_mod = get_mod(object) @@ -430,7 +473,7 @@ def create_starting_nodes_starfile(object, n_images=1): # Make sure the aotmic material is loaded material_default() # create a new GN node group, specific to this particular molecule - group = new_group(node_name) + group = new_tree(node_name) node_mod.node_group = group link = group.links.new @@ -447,13 +490,17 @@ def create_starting_nodes_starfile(object, n_images=1): node_mod["Input_3"] = 1 -def create_starting_nodes_density(object, threshold=0.8, style="density_surface"): +def create_starting_nodes_density( + object: bpy.types.Object, + threshold: float = 0.8, + style: str = "density_surface", +) -> None: # ensure there is a geometry nodes modifier called 'MolecularNodes' that is created and applied to the object mod = get_mod(object) node_name = f"MN_density_{object.name}" # create a new GN node group, specific to this particular molecule - group = new_group(node_name, fallback=False) + group = new_tree(node_name, fallback=False) link = group.links.new mod.node_group = group @@ -471,8 +518,12 @@ def create_starting_nodes_density(object, threshold=0.8, style="density_surface" def create_starting_node_tree( - object, coll_frames=None, style="spheres", name=None, set_color=True -): + object: bpy.types.Object, + coll_frames: Optional[bpy.types.Collection] = None, + style: str = "spheres", + name: Optional[str] = None, + set_color: bool = True, +) -> None: """ Create a starting node tree for the inputted object. @@ -501,7 +552,7 @@ def create_starting_node_tree( # create a new GN node group, specific to this particular molecule mod = get_mod(object) - group = new_group(name) + group = new_tree(name) link = group.links.new mod.node_group = group @@ -519,10 +570,15 @@ def create_starting_node_tree( if set_color: node_color_set = add_custom(group, "MN_color_set", [200, 0]) node_color_common = add_custom(group, "MN_color_common", [-50, -150]) - node_random_color = add_custom(group, "MN_color_attribute_random", [-300, -150]) + node_random_color = add_custom( + group, "MN_color_attribute_random", [-300, -150] + ) link(node_input.outputs["Geometry"], node_color_set.inputs[0]) - link(node_random_color.outputs["Color"], node_color_common.inputs["Carbon"]) + link( + node_random_color.outputs["Color"], + node_color_common.inputs["Carbon"], + ) link(node_color_common.outputs[0], node_color_set.inputs["Color"]) link(node_color_set.outputs[0], node_style.inputs[0]) to_animate = node_color_set @@ -538,17 +594,24 @@ def create_starting_node_tree( node_animate = add_custom(group, "MN_animate_value", [500, -300]) node_animate_frames.inputs["Frames"].default_value = coll_frames - node_animate.inputs["To Max"].default_value = len(coll_frames.objects) - 1 + node_animate.inputs["To Max"].default_value = ( + len(coll_frames.objects) - 1 + ) link(to_animate.outputs[0], node_animate_frames.inputs[0]) link(node_animate_frames.outputs[0], node_style.inputs[0]) link(node_animate.outputs[0], node_animate_frames.inputs["Frame"]) -def combine_join_geometry(group, node_list, output="Geometry", join_offset=300): - link = group.links.new +def combine_join_geometry( + tree: bpy.types.GeometryNodeTree, + node_list: List[bpy.types.GeometryNode], + output: str = "Geometry", + join_offset: int = 300, +) -> bpy.types.GeometryNode: + link = tree.links.new max_x = max([node.location[0] for node in node_list]) - node_to_instances = group.nodes.new("GeometryNodeJoinGeometry") + node_to_instances = tree.nodes.new("GeometryNodeJoinGeometry") node_to_instances.location = [int(max_x + join_offset), 0] for node in reversed(node_list): @@ -556,7 +619,11 @@ def combine_join_geometry(group, node_list, output="Geometry", join_offset=300): return node_to_instances -def split_geometry_to_instances(name, iter_list=("A", "B", "C"), attribute="chain_id"): +def split_geometry_to_instances( + name: str, + iter_list: List[str] = ["A", "B", "C"], + attribute: str = "chain_id", +) -> bpy.types.GeometryNodeTree: """Create a Node to Split Geometry by an Attribute into Instances Splits the inputted geometry into instances, based on an attribute field. By @@ -565,40 +632,42 @@ def split_geometry_to_instances(name, iter_list=("A", "B", "C"), attribute="chai define how many times to create the required nodes. """ - group = new_group(name) - node_input = get_input(group) - node_output = get_output(group) + tree = new_tree(name) + node_input = get_input(tree) + node_output = get_output(tree) - named_att = group.nodes.new("GeometryNodeInputNamedAttribute") + named_att = tree.nodes.new("GeometryNodeInputNamedAttribute") named_att.location = [-200, -200] named_att.data_type = "INT" named_att.inputs[0].default_value = attribute - link = group.links.new + link = tree.links.new list_sep = [] for i, chain in enumerate(iter_list): pos = [i % 10, math.floor(i / 10)] - node_split = add_custom(group, ".MN_utils_split_instance") + node_split = add_custom(tree, ".MN_utils_split_instance") node_split.location = [int(250 * pos[0]), int(-300 * pos[1])] node_split.inputs["Group ID"].default_value = i link(named_att.outputs["Attribute"], node_split.inputs["Field"]) link(node_input.outputs["Geometry"], node_split.inputs["Geometry"]) list_sep.append(node_split) - node_instance = combine_join_geometry(group, list_sep, "Instance") + node_instance = combine_join_geometry(tree, list_sep, "Instance") node_output.location = [int(10 * 250 + 400), 0] link(node_instance.outputs[0], node_output.inputs[0]) - return group + return tree -def assembly_initialise(mol: bpy.types.Object): +def assembly_initialise(mol: bpy.types.Object) -> bpy.types.GeometryNodeTree: """ Setup the required data object and nodes for building an assembly. """ - transforms = utils.array_quaternions_from_dict(mol["biological_assemblies"]) + transforms = utils.array_quaternions_from_dict( + mol["biological_assemblies"] + ) data_object = obj.create_data_object( array=transforms, name=f"data_assembly_{mol.name}" ) @@ -608,7 +677,7 @@ def assembly_initialise(mol: bpy.types.Object): return tree_assembly -def assembly_insert(mol: bpy.types.Object): +def assembly_insert(mol: bpy.types.Object) -> None: """ Given a molecule, setup the required assembly node and insert it into the node tree. """ @@ -619,23 +688,31 @@ def assembly_insert(mol: bpy.types.Object): insert_last_node(get_mod(mol).node_group, node) -def create_assembly_node_tree(name, iter_list, data_object): +def create_assembly_node_tree( + name: str, iter_list: List[str], data_object: bpy.types.Object +) -> bpy.types.GeometryNodeTree: node_group_name = f"MN_assembly_{name}" - group = new_group(name=node_group_name) - link = group.links.new + tree = new_tree(name=node_group_name) + link = tree.links.new - n_assemblies = len(np.unique(obj.get_attribute(data_object, "assembly_id"))) + n_assemblies = len( + np.unique(obj.get_attribute(data_object, "assembly_id")) + ) node_group_instances = split_geometry_to_instances( - name=f".MN_utils_split_{name}", iter_list=iter_list, attribute="chain_id" + name=f".MN_utils_split_{name}", + iter_list=iter_list, + attribute="chain_id", ) node_group_assembly_instance = append(".MN_assembly_instance_chains") - node_instances = add_custom(group, node_group_instances.name, [0, 0]) - node_assembly = add_custom(group, node_group_assembly_instance.name, [200, 0]) + node_instances = add_custom(tree, node_group_instances.name, [0, 0]) + node_assembly = add_custom( + tree, node_group_assembly_instance.name, [200, 0] + ) node_assembly.inputs["data_object"].default_value = data_object - out_sockets = outputs(group) + out_sockets = outputs(tree) out_sockets[list(out_sockets)[0]].name = "Instances" socket_info = ( @@ -663,26 +740,29 @@ def create_assembly_node_tree(name, iter_list, data_object): ) for info in socket_info: - socket = group.interface.items_tree.get(info["name"]) + socket = tree.interface.items_tree.get(info["name"]) if not socket: - socket = group.interface.new_socket( + socket = tree.interface.new_socket( info["name"], in_out="INPUT", socket_type=info["type"] ) socket.default_value = info["default"] socket.min_value = info["min"] socket.max_value = info["max"] - link(get_input(group).outputs[info["name"]], node_assembly.inputs[info["name"]]) + link( + get_input(tree).outputs[info["name"]], + node_assembly.inputs[info["name"]], + ) - get_output(group).location = [400, 0] - link(get_input(group).outputs[0], node_instances.inputs[0]) + get_output(tree).location = [400, 0] + link(get_input(tree).outputs[0], node_instances.inputs[0]) link(node_instances.outputs[0], node_assembly.inputs[0]) - link(node_assembly.outputs[0], get_output(group).inputs[0]) + link(node_assembly.outputs[0], get_output(tree).inputs[0]) - return group + return tree -def add_inverse_selection(group): +def add_inverse_selection(group: bpy.types.GeometryNodeTree) -> None: output = get_output(group) if "Inverted" not in output.inputs.keys(): group.interface.new_socket( @@ -701,14 +781,14 @@ def add_inverse_selection(group): def custom_iswitch( - name, - iter_list, - field="chain_id", - dtype="BOOLEAN", - default_values=None, - prefix="", - start=0, -): + name: str, + iter_list: List[str], + field: str = "chain_id", + dtype: str = "BOOLEAN", + default_values: Optional[List[Union[str, int, float, bool]]] = None, + prefix: str = "", + start: int = 0, +) -> bpy.types.GeometryNodeTree: """ Creates a named `Index Switch` node. @@ -736,7 +816,7 @@ def custom_iswitch( Returns ------- - group : bpy.types.NodeGroup + group : bpy.types.GeometryNodeGroup The created node group. Raises @@ -750,7 +830,7 @@ def custom_iswitch( return group socket_type = socket_types[dtype] - group = new_group(name, geometry=False, fallback=False) + group = new_tree(name, geometry=False, fallback=False) # try creating the node group, otherwise on fail cleanup the created group and # report the error @@ -785,7 +865,9 @@ def custom_iswitch( # is colors, then generate a random pastel color for each value default_lookup = None if default_values: - default_lookup = dict(zip(iter_list, itertools.cycle(default_values))) + default_lookup = dict( + zip(iter_list, itertools.cycle(default_values)) + ) elif dtype == "RGBA": default_lookup = dict( zip(iter_list, [color.random_rgb() for i in iter_list]) @@ -807,12 +889,18 @@ def custom_iswitch( if default_lookup: socket.default_value = default_lookup[item] - link(node_input.outputs[socket.identifier], node_iswitch.inputs[str(i)]) + link( + node_input.outputs[socket.identifier], + node_iswitch.inputs[str(i)], + ) socket_out = group.interface.new_socket( name="Color", in_out="OUTPUT", socket_type=socket_type ) - link(node_iswitch.outputs["Output"], node_output.inputs[socket_out.identifier]) + link( + node_iswitch.outputs["Output"], + node_output.inputs[socket_out.identifier], + ) return group @@ -825,7 +913,9 @@ def custom_iswitch( ) -def resid_multiple_selection(node_name, input_resid_string): +def resid_multiple_selection( + node_name: str, input_resid_string: str +) -> bpy.types.GeometryNodeTree: """ Returns a node group that takes an integer input and creates a boolean tick box for each item in the input list. Outputs are the selected @@ -856,14 +946,13 @@ def resid_multiple_selection(node_name, input_resid_string): # create the custom node group data block, where everything will go # also create the required group node input and position it - residue_id_group = bpy.data.node_groups.new(node_name, "GeometryNodeTree") - node_input = residue_id_group.nodes.new("NodeGroupInput") + tree = bpy.data.node_groups.new(node_name, "GeometryNodeTree") + node_input = tree.nodes.new("NodeGroupInput") node_input.location = [0, node_sep_dis * len(sub_list) / 2] - group_link = residue_id_group.links.new - new_node = residue_id_group.nodes.new + link = tree.links.new + new_node = tree.nodes.new - prev = None for residue_id_index, residue_id in enumerate(sub_list): # add an new node of Select Res ID or MN_sek_res_id_range current_node = new_node("GeometryNodeGroup") @@ -877,53 +966,64 @@ def resid_multiple_selection(node_name, input_resid_string): # set two new inputs current_node.node_tree = append("MN_select_res_id_range") [resid_start, resid_end] = residue_id.split("-")[:2] - socket_1 = residue_id_group.interface.new_socket( + socket_1 = tree.interface.new_socket( "res_id: Min", in_out="INPUT", socket_type="NodeSocketInt" ) socket_1.default_value = int(resid_start) - socket_2 = residue_id_group.interface.new_socket( + socket_2 = tree.interface.new_socket( "res_id: Max", in_out="INPUT", socket_type="NodeSocketInt" ) socket_2.default_value = int(resid_end) # a residue range - group_link(node_input.outputs[socket_1.identifier], current_node.inputs[0]) - group_link(node_input.outputs[socket_2.identifier], current_node.inputs[1]) + link( + node_input.outputs[socket_1.identifier], current_node.inputs[0] + ) + link( + node_input.outputs[socket_2.identifier], current_node.inputs[1] + ) else: # create a node current_node.node_tree = append("MN_select_res_id_single") - socket = residue_id_group.interface.new_socket( + socket = tree.interface.new_socket( "res_id", in_out="INPUT", socket_type="NodeSocketInt" ) socket.default_value = int(residue_id) - group_link(node_input.outputs[socket.identifier], current_node.inputs[0]) + link(node_input.outputs[socket.identifier], current_node.inputs[0]) # set the coordinates current_node.location = [200, (residue_id_index + 1) * node_sep_dis] - if not prev: + if residue_id_index == 0: # link the first residue selection to the first input of its OR block - group_link(current_node.outputs["Selection"], bool_math.inputs[0]) + link(current_node.outputs["Selection"], bool_math.inputs[0]) + prev = bool_math else: # if it is not the first residue selection, link the output to the previous or block - group_link(current_node.outputs["Selection"], prev.inputs[1]) + link(current_node.outputs["Selection"], prev.inputs[1]) # link the ouput of previous OR block to the current OR block - group_link(prev.outputs[0], bool_math.inputs[0]) - prev = bool_math + link(prev.outputs[0], bool_math.inputs[0]) + prev = bool_math # add a output block residue_id_group_out = new_node("NodeGroupOutput") - residue_id_group_out.location = [800, (residue_id_index + 1) / 2 * node_sep_dis] - residue_id_group.interface.new_socket( + residue_id_group_out.location = [ + 800, + (residue_id_index + 1) / 2 * node_sep_dis, + ] + tree.interface.new_socket( "Selection", in_out="OUTPUT", socket_type="NodeSocketBool" ) - residue_id_group.interface.new_socket( + tree.interface.new_socket( "Inverted", in_out="OUTPUT", socket_type="NodeSocketBool" ) - group_link(prev.outputs[0], residue_id_group_out.inputs["Selection"]) + link(prev.outputs[0], residue_id_group_out.inputs["Selection"]) invert_bool_math = new_node("FunctionNodeBooleanMath") - invert_bool_math.location = [600, (residue_id_index + 1) / 3 * 2 * node_sep_dis] + invert_bool_math.location = [ + 600, + (residue_id_index + 1) / 3 * 2 * node_sep_dis, + ] invert_bool_math.operation = "NOT" - group_link(prev.outputs[0], invert_bool_math.inputs[0]) - group_link(invert_bool_math.outputs[0], residue_id_group_out.inputs["Inverted"]) - return residue_id_group + link(prev.outputs[0], invert_bool_math.inputs[0]) + link(invert_bool_math.outputs[0], residue_id_group_out.inputs["Inverted"]) + return tree diff --git a/molecularnodes/blender/obj.py b/molecularnodes/blender/obj.py index 44c4cc71..c0f54a34 100644 --- a/molecularnodes/blender/obj.py +++ b/molecularnodes/blender/obj.py @@ -1,9 +1,11 @@ +from dataclasses import dataclass +from typing import Union, List, Optional +from typing import Optional, Type +from types import TracebackType import bpy import numpy as np -from . import coll -from . import nodes -from dataclasses import dataclass +from . import coll, nodes @dataclass @@ -28,17 +30,20 @@ class AttributeTypeInfo: class AttributeMismatchError(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: self.message = message super().__init__(self.message) -def centre(array: np.array): - return np.mean(array, axis=0) +def centre(array: np.array) -> np.ndarray: + return np.array(np.mean(array, axis=0)) -def centre_weighted(array: np.ndarray, weight: np.ndarray): - return np.sum(array * weight.reshape((len(array), 1)), axis=0) / np.sum(weight) +def centre_weighted(array: np.ndarray, weight: np.ndarray) -> np.ndarray: + return np.array( + np.sum(array * weight.reshape((len(array), 1)), axis=0) + / np.sum(weight) + ) class ObjectTracker: @@ -54,7 +59,7 @@ class ObjectTracker: Returns a list of new objects that were added to bpy.data.objects while in the context. """ - def __enter__(self): + def __enter__(self) -> "ObjectTracker": """ Store the current objects and their names when entering the context. @@ -66,10 +71,20 @@ def __enter__(self): self.objects = list(bpy.context.scene.objects) return self - def __exit__(self, type, value, traceback): - pass - - def new_objects(self): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + del self.objects + if exc_type is not None: + print(f"Exception detected: {exc_val}") + print(exc_tb) + return True + return False + + def new_objects(self) -> List[bpy.types.Object]: """ Find new objects that were added to bpy.data.objects while in the context. @@ -88,7 +103,7 @@ def new_objects(self): new_objects.append(bob) return new_objects - def latest(self): + def latest(self) -> bpy.types.Object: """ Get the most recently added object. @@ -103,11 +118,11 @@ def latest(self): def create_object( - vertices: np.ndarray = [], - edges: np.ndarray = [], - faces: np.ndarray = [], - name: str = "NewObject", - collection: bpy.types.Collection = None, + vertices: Union[np.ndarray, List[List[float]], None] = None, + edges: Union[np.ndarray, List[List[int]], None] = None, + faces: Union[np.ndarray, List[List[int]], None] = None, + name: Optional[str] = "NewObject", + collection: Optional[bpy.types.Collection] = None, ) -> bpy.types.Object: """ Create a new Blender object, initialised with locations for each vertex. @@ -134,8 +149,16 @@ def create_object( """ mesh = bpy.data.meshes.new(name) - mesh.from_pydata(vertices=vertices, edges=edges, faces=faces) + if edges is None: + edges = [] + if faces is None: + faces = [] + if vertices is None: + vertices = [[0, 0, 0]] + mesh.from_pydata(vertices=vertices, edges=edges, faces=faces) + if name is None: + name = "NewObject" object = bpy.data.objects.new(name, mesh) if not collection: @@ -153,8 +176,8 @@ def set_attribute( object: bpy.types.Object, name: str, data: np.ndarray, - type=None, - domain="POINT", + type: Optional[str] = None, + domain: str = "POINT", overwrite: bool = True, ) -> bpy.types.Attribute: """ @@ -210,12 +233,16 @@ def set_attribute( if len(data) != len(attribute.data): raise AttributeMismatchError( - f"Data length {len(data)}, dimensions {data.shape} does not equal the size of the target domain {domain}, len={len(attribute.data)=}" + f"Data length {len(data)}, dimensions {data.shape} " + f"does not equal the size of the target domain {domain}, " + f"len={len(attribute.data)=}" ) # the 'foreach_set' requires a 1D array, regardless of the shape of the attribute # it also requires the order to be 'c' or blender might crash!! - attribute.data.foreach_set(TYPES[type].dname, data.reshape(-1).copy(order="c")) + attribute.data.foreach_set( + TYPES[type].dname, data.reshape(-1).copy(order="C") + ) # The updating of data doesn't work 100% of the time (see: # https://projects.blender.org/blender/blender/issues/118507) so this resetting of a @@ -232,7 +259,7 @@ def set_attribute( def get_attribute( - object: bpy.types.Object, name="position", evaluate=False + object: bpy.types.Object, name: str = "position", evaluate: bool = False ) -> np.ndarray: """ Get the attribute data from the object. @@ -269,7 +296,7 @@ def get_attribute( # we have the initialise the array first with the appropriate length, then we can # fill it with the given data using the 'foreach_get' method which is super fast C++ # internal method - array = np.zeros(n_att * width, dtype=data_type.dtype) + array: np.ndarray = np.zeros(n_att * width, dtype=data_type.dtype) # it is currently not really consistent, but to get the values you need to use one of # the 'value', 'vector', 'color' etc from the types dict. This I could only figure # out through trial and error. I assume this might be changed / improved in the future @@ -282,7 +309,9 @@ def get_attribute( return array -def import_vdb(file: str, collection: bpy.types.Collection = None) -> bpy.types.Object: +def import_vdb( + file: str, collection: bpy.types.Collection = None +) -> bpy.types.Object: """ Imports a VDB file as a Blender volume object, in the MolecularNodes collection. @@ -312,13 +341,13 @@ def import_vdb(file: str, collection: bpy.types.Collection = None) -> bpy.types. return object -def evaluated(object): +def evaluated(object: bpy.types.Object) -> bpy.types.Object: "Return an object which has the modifiers evaluated." object.update_tag() return object.evaluated_get(bpy.context.evaluated_depsgraph_get()) -def evaluate_using_mesh(object): +def evaluate_using_mesh(object: bpy.types.Object) -> bpy.types.Object: """ Evaluate the object using a debug object. Some objects can't currently have their Geometry Node trees evaluated (such as volumes), so we source the geometry they create @@ -341,7 +370,7 @@ def evaluate_using_mesh(object): # object types can't be currently through the API debug = create_object() mod = nodes.get_mod(debug) - mod.node_group = nodes.create_debug_group() + mod.node_group = nodes.node_tree_debug() mod.node_group.nodes["Object Info"].inputs["Object"].default_value = object # need to use 'evaluate' otherwise the modifiers won't be taken into account @@ -349,8 +378,12 @@ def evaluate_using_mesh(object): def create_data_object( - array, collection=None, name="DataObject", world_scale=0.01, fallback=False -): + array: np.ndarray, + collection: Optional[bpy.types.Collection] = None, + name: str = "DataObject", + world_scale: float = 0.01, + fallback: bool = False, +) -> bpy.types.Object: # still requires a unique call TODO: figure out why # I think this has to do with the bcif instancing extraction array = np.unique(array) @@ -378,6 +411,8 @@ def create_data_object( if np.issubdtype(data.dtype, str): data = np.unique(data, return_inverse=True)[1] - set_attribute(object, name=column, data=data, type=type, domain="POINT") + set_attribute( + object, name=column, data=data, type=type, domain="POINT" + ) return object diff --git a/molecularnodes/color.py b/molecularnodes/color.py index 9c1f3a00..b8965eaa 100644 --- a/molecularnodes/color.py +++ b/molecularnodes/color.py @@ -1,27 +1,32 @@ import random import colorsys import numpy as np +from numpy.typing import NDArray +from typing import List, Optional, Any, Dict, Tuple -def random_rgb(seed=None): +def random_rgb(seed: int = 6) -> NDArray[np.float64]: """Random Pastel RGB values""" - if seed: - random.seed(seed) + random.seed(seed) r, g, b = colorsys.hls_to_rgb(random.random(), 0.6, 0.6) - return np.array((r, g, b, 1)) + return np.array((r, g, b, 1.0)) -def color_from_atomic_number(atomic_number: int): +def color_from_atomic_number(atomic_number: int) -> Tuple[int, int, int, int]: r, g, b = list(iupac_colors_rgb.values())[int(atomic_number - 1)] - return np.array((r, g, b, 1)) + return (r, g, b, 1) -def colors_from_elements(atomic_numbers): - colors = np.array(list(map(color_from_atomic_number, atomic_numbers))) +def colors_from_elements( + atomic_numbers: NDArray[np.int32], +) -> NDArray[np.float64]: + colors = np.array([color_from_atomic_number(x) for x in atomic_numbers]) return colors -def equidistant_colors(some_list): +def equidistant_colors( + some_list: NDArray[np.character], +) -> Dict[str, List[Tuple[int, int, int, int]]]: u = np.unique(some_list) num_colors = len(u) @@ -31,22 +36,29 @@ def equidistant_colors(some_list): colors = [colorsys.hls_to_rgb(hue, 0.6, 0.6) for hue in hues] # Convert RGB to 8-bit integer values - colors = [(int(r * 255), int(g * 255), int(b * 255), 1) for (r, g, b) in colors] + colors = [ + (int(r * 255), int(g * 255), int(b * 255), int(1)) # type: ignore + for (r, g, b) in colors + ] - return dict(zip(u, colors)) + return dict(zip(u, colors)) # type: ignore -def color_chains_equidistant(chain_ids): +def color_chains_equidistant( + chain_ids: NDArray[np.character], +) -> NDArray[np.float32]: color_dict = equidistant_colors(chain_ids) chain_colors = np.array([color_dict[x] for x in chain_ids]) return chain_colors / 255 -def color_chains(atomic_numbers, chain_ids): +def color_chains( + atomic_numbers: NDArray[np.int32], chain_ids: NDArray[np.character] +) -> NDArray[np.float32]: mask = atomic_numbers == 6 colors = colors_from_elements(atomic_numbers) chain_color_dict = equidistant_colors(chain_ids) - chain_colors = np.array(list(map(lambda x: chain_color_dict.get(x), chain_ids))) + chain_colors = np.array([chain_color_dict.get(x) for x in chain_ids]) colors[mask] = chain_colors[mask] diff --git a/molecularnodes/io/parse/__init__.py b/molecularnodes/io/parse/__init__.py index cdbfd5af..ec35f957 100644 --- a/molecularnodes/io/parse/__init__.py +++ b/molecularnodes/io/parse/__init__.py @@ -10,4 +10,13 @@ from .sdf import SDF from .star import StarFile -__all__ = [CIF, BCIF, PDB, CellPack, StarFile, SDF, MDAnalysisSession, MRC] +__all__ = [ + "CIF", + "BCIF", + "PDB", + "CellPack", + "StarFile", + "SDF", + "MDAnalysisSession", + "MRC", +] diff --git a/molecularnodes/io/parse/cellpack.py b/molecularnodes/io/parse/cellpack.py index 02ec0d57..f8bf0ee1 100644 --- a/molecularnodes/io/parse/cellpack.py +++ b/molecularnodes/io/parse/cellpack.py @@ -69,7 +69,11 @@ def _create_object_instances( colors = np.tile(color.random_rgb(i), (len(chain_atoms), 1)) bl.obj.set_attribute( - model, name="Color", data=colors, type="FLOAT_COLOR", overwrite=True + model, + name="Color", + data=colors, + type="FLOAT_COLOR", + overwrite=True, ) if node_setup: @@ -93,10 +97,12 @@ def _create_data_object(self, name="DataObject"): def _setup_node_tree(self, name="CellPack", fraction=1.0, as_points=False): mod = bl.nodes.get_mod(self.data_object) - group = bl.nodes.new_group(name=f"MN_ensemble_{name}", fallback=False) + group = bl.nodes.new_tree(name=f"MN_ensemble_{name}", fallback=False) mod.node_group = group - node_pack = bl.nodes.add_custom(group, "MN_pack_instances", location=[-100, 0]) + node_pack = bl.nodes.add_custom( + group, "MN_pack_instances", location=[-100, 0] + ) node_pack.inputs["Collection"].default_value = self.data_collection node_pack.inputs["Fraction"].default_value = fraction node_pack.inputs["As Points"].default_value = as_points diff --git a/molecularnodes/io/parse/molecule.py b/molecularnodes/io/parse/molecule.py index 40f3c1e5..2737b7ea 100644 --- a/molecularnodes/io/parse/molecule.py +++ b/molecularnodes/io/parse/molecule.py @@ -1,13 +1,16 @@ -import time -import warnings from abc import ABCMeta -from typing import Optional - -import bpy +from typing import Optional, Any, Tuple, Union, List +import biotite.structure +from numpy.typing import NDArray +from pathlib import Path +import warnings +import biotite +import time import numpy as np +import bpy from ... import blender as bl -from ... import color, data, utils +from ... import utils, data, color class Molecule(metaclass=ABCMeta): @@ -50,14 +53,15 @@ class Molecule(metaclass=ABCMeta): Get the biological assemblies of the molecule. """ - def __init__(self): - self.file_path: str = None - self.file: str = None + def __init__(self, file_path: str) -> None: + self.file_path: Optional[Union[Path, str]] = None + self.file = None self.object: Optional[bpy.types.Object] = None self.frames: Optional[bpy.types.Collection] = None self.array: Optional[np.ndarray] = None + self.entity_ids: Optional[List[str]] = None - def __len__(self): + def __len__(self) -> Union[int, None]: if hasattr(self, "object"): if self.object: return len(self.object.data.vertices) @@ -67,37 +71,38 @@ def __len__(self): return None @property - def n_models(self): + def n_models(self) -> int: import biotite.structure as struc if isinstance(self.array, struc.AtomArray): return 1 - else: - return self.array.shape[0] + elif isinstance(self.array, struc.AtomArrayStack): + return len(self.array) + + return 0 @property - def chain_ids(self) -> Optional[list]: + def chain_ids(self) -> Optional[Any]: if self.array: if hasattr(self.array, "chain_id"): return np.unique(self.array.chain_id).tolist() - return None @property def name(self) -> Optional[str]: if self.object is not None: - return self.object.name + return str(self.object.name) else: - return None + return "" def set_attribute( self, data: np.ndarray, - name="NewAttribute", - type=None, - domain="POINT", - overwrite=True, - ): + name: str = "NewAttribute", + type: Optional[str] = None, + domain: str = "POINT", + overwrite: bool = True, + ) -> None: """ Set an attribute for the molecule. @@ -119,16 +124,17 @@ def set_attribute( Whether to overwrite an existing attribute with the same name, or create a new attribute with always a unique name. Default is True. """ - if not self.object: - warnings.warn( - "No object yet created. Use `create_model()` to create a corresponding object." - ) - return None bl.obj.set_attribute( - self.object, name=name, data=data, domain=domain, overwrite=overwrite + self.object, + name=name, + data=data, + domain=domain, + overwrite=overwrite, ) - def get_attribute(self, name="position", evaluate=False) -> np.ndarray | None: + def get_attribute( + self, name: str = "position", evaluate: bool = False + ) -> np.ndarray: """ Get the value of an attribute for the associated object. @@ -146,14 +152,9 @@ def get_attribute(self, name="position", evaluate=False) -> np.ndarray | None: np.ndarray The value of the attribute. """ - if not self.object: - warnings.warn( - "No object yet created. Use `create_model()` to create a corresponding object." - ) - return None return bl.obj.get_attribute(self.object, name=name, evaluate=evaluate) - def list_attributes(self, evaluate=False) -> list | None: + def list_attributes(self, evaluate: bool = False) -> Optional[List[str]]: """ Returns a list of attribute names for the object. @@ -199,11 +200,11 @@ def create_model( self, name: str = "NewMolecule", style: str = "spheres", - selection: np.ndarray = None, - build_assembly=False, + selection: Optional[np.ndarray] = None, + build_assembly: bool = False, centre: str = "", del_solvent: bool = True, - collection=None, + collection: Optional[bpy.types.Collection] = None, verbose: bool = False, ) -> bpy.types.Object: """ @@ -289,7 +290,9 @@ def create_model( return model - def assemblies(self, as_array=False): + def assemblies( + self, as_array: bool = False + ) -> Dict[str, List[float]] | None: """ Get the biological assemblies of the molecule. @@ -322,15 +325,15 @@ def __repr__(self) -> str: def _create_model( - array, - name=None, - centre="", - del_solvent=False, - style="spherers", - collection=None, - world_scale=0.01, - verbose=False, -) -> (bpy.types.Object, bpy.types.Collection): + array: biotite.structure.AtomArray, + name: Optional[str] = None, + centre: str = "", + del_solvent: bool = False, + style: str = "spherers", + collection: bpy.types.Collection = None, + world_scale: float = 0.01, + verbose: bool = False, +) -> Tuple[bpy.types.Object, bpy.types.Collection]: import biotite.structure as struc frames = None @@ -355,7 +358,9 @@ def _create_model( except AttributeError: pass - def centre_array(atom_array, centre): + def centre_array( + atom_array: biotite.structure.AtomArray, centre: str + ) -> None: if centre == "centroid": atom_array.coord -= bl.obj.centre(atom_array.coord) elif centre == "mass": @@ -368,7 +373,7 @@ def centre_array(atom_array, centre): for atom_array in array: centre_array(atom_array, centre) else: - centre_array(atom_array, centre) + centre_array(array, centre) if is_stack: if array.stack_depth() > 1: @@ -413,22 +418,24 @@ def centre_array(atom_array, centre): # I still don't like this as an implementation, and welcome any cleaner approaches that # anybody might have. - def att_atomic_number(): + def att_atomic_number() -> NDArray[np.int32]: atomic_number = np.array( [ - data.elements.get(x, {"atomic_number": -1}).get("atomic_number") + data.elements.get(x, {"atomic_number": -1}).get( + "atomic_number" + ) for x in np.char.title(array.element) ] ) return atomic_number - def att_atom_id(): + def att_atom_id() -> NDArray[np.int32]: return array.atom_id - def att_res_id(): + def att_res_id() -> NDArray[np.int32]: return array.res_id - def att_res_name(): + def att_res_name() -> NDArray[np.int32]: other_res = [] counter = 0 id_counter = -1 @@ -437,7 +444,9 @@ def att_res_name(): res_nums = [] for name in res_names: - res_num = data.residues.get(name, {"res_name_num": -1}).get("res_name_num") + res_num = data.residues.get(name, {"res_name_num": -1}).get( + "res_name_num" + ) if res_num == 9999: if ( @@ -450,7 +459,10 @@ def att_res_name(): other_res.append(unique_res_name) num = ( - np.where(np.isin(np.unique(other_res), unique_res_name))[0][0] + 100 + np.where(np.isin(np.unique(other_res), unique_res_name))[ + 0 + ][0] + + 100 ) res_nums.append(num) else: @@ -472,64 +484,65 @@ def att_b_factor(): def att_occupancy(): return array.occupancy - def att_vdw_radii(): + def att_vdw_radii() -> NDArray[np.float64]: vdw_radii = np.array( list( map( # divide by 100 to convert from picometres to angstroms which is # what all of coordinates are in - lambda x: data.elements.get(x, {}).get("vdw_radii", 100.0) / 100, + lambda x: data.elements.get(x, {}).get("vdw_radii", 100.0) + / 100, np.char.title(array.element), ) ) ) return vdw_radii * world_scale - def att_mass(): + def att_mass() -> NDArray[np.float64]: return array.mass - def att_atom_name(): + def att_atom_name() -> NDArray[np.int32]: atom_name = np.array( - list(map(lambda x: data.atom_names.get(x, -1), array.atom_name)) + [data.atom_names.get(x, -1) for x in array.atom_name] ) return atom_name - def att_lipophobicity(): + def att_lipophobicity() -> NDArray[np.float64]: lipo = np.array( - list( - map( - lambda x, y: data.lipophobicity.get(x, {"0": 0}).get(y, 0), + [ + data.lipophobicity.get(res_name, {"0": 0}).get(atom_name, 0) + for (res_name, atom_name) in zip( array.res_name, array.atom_name, ) - ) + ] ) return lipo - def att_charge(): + def att_charge() -> NDArray[np.float64]: charge = np.array( - list( - map( - lambda x, y: data.atom_charge.get(x, {"0": 0}).get(y, 0), + [ + data.atom_charge.get(res_name, {"0": 0}).get(atom_name, 0) + for (res_name, atom_name) in zip( array.res_name, array.atom_name, ) - ) + ] ) return charge - def att_color(): + def att_color() -> NDArray[np.float64]: return color.color_chains(att_atomic_number(), att_chain_id()) - def att_is_alpha(): + def att_is_alpha() -> NDArray[np.bool_]: return np.isin(array.atom_name, "CA") - def att_is_solvent(): + def att_is_solvent() -> NDArray[np.bool_]: return struc.filter_solvent(array) - def att_is_backbone(): + def att_is_backbone() -> NDArray[np.bool_]: """ Get the atoms that appear in peptide backbone or nucleic acid phosphate backbones. Filter differs from the Biotite's `struc.filter_peptide_backbone()` in that this @@ -564,37 +577,52 @@ def att_is_backbone(): ) return is_backbone - def att_is_nucleic(): + def att_is_nucleic() -> NDArray[np.bool_]: return struc.filter_nucleotides(array) - def att_is_peptide(): + def att_is_peptide() -> NDArray[np.bool_]: aa = struc.filter_amino_acids(array) con_aa = struc.filter_canonical_amino_acids(array) return aa | con_aa - def att_is_hetero(): + def att_is_hetero() -> NDArray[np.bool_]: return array.hetero - def att_is_carb(): + def att_is_carb() -> NDArray[np.bool_]: return struc.filter_carbohydrates(array) - def att_sec_struct(): + def att_sec_struct() -> NDArray[np.int32]: return array.sec_struct # these are all of the attributes that will be added to the structure # TODO add capcity for selection of particular attributes to include / not include to potentially # boost performance, unsure if actually a good idea of not. Need to do some testing. attributes = ( - {"name": "res_id", "value": att_res_id, "type": "INT", "domain": "POINT"}, - {"name": "res_name", "value": att_res_name, "type": "INT", "domain": "POINT"}, + { + "name": "res_id", + "value": att_res_id, + "type": "INT", + "domain": "POINT", + }, + { + "name": "res_name", + "value": att_res_name, + "type": "INT", + "domain": "POINT", + }, { "name": "atomic_number", "value": att_atomic_number, "type": "INT", "domain": "POINT", }, - {"name": "b_factor", "value": att_b_factor, "type": "FLOAT", "domain": "POINT"}, + { + "name": "b_factor", + "value": att_b_factor, + "type": "FLOAT", + "domain": "POINT", + }, { "name": "occupancy", "value": att_occupancy, @@ -607,19 +635,54 @@ def att_sec_struct(): "type": "FLOAT", "domain": "POINT", }, - {"name": "mass", "value": att_mass, "type": "FLOAT", "domain": "POINT"}, - {"name": "chain_id", "value": att_chain_id, "type": "INT", "domain": "POINT"}, - {"name": "entity_id", "value": att_entity_id, "type": "INT", "domain": "POINT"}, - {"name": "atom_id", "value": att_atom_id, "type": "INT", "domain": "POINT"}, - {"name": "atom_name", "value": att_atom_name, "type": "INT", "domain": "POINT"}, + { + "name": "mass", + "value": att_mass, + "type": "FLOAT", + "domain": "POINT", + }, + { + "name": "chain_id", + "value": att_chain_id, + "type": "INT", + "domain": "POINT", + }, + { + "name": "entity_id", + "value": att_entity_id, + "type": "INT", + "domain": "POINT", + }, + { + "name": "atom_id", + "value": att_atom_id, + "type": "INT", + "domain": "POINT", + }, + { + "name": "atom_name", + "value": att_atom_name, + "type": "INT", + "domain": "POINT", + }, { "name": "lipophobicity", "value": att_lipophobicity, "type": "FLOAT", "domain": "POINT", }, - {"name": "charge", "value": att_charge, "type": "FLOAT", "domain": "POINT"}, - {"name": "Color", "value": att_color, "type": "FLOAT_COLOR", "domain": "POINT"}, + { + "name": "charge", + "value": att_charge, + "type": "FLOAT", + "domain": "POINT", + }, + { + "name": "Color", + "value": att_color, + "type": "FLOAT_COLOR", + "domain": "POINT", + }, { "name": "is_backbone", "value": att_is_backbone, @@ -656,7 +719,12 @@ def att_sec_struct(): "type": "BOOLEAN", "domain": "POINT", }, - {"name": "is_carb", "value": att_is_carb, "type": "BOOLEAN", "domain": "POINT"}, + { + "name": "is_carb", + "value": att_is_carb, + "type": "BOOLEAN", + "domain": "POINT", + }, { "name": "sec_struct", "value": att_sec_struct, @@ -678,12 +746,12 @@ def att_sec_struct(): domain=att["domain"], ) if verbose: - print(f'Added {att["name"]} after {time.process_time() - start} s') - except (AttributeError, TypeError, KeyError) as e: - if verbose: - warnings.warn( - f"Unable to add attribute: {att['name']}. Error: {str(e)}" + print( + f'Added {att["name"]} after {time.process_time() - start} s' ) + except: + if verbose: + warnings.warn(f"Unable to add attribute: {att['name']}") print( f'Failed adding {att["name"]} after {time.process_time() - start} s' ) diff --git a/molecularnodes/io/parse/pdb.py b/molecularnodes/io/parse/pdb.py index 69c6b5cd..82c8b12a 100644 --- a/molecularnodes/io/parse/pdb.py +++ b/molecularnodes/io/parse/pdb.py @@ -1,12 +1,15 @@ import numpy as np +from typing import List, Union, AnyStr +from pathlib import Path + from .assembly import AssemblyParser from .molecule import Molecule class PDB(Molecule): - def __init__(self, file_path): - super().__init__() + def __init__(self, file_path: Union[Path, AnyStr]): + super().__init__(file_path=file_path) self.file_path = file_path self.file = self.read() self.array = self._get_structure() @@ -48,7 +51,9 @@ def _get_sec_struct(file, array): lines_helix = lines[np.char.startswith(lines, "HELIX")] lines_sheet = lines[np.char.startswith(lines, "SHEET")] if len(lines_helix) == 0 and len(lines_sheet) == 0: - raise struc.BadStructureError("No secondary structure information detected.") + raise struc.BadStructureError( + "No secondary structure information detected." + ) sec_struct = np.zeros(array.array_length(), int) @@ -75,7 +80,9 @@ def _get_mask(line, start1, end1, start2, end2, chainid): # create a mask for the array based on these values mask = np.logical_and( - np.logical_and(array.chain_id == chain_id, array.res_id >= start_num), + np.logical_and( + array.chain_id == chain_id, array.res_id >= start_num + ), array.res_id <= end_num, ) @@ -88,7 +95,9 @@ def _get_mask(line, start1, end1, start2, end2, chainid): # assign remaining AA atoms to 3 (loop), while all other remaining # atoms will be 0 (not relevant) - mask = np.logical_and(sec_struct == 0, struc.filter_canonical_amino_acids(array)) + mask = np.logical_and( + sec_struct == 0, struc.filter_canonical_amino_acids(array) + ) sec_struct[mask] = 3 @@ -113,7 +122,9 @@ def _comp_secondary_structure(array): conv_sse_char_int = {"a": 1, "b": 2, "c": 3, "": 0} char_sse = annotate_sse(array) - int_sse = np.array([conv_sse_char_int[char] for char in char_sse], dtype=int) + int_sse = np.array( + [conv_sse_char_int[char] for char in char_sse], dtype=int + ) atom_sse = spread_residue_wise(array, int_sse) return atom_sse @@ -173,9 +184,9 @@ def get_transformations(self, assembly_id): affected_chain_ids = [] transform_start = None for j, line in enumerate(assembly_lines[start:stop]): - if line.startswith("APPLY THE FOLLOWING TO CHAINS:") or line.startswith( - " AND CHAINS:" - ): + if line.startswith( + "APPLY THE FOLLOWING TO CHAINS:" + ) or line.startswith(" AND CHAINS:"): affected_chain_ids += [ chain_id.strip() for chain_id in line[30:].split(",") ] @@ -190,7 +201,9 @@ def get_transformations(self, assembly_id): "No 'BIOMT' records found for chosen assembly" ) - matrices = _parse_transformations(assembly_lines[transform_start:stop]) + matrices = _parse_transformations( + assembly_lines[transform_start:stop] + ) for matrix in matrices: transformations.append((affected_chain_ids, matrix.tolist())) @@ -215,7 +228,9 @@ def _parse_transformations(lines): # Each transformation requires 3 lines for the (x,y,z) components if len(lines) % 3 != 0: - raise biotite.InvalidFileError("Invalid number of transformation vectors") + raise biotite.InvalidFileError( + "Invalid number of transformation vectors" + ) n_transformations = len(lines) // 3 matrices = np.tile(np.identity(4), (n_transformations, 1, 1)) diff --git a/molecularnodes/io/parse/pdbx.py b/molecularnodes/io/parse/pdbx.py index 94a64b31..87ada18f 100644 --- a/molecularnodes/io/parse/pdbx.py +++ b/molecularnodes/io/parse/pdbx.py @@ -1,6 +1,7 @@ import numpy as np import warnings import itertools +from typing import List from .molecule import Molecule @@ -12,7 +13,12 @@ def __init__(self, file_path): @property def entity_ids(self): - return self.file.block.get("entity").get("pdbx_description").as_array().tolist() + return ( + self.file.block.get("entity") + .get("pdbx_description") + .as_array() + .tolist() + ) def _get_entity_id(self, array, file): chain_ids = file.block["entity_poly"]["pdbx_strand_id"].as_array() @@ -43,12 +49,15 @@ def get_structure( array = pdbx.get_structure(self.file, extra_fields=extra_fields) try: array.set_annotation( - "sec_struct", self._get_secondary_structure(array=array, file=self.file) + "sec_struct", + self._get_secondary_structure(array=array, file=self.file), ) except KeyError: warnings.warn("No secondary structure information.") try: - array.set_annotation("entity_id", self._get_entity_id(array, self.file)) + array.set_annotation( + "entity_id", self._get_entity_id(array, self.file) + ) except KeyError: warnings.warn("No entity ID information") @@ -62,47 +71,6 @@ def get_structure( def _assemblies(self): return CIFAssemblyParser(self.file).get_assemblies() - # # in the cif / BCIF file 3x4 transformation matrices are stored in individual - # # columns, this extracts them and returns them with additional row for scaling, - # # meaning an (n, 4, 4) array is returned, where n is the number of transformations - # # and each is a 4x4 transformaiton matrix - # cat_matrix = self.file.block['pdbx_struct_oper_list'] - # matrices = self._extract_matrices(cat_matrix) - - # # sometimes there will be missing opers / matrices. For example in the - # # 'square.bcif' file, the matrix IDs go all the way up to 18024, but only - # # 18023 matrices are defined. That is becuase matrix 12 is never referenced, so - # # isn't included in teh file. To get around this we have to just get the specific - # # IDs that are defined for the matrices and use that to lookup the correct index - # # in the matrices array. - # mat_ids = cat_matrix.get('id').as_array(int) - # mat_lookup = dict(zip(mat_ids, range(len(mat_ids)))) - - # category = self.file.block['pdbx_struct_assembly_gen'] - # ids = category['assembly_id'].as_array(int) - # opers = category['oper_expression'].as_array(str) - # asyms = category['asym_id_list'].as_array() - - # # constructs a dictionary of - # # { - # # '1': ((['A', 'B', C'], [4x4 matrix]), (['A', 'B'], [4x4 matrix])), - # # '2': ((['A', 'B', C'], [4x4 matrix])) - # # } - # # where each entry in the dictionary is a biological assembly, and each dictionary - # # value contains a list of tranasformations which need to be applied. Each entry in - # # the list of transformations is - # # ([chains to be affected], [4x4 transformation matrix]) - # assembly_dic = {} - # for idx, oper, asym in zip(ids, opers, asyms): - # trans = list() - # asym = asym.split(',') - # for op in _parse_opers(oper): - # i = int(op) - # trans.append((asym, matrices[mat_lookup[i]].tolist())) - # assembly_dic[str(idx)] = trans - - # return assembly_dic - def _extract_matrices(self, category): matrix_columns = [ "matrix[1][1]", @@ -119,7 +87,9 @@ def _extract_matrices(self, category): "vector[3]", ] - columns = [category[name].as_array().astype(float) for name in matrix_columns] + columns = [ + category[name].as_array().astype(float) for name in matrix_columns + ] matrices = np.empty((len(columns[0]), 4, 4), float) col_mask = np.tile((0, 1, 2, 3), 3) @@ -178,9 +148,15 @@ def _get_secondary_structure(self, file, array): # as normalquit sheet = file.block.get("struct_sheet_range") if sheet: - starts = np.append(starts, sheet["beg_auth_seq_id"].as_array().astype(int)) - ends = np.append(ends, sheet["end_auth_seq_id"].as_array().astype(int)) - chains = np.append(chains, sheet["end_auth_asym_id"].as_array().astype(str)) + starts = np.append( + starts, sheet["beg_auth_seq_id"].as_array().astype(int) + ) + ends = np.append( + ends, sheet["end_auth_seq_id"].as_array().astype(int) + ) + chains = np.append( + chains, sheet["end_auth_asym_id"].as_array().astype(str) + ) id_label = np.append(id_label, np.repeat("STRN", len(sheet["id"]))) if not conf and not sheet: @@ -225,28 +201,7 @@ def _get_secondary_structure(self, file, array): return secondary_structure -def _parse_opers(oper): - # we want the example '1,3,(5-8)' to expand to (1, 3, 5, 6, 7, 8). - op_ids = list() - - for group in oper.strip(")").split("("): - if "," in group: - for i in group.split(","): - op_ids.append() - - for group in oper.split(","): - if "-" not in group: - op_ids.append(str(group)) - continue - - start, stop = [int(x) for x in group.strip("()").split("-")] - for i in range(start, stop + 1): - op_ids.append(str(i)) - - return op_ids - - -def _ss_label_to_int(label): +def _ss_label_to_int(label: str) -> int: if "HELX" in label: return 1 elif "STRN" in label: @@ -297,7 +252,9 @@ def get_transformations(self, assembly_id): struct_oper_category = self._file.block["pdbx_struct_oper_list"] - if assembly_id not in assembly_gen_category["assembly_id"].as_array(str): + if assembly_id not in assembly_gen_category["assembly_id"].as_array( + str + ): raise KeyError(f"File has no Assembly ID '{assembly_id}'") # Extract all possible transformations indexed by operation ID @@ -351,7 +308,9 @@ def _extract_matrices(category, scale=True): "vector[3]", ] - columns = [category[name].as_array().astype(float) for name in matrix_columns] + columns = [ + category[name].as_array().astype(float) for name in matrix_columns + ] n = 4 if scale else 3 matrices = np.empty((len(columns[0]), n, 4), float) @@ -396,7 +355,10 @@ def _get_transformations(struct_oper): for index, id in enumerate(struct_oper["id"].as_array()): rotation_matrix = np.array( [ - [float(struct_oper[f"matrix[{i}][{j}]"][index]) for j in (1, 2, 3)] + [ + float(struct_oper[f"matrix[{i}][{j}]"][index]) + for j in (1, 2, 3) + ] for i in (1, 2, 3) ] ) @@ -429,14 +391,19 @@ def _parse_operation_expression(expression): if "-" in gexpr: first, last = gexpr.split("-") operations.append( - [str(id) for id in range(int(first), int(last) + 1)] + [ + str(id) + for id in range(int(first), int(last) + 1) + ] ) else: operations.append([gexpr]) else: # Range of operation IDs, they must be integers first, last = expr.split("-") - operations.append([str(id) for id in range(int(first), int(last) + 1)]) + operations.append( + [str(id) for id in range(int(first), int(last) + 1)] + ) elif "," in expr: # List of operation IDs operations.append(expr.split(",")) diff --git a/molecularnodes/io/retrieve.py b/molecularnodes/io/retrieve.py index 761d8a47..3c99cd2a 100644 --- a/molecularnodes/io/retrieve.py +++ b/molecularnodes/io/retrieve.py @@ -2,6 +2,9 @@ import requests import io +from typing import Union, Optional +from pathlib import Path + class FileDownloadPDBError(Exception): """ @@ -13,13 +16,18 @@ class FileDownloadPDBError(Exception): def __init__( self, - message="There was an error downloading the file from the Protein Data Bank. PDB or format for PDB code may not be available.", - ): + message: str = "There was an error downloading the file from the Protein Data Bank. PDB or format for PDB code may not be available.", + ) -> None: self.message = message super().__init__(self.message) -def download(code, format="cif", cache=None, database="rcsb"): +def download( + code: str, + format: str = "cif", + cache: Optional[Union[Path, str]] = None, + database: str = "rcsb", +) -> Union[Path, str, io.StringIO, io.BytesIO]: """ Downloads a structure from the specified protein data bank in the given format. @@ -47,18 +55,13 @@ def download(code, format="cif", cache=None, database="rcsb"): """ supported_formats = ["cif", "pdb", "bcif"] if format not in supported_formats: - raise ValueError(f"File format '{format}' not in: {supported_formats=}") + raise ValueError( + f"File format '{format}' not in: {supported_formats=}" + ) _is_binary = format in ["bcif"] filename = f"{code}.{format}" # create the cache location - if cache: - if not os.path.isdir(cache): - os.makedirs(cache) - - file = os.path.join(cache, filename) - else: - file = None # get the contents of the url try: @@ -69,22 +72,26 @@ def download(code, format="cif", cache=None, database="rcsb"): if _is_binary: content = r.content else: - content = r.text + content = r.text # type: ignore + + if cache: + if not os.path.isdir(cache): + os.makedirs(cache) - if file: + file = os.path.join(cache, filename) mode = "wb+" if _is_binary else "w+" with open(file, mode) as f: f.write(content) else: if _is_binary: - file = io.BytesIO(content) + file = io.BytesIO(content) # type: ignore else: - file = io.StringIO(content) + file = io.StringIO(content) # type: ignore return file -def _url(code, format, database="rcsb"): +def _url(code: str, format: str, database: str = "rcsb") -> str: "Get the URL for downloading the given file form a particular database." if database == "rcsb": @@ -100,12 +107,14 @@ def _url(code, format, database="rcsb"): # if database == "pdbe": # return f"https://www.ebi.ac.uk/pdbe/entry-files/download/{filename}" else: - ValueError(f"Database {database} not currently supported.") + raise ValueError(f"Database {database} not currently supported.") -def get_alphafold_url(code, format): +def get_alphafold_url(code: str, format: str) -> str: if format not in ["pdb", "cif", "bcif"]: - ValueError(f"Format {format} not currently supported from AlphaFold databse.") + ValueError( + f"Format {format} not currently supported from AlphaFold databse." + ) # we have to first query the database, then they'll return some JSON with a list # of metadata, some items of which will be the URLs for the computed models @@ -115,5 +124,4 @@ def get_alphafold_url(code, format): response = requests.get(url) print(f"{response=}") data = response.json()[0] - # return data[f'{format}Url'] - return data[f"{format}Url"] + return str(data[f"{format}Url"]) diff --git a/molecularnodes/io/wwpdb.py b/molecularnodes/io/wwpdb.py index 1f2c3afb..8a5e4cfb 100644 --- a/molecularnodes/io/wwpdb.py +++ b/molecularnodes/io/wwpdb.py @@ -1,27 +1,34 @@ from pathlib import Path +from typing import Optional, Union, Set +from pathlib import Path import bpy -from . import parse +from .parse import PDB, CIF, BCIF +from .parse.molecule import Molecule from .retrieve import FileDownloadPDBError, download def fetch( - pdb_code, - style="spheres", - centre="", - del_solvent=True, - cache_dir=None, - build_assembly=False, - format="bcif", -): + pdb_code: str, + style: Optional[str] = "spheres", + centre: str = "", + del_solvent: bool = True, + cache_dir: Optional[Union[Path, str]] = None, + build_assembly: bool = False, + format: str = "bcif", +) -> Molecule: if build_assembly: centre = "" file_path = download(code=pdb_code, format=format, cache=cache_dir) - parsers = {"pdb": parse.PDB, "cif": parse.CIF, "bcif": parse.BCIF} - molecule = parsers[format](file_path=file_path) + if format == "pdb": + molecule = PDB(file_path) + elif format == "cif": + molecule = CIF(file_path) + elif format == "bcif": + molecule = BCIF(file_path) model = molecule.create_model( name=pdb_code, @@ -62,7 +69,11 @@ def fetch( name="Format", description="Format to download as from the PDB", items=( - ("bcif", ".bcif", "Binary compressed .cif file, fastest for downloading"), + ( + "bcif", + ".bcif", + "Binary compressed .cif file, fastest for downloading", + ), ("cif", ".cif", "The new standard of .cif / .mmcif"), ("pdb", ".pdb", "The classic (and depcrecated) PDB format"), ), @@ -78,7 +89,7 @@ class MN_OT_Import_wwPDB(bpy.types.Operator): bl_description = "Download and open a structure from the Protein Data Bank" bl_options = {"REGISTER", "UNDO"} - def execute(self, context): + def execute(self, context: bpy.types.Context) -> Set[str]: scene = context.scene pdb_code = scene.MN_pdb_code cache_dir = scene.MN_cache_dir @@ -115,7 +126,7 @@ def execute(self, context): return {"CANCELLED"} bpy.context.view_layer.objects.active = mol.object - self.report({"INFO"}, message=f"Imported '{pdb_code}' as {mol.object.name}") + self.report({"INFO"}, message=f"Imported '{pdb_code}' as {mol.name}") return {"FINISHED"} @@ -123,7 +134,9 @@ def execute(self, context): # the UI for the panel, which will display the operator and the properties -def panel(layout, scene): +def panel( + layout: bpy.types.UILayout, scene: bpy.types.Scene +) -> bpy.types.UILayout: layout.label(text="Download from PDB", icon="IMPORT") layout.separator() diff --git a/molecularnodes/props.py b/molecularnodes/props.py index 316c6fa6..f3a267a5 100644 --- a/molecularnodes/props.py +++ b/molecularnodes/props.py @@ -1,13 +1,13 @@ import bpy +from bpy.props import BoolProperty, StringProperty, EnumProperty, IntProperty - -bpy.types.Scene.MN_import_centre = bpy.props.BoolProperty( +bpy.types.Scene.MN_import_centre = BoolProperty( name="Centre Structure", description="Move the imported Molecule on the World Origin", default=False, ) -bpy.types.Scene.MN_centre_type = bpy.props.EnumProperty( +bpy.types.Scene.MN_centre_type = EnumProperty( name="Method", default="mass", items=( @@ -26,21 +26,21 @@ ), ) -bpy.types.Scene.MN_import_del_solvent = bpy.props.BoolProperty( +bpy.types.Scene.MN_import_del_solvent = BoolProperty( name="Remove Solvent", description="Delete the solvent from the structure on import", default=True, ) -bpy.types.Scene.MN_import_panel_selection = bpy.props.IntProperty( +bpy.types.Scene.MN_import_panel_selection = IntProperty( name="MN_import_panel_selection", description="Import Panel Selection", subtype="NONE", default=0, ) -bpy.types.Scene.MN_import_build_assembly = bpy.props.BoolProperty( +bpy.types.Scene.MN_import_build_assembly = BoolProperty( name="Build Assembly", default=False ) -bpy.types.Scene.MN_import_node_setup = bpy.props.BoolProperty( +bpy.types.Scene.MN_import_node_setup = BoolProperty( name="Setup Nodes", default=True, description="Create and set up a Geometry Nodes tree on import", @@ -48,19 +48,19 @@ class MolecularNodesObjectProperties(bpy.types.PropertyGroup): - subframes: bpy.props.IntProperty( + subframes: IntProperty( # type: ignore name="Subframes", description="Number of subframes to interpolate for MD trajectories", default=0, - ) # type: ignore - molecule_type: bpy.props.StringProperty( + ) + molecule_type: StringProperty( # type: ignore name="Molecular Type", description="How the file was imported, dictating how MN interacts with it", default="", - ) # type: ignore - pdb_code: bpy.props.StringProperty( + ) + pdb_code: StringProperty( # type: ignore name="PDB", description="PDB code used to download this structure", maxlen=4, options={"HIDDEN"}, - ) # type: ignore + ) diff --git a/molecularnodes/utils.py b/molecularnodes/utils.py index b78c42c3..76af938c 100644 --- a/molecularnodes/utils.py +++ b/molecularnodes/utils.py @@ -142,7 +142,9 @@ def _install_template(filepath, subfolder="", overwrite=True): traceback.print_exc() return {"CANCELLED"} except zipfile.BadZipFile: - print("Bad zip file: The file is not a zip file or it is corrupted.") + print( + "Bad zip file: The file is not a zip file or it is corrupted." + ) traceback.print_exc() return {"CANCELLED"} @@ -152,7 +154,9 @@ def _install_template(filepath, subfolder="", overwrite=True): _module_filesystem_remove(path_app_templates, f) else: for f in file_to_extract_root: - path_dest = os.path.join(path_app_templates, os.path.basename(f)) + path_dest = os.path.join( + path_app_templates, os.path.basename(f) + ) if os.path.exists(path_dest): # self.report({'WARNING'}, tip_("File already installed to %r\n") % path_dest) return {"CANCELLED"} @@ -199,7 +203,7 @@ def _install_template(filepath, subfolder="", overwrite=True): ] -def array_quaternions_from_dict(transforms_dict): +def array_quaternions_from_dict(transforms_dict: dict) -> np.ndarray: n_transforms = 0 for assembly in transforms_dict.values(): for transform in assembly: diff --git a/pyproject.toml b/pyproject.toml index d16126f2..73235a08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ documentation = "https://bradyajohnston.github.io/MolecularNodes" [tool.poetry.dependencies] python = "~=3.11.0" -bpy = "~=4.1" +# bpy = "~=4.1" MDAnalysis = "~=2.7.0" biotite = "==0.40.0" mrcfile = "==1.4.3" @@ -26,6 +26,7 @@ pytest-cov = "*" syrupy = "*" quartodoc = "*" scipy = "*" +mypy = "*" [build-system] @@ -33,3 +34,42 @@ requires = ["poetry-core>=1.1.0"] build-backend = "poetry.core.masonry.api" [tool.setuptools_scm] + +[tool.mypy] +strict = true +ignore_missing_imports = true + +[tool.ruff] +# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. +select = ["E", "F"] +ignore = [] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] +unfixable = [] + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv"] + +line-length = 79 + diff --git a/tests/python.py b/tests/python.py index 4e428813..ee62c8cf 100644 --- a/tests/python.py +++ b/tests/python.py @@ -6,7 +6,7 @@ argv = argv[argv.index("--") + 1 :] -def main(): +def main() -> None: python = os.path.realpath(sys.executable) subprocess.run([python] + argv) diff --git a/tests/run.py b/tests/run.py index 2279ad82..cc184bde 100644 --- a/tests/run.py +++ b/tests/run.py @@ -10,15 +10,15 @@ # /Applications/Blender.app/Contents/MacOS/Blender -b -P tests/run.py -- . -k test_color_lookup_supplied -def main(): +def main() -> None: # run the test suite, and we have to manually return the result value if non-zero # value is returned for a failing test if len(argv) == 0: result = pytest.main() else: result = pytest.main(argv) - if result.value != 0: - sys.exit(result.value) + if result != 0: + sys.exit(result) if __name__ == "__main__": diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 8f1a71a1..6353267f 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -15,9 +15,12 @@ def test_node_name_format(): - assert mn.blender.nodes.format_node_name("MN_style_cartoon") == "Style Cartoon" assert ( - mn.blender.nodes.format_node_name("MN_dna_double_helix") == "DNA Double Helix" + mn.blender.nodes.format_node_name("MN_style_cartoon") == "Style Cartoon" + ) + assert ( + mn.blender.nodes.format_node_name("MN_dna_double_helix") + == "DNA Double Helix" ) assert ( mn.blender.nodes.format_node_name("MN_topo_vector_angle") @@ -29,12 +32,16 @@ def test_get_nodes(): bob = mn.io.fetch("4ozs", style="spheres").object assert ( - nodes.get_nodes_last_output(bob.modifiers["MolecularNodes"].node_group)[0].name + nodes.get_nodes_last_output(bob.modifiers["MolecularNodes"].node_group)[ + 0 + ].name == "MN_style_spheres" ) nodes.realize_instances(bob) assert ( - nodes.get_nodes_last_output(bob.modifiers["MolecularNodes"].node_group)[0].name + nodes.get_nodes_last_output(bob.modifiers["MolecularNodes"].node_group)[ + 0 + ].name == "Realize Instances" ) assert nodes.get_style_node(bob).name == "MN_style_spheres" @@ -42,7 +49,9 @@ def test_get_nodes(): bob2 = mn.io.fetch("1cd3", style="cartoon", build_assembly=True).object assert ( - nodes.get_nodes_last_output(bob2.modifiers["MolecularNodes"].node_group)[0].name + nodes.get_nodes_last_output( + bob2.modifiers["MolecularNodes"].node_group + )[0].name == "MN_assembly_1cd3" ) assert nodes.get_style_node(bob2).name == "MN_style_cartoon" @@ -62,10 +71,14 @@ def test_selection(): @pytest.mark.parametrize("code", codes) @pytest.mark.parametrize("attribute", ["chain_id", "entity_id"]) -def test_selection_working(snapshot_custom: NumpySnapshotExtension, attribute, code): +def test_selection_working( + snapshot_custom: NumpySnapshotExtension, attribute, code +): mol = mn.io.fetch(code, style="ribbon", cache_dir=data_dir).object group = mol.modifiers["MolecularNodes"].node_group - node_sel = nodes.add_selection(group, mol.name, mol[f"{attribute}s"], attribute) + node_sel = nodes.add_selection( + group, mol.name, mol[f"{attribute}s"], attribute + ) n = len(node_sel.inputs) @@ -90,13 +103,17 @@ def test_color_custom(snapshot_custom: NumpySnapshotExtension, code, attribute): ) group = mol.modifiers["MolecularNodes"].node_group node_col = mn.blender.nodes.add_custom(group, group_col.name, [0, -200]) - group.links.new(node_col.outputs[0], group.nodes["MN_color_set"].inputs["Color"]) + group.links.new( + node_col.outputs[0], group.nodes["MN_color_set"].inputs["Color"] + ) assert snapshot_custom == sample_attribute(mol, "Color", n=50) def test_custom_resid_selection(): - node = mn.blender.nodes.resid_multiple_selection("new_node", "1, 5, 10-20, 40-100") + node = mn.blender.nodes.resid_multiple_selection( + "new_node", "1, 5, 10-20, 40-100" + ) numbers = [1, 5, 10, 20, 40, 100] assert len(nodes.outputs(node)) == 2 counter = 0 @@ -110,7 +127,9 @@ def test_op_custom_color(): mol = mn.io.load(data_dir / "1cd3.cif").object mol.select_set(True) group = mn.blender.nodes.custom_iswitch( - name=f"MN_color_chain_{mol.name}", iter_list=mol["chain_ids"], dtype="RGBA" + name=f"MN_color_chain_{mol.name}", + iter_list=mol["chain_ids"], + dtype="RGBA", ) assert group @@ -143,11 +162,15 @@ def test_color_lookup_supplied(): def test_color_chain(snapshot_custom: NumpySnapshotExtension): mol = mn.io.load(data_dir / "1cd3.cif", style="cartoon").object group_col = mn.blender.nodes.custom_iswitch( - name=f"MN_color_chain_{mol.name}", iter_list=mol["chain_ids"], dtype="RGBA" + name=f"MN_color_chain_{mol.name}", + iter_list=mol["chain_ids"], + dtype="RGBA", ) group = mol.modifiers["MolecularNodes"].node_group node_col = mn.blender.nodes.add_custom(group, group_col.name, [0, -200]) - group.links.new(node_col.outputs[0], group.nodes["MN_color_set"].inputs["Color"]) + group.links.new( + node_col.outputs[0], group.nodes["MN_color_set"].inputs["Color"] + ) assert snapshot_custom == sample_attribute(mol, "Color") @@ -162,7 +185,9 @@ def test_color_entity(snapshot_custom: NumpySnapshotExtension): ) group = mol.modifiers["MolecularNodes"].node_group node_col = mn.blender.nodes.add_custom(group, group_col.name, [0, -200]) - group.links.new(node_col.outputs[0], group.nodes["MN_color_set"].inputs["Color"]) + group.links.new( + node_col.outputs[0], group.nodes["MN_color_set"].inputs["Color"] + ) assert snapshot_custom == sample_attribute(mol, "Color") @@ -183,14 +208,18 @@ def test_change_style(): for style in ["ribbon", "cartoon", "presets", "ball_and_stick", "surface"]: style_node_1 = nodes.get_style_node(model) - links_in_1 = [link.from_socket.name for link in get_links(style_node_1.inputs)] + links_in_1 = [ + link.from_socket.name for link in get_links(style_node_1.inputs) + ] links_out_1 = [ link.from_socket.name for link in get_links(style_node_1.outputs) ] nodes.change_style_node(model, style) style_node_2 = nodes.get_style_node(model) - links_in_2 = [link.from_socket.name for link in get_links(style_node_2.inputs)] + links_in_2 = [ + link.from_socket.name for link in get_links(style_node_2.inputs) + ] links_out_2 = [ link.from_socket.name for link in get_links(style_node_2.outputs) ] @@ -205,7 +234,8 @@ def test_node_topology(snapshot_custom: NumpySnapshotExtension): group = nodes.get_mod(mol).node_group group.links.new( - group.nodes["Group Input"].outputs[0], group.nodes["Group Output"].inputs[0] + group.nodes["Group Input"].outputs[0], + group.nodes["Group Output"].inputs[0], ) node_att = group.nodes.new("GeometryNodeStoreNamedAttribute") node_att.inputs[2].default_value = "test_attribute" @@ -256,7 +286,8 @@ def test_compute_backbone(snapshot_custom: NumpySnapshotExtension): group = nodes.get_mod(mol).node_group group.links.new( - group.nodes["Group Input"].outputs[0], group.nodes["Group Output"].inputs[0] + group.nodes["Group Input"].outputs[0], + group.nodes["Group Output"].inputs[0], ) node_att = group.nodes.new("GeometryNodeStoreNamedAttribute") node_att.inputs[2].default_value = "test_attribute" @@ -311,7 +342,7 @@ def test_compute_backbone(snapshot_custom: NumpySnapshotExtension): def test_topo_bonds(): mol = mn.io.fetch("1BNA", del_solvent=True, style=None).object - group = nodes.get_mod(mol).node_group = nodes.new_group() + group = nodes.get_mod(mol).node_group = nodes.new_tree() # add the node that will break bonds, set the cutoff to 0 node_break = nodes.add_custom(group, "MN_topo_bonds_break") diff --git a/tests/test_select.py b/tests/test_select.py index 29c05975..4114aff9 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -5,13 +5,6 @@ import pytest -def create_debug_group(name="MolecularNodesDebugGroup"): - group = nodes.new_group(name=name, fallback=False) - info = group.nodes.new("GeometryNodeObjectInfo") - group.links.new(info.outputs["Geometry"], group.nodes["Group Output"].inputs[0]) - return group - - def evaluate(object): object.update_tag() dg = bpy.context.evaluated_depsgraph_get() @@ -32,7 +25,7 @@ def test_select_multiple_residues(selection): mn.blender.obj.set_attribute(object, "res_id", np.arange(n_atoms) + 1) mod = nodes.get_mod(object) - group = nodes.new_group(fallback=False) + group = nodes.new_tree(fallback=False) mod.node_group = group sep = group.nodes.new("GeometryNodeSeparateGeometry") nodes.insert_last_node(group, sep)