diff --git a/python-src/fastpdb/__init__.py b/python-src/fastpdb/__init__.py index da190b7..900640f 100644 --- a/python-src/fastpdb/__init__.py +++ b/python-src/fastpdb/__init__.py @@ -131,7 +131,7 @@ def get_structure(self, model=None, altloc="first", extra_fields=None, include_b # Interpret uint32 arrays as unicode arrays chain_id = np.frombuffer(chain_id, dtype="U4") ins_code = np.frombuffer(ins_code, dtype="U1") - res_name = np.frombuffer(res_name, dtype="U3") + res_name = np.frombuffer(res_name, dtype="U5") atom_name = np.frombuffer(atom_name, dtype="U6") element = np.frombuffer(element, dtype="U2") altloc_id = np.frombuffer(altloc_id, dtype="U1") @@ -254,11 +254,11 @@ def set_structure(self, atoms): # Write 'ATOM' and 'MODEL' records # Convert Unicode arrays into uint32 arrays for usage in Rust - chain_id = np.frombuffer(atoms.chain_id, dtype=np.uint32).reshape(-1, 4) - ins_code = np.frombuffer(atoms.ins_code, dtype=np.uint32).reshape(-1, 1) - res_name = np.frombuffer(atoms.res_name, dtype=np.uint32).reshape(-1, 3) - atom_name = np.frombuffer(atoms.atom_name, dtype=np.uint32).reshape(-1, 6) - element = np.frombuffer(atoms.element, dtype=np.uint32).reshape(-1, 2) + chain_id = _convert_unicode_to_uint32(atoms.chain_id) + ins_code = _convert_unicode_to_uint32(atoms.ins_code) + res_name = _convert_unicode_to_uint32(atoms.res_name) + atom_name = _convert_unicode_to_uint32(atoms.atom_name) + element = _convert_unicode_to_uint32(atoms.element) categories = atoms.get_annotation_categories() atom_id = atoms.atom_id if "atom_id" in categories else None @@ -320,3 +320,18 @@ def _index_models_and_atoms(self): self._pdb_file.index_models_and_atoms() self._model_start_i = self._pdb_file.model_start_i self._atom_line_i = self._pdb_file.atom_line_i + + +def _convert_unicode_to_uint32(array): + """ + Convert a unicode string array into a 2D uint32 array. + + The second dimension corresponds to the character position within a + string. + """ + dtype = array.dtype + if not np.issubdtype(dtype, np.str_): + raise TypeError("Expected unicode string array") + length = array.shape[0] + n_char = dtype.itemsize // 4 + return np.frombuffer(array, dtype=np.uint32).reshape(length, n_char) \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 1a8f58f..ed157a7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,7 +43,7 @@ struct PDBFile { #[pymethods] impl PDBFile { - + /// Create an new [`PDBFile`]. /// The lines of text are given to `lines`. /// An empty `Vec` represents and empty PDB file. @@ -99,7 +99,7 @@ impl PDBFile { }) } - + /// Get the number of models contained in the file. fn get_model_count(&self) -> usize { self.model_start_i.len() @@ -180,7 +180,7 @@ impl PDBFile { /// Unicode *NumPy* arrays are represented by 2D *NumPy* arrays with /// `uint32` *dtype*. /// The returned tuple contains the following annotations in the given order: - /// + /// /// - `chain_id` /// - `res_id` /// - `ins_code` @@ -211,16 +211,16 @@ impl PDBFile { Option>>, Option>>)> { let atom_line_i: Vec = self.get_atom_indices(model)?; - + let mut chain_id: Array = Array::zeros((atom_line_i.len(), 4)); let mut res_id: Array = Array::zeros(atom_line_i.len()); let mut ins_code: Array = Array::zeros((atom_line_i.len(), 1)); - let mut res_name: Array = Array::zeros((atom_line_i.len(), 3)); + let mut res_name: Array = Array::zeros((atom_line_i.len(), 5)); let mut hetero: Array = Array::default(atom_line_i.len()); let mut atom_name: Array = Array::zeros((atom_line_i.len(), 6)); let mut element: Array = Array::zeros((atom_line_i.len(), 2)); let mut altloc_id: Array = Array::zeros((atom_line_i.len(), 1)); - + let mut atom_id: Array; if include_atom_id { atom_id = Array::zeros(atom_line_i.len()); @@ -250,7 +250,7 @@ impl PDBFile { else { charge = Array::zeros(0); } - + // Iterate over ATOM and HETATM records to write annotation arrays for (atom_i, line_i) in atom_line_i.iter().enumerate() { let line = &self.lines[*line_i]; @@ -293,7 +293,7 @@ impl PDBFile { } } } - + Python::with_gil(|py| { Ok(( PyArray::from_array(py, &chain_id ).to_owned(), @@ -326,7 +326,7 @@ impl PDBFile { atom_id_to_index.insert(*id, i as u32); } }); - + // Cannot preemptively determine number of bonds // -> Memory allocation for all bonds is not possible // -> No benefit in finding 'CONECT' record lines prior to iteration @@ -409,11 +409,11 @@ impl PDBFile { // This procedure aims to increase the performance is repetitive formatting is omitted let mut prefix: Vec = Vec::new(); let mut suffix: Vec = Vec::new(); - + for i in 0..coord.shape()[1] { let element_i = parse_string_from_array(&element, i)?; let atom_name_i = parse_string_from_array(&atom_name, i)?; - + prefix.push(format!( "{:6}{:>5} {:4} {:>3} {:1}{:>4}{:1} ", if hetero[i] { "HETATM" } else { "ATOM" }, @@ -449,7 +449,7 @@ impl PDBFile { if is_multi_model { self.lines.push(format!("MODEL {:>8}", model_i+1)); } - for atom_i in 0..coord.shape()[1] { + for atom_i in 0..coord.shape()[1] { let coord_string = format!( "{:>8.3}{:>8.3}{:>8.3}", coord[[model_i, atom_i, 0]], @@ -471,13 +471,13 @@ impl PDBFile { /// array containing indices pointing to bonded atoms in the `AtomArray`. /// The `atom_id` annotation array is required to map the atom IDs in `CONECT` records /// to atom indices. - fn write_bonds(&mut self, + fn write_bonds(&mut self, bonds: Py>, atom_id: Py>) -> PyResult<()> { Python::with_gil(|py| { let bonds = bonds.as_ref(py).to_owned_array(); let atom_id = atom_id.as_ref(py).to_owned_array(); - + for (center_i, bonded_indices) in bonds.outer_iter().enumerate() { let mut n_added: usize = 0; let mut line: String = String::new(); @@ -552,7 +552,7 @@ impl PDBFile { } Ok(CoordArray::Single(coord)) }, - + None => { let length = self.get_model_length()?; let mut coord = Array::zeros((self.atom_line_i.len(), 3)); @@ -591,7 +591,7 @@ impl PDBFile { self.model_start_i.len(), model ))); } - + // Get the start and stop line index for this model index let (model_start, model_stop) = match model_i.cmp(&(self.model_start_i.len() as isize - 1)){ Ordering::Less => ( @@ -600,13 +600,13 @@ impl PDBFile { ), // Last model -> Model reaches to end of file Ordering::Equal => ( - self.model_start_i[model_i as usize], + self.model_start_i[model_i as usize], self.lines.len() ), // This case was excluded above _ => panic!("This branch should not be reached") }; - + // Get the atom records within these line boundaries Ok( self.atom_line_i.iter().copied() @@ -614,7 +614,7 @@ impl PDBFile { .collect() ) } - + /// Get the number of atoms in each model of the PDB file. /// A `PyErr` is returned if the number of atoms per model differ from each other. @@ -635,7 +635,7 @@ impl PDBFile { )); } }; } - + match length { None => panic!("Length cannot be 'None'"), Some(l) => Ok(l)