diff --git a/pyscses/grid.py b/pyscses/grid.py index ddd657a..c190367 100644 --- a/pyscses/grid.py +++ b/pyscses/grid.py @@ -24,6 +24,7 @@ def phi_at_x(phi: np.ndarray, Returns: float: The electrostatic potential at the x coordinate with position [index]. + """ index = index_of_grid_at_x(coordinates, x) return phi[index] @@ -41,30 +42,32 @@ def energy_at_x(energy: np.ndarray, Returns: energy[index] (float): The segregation energy at the x coordinate with position [index]. - """ + """ index = index_of_grid_at_x(coordinates, x) return energy[index] def index_of_grid_at_x(coordinates: np.ndarray, x: float) -> int: - """ - Assigns each site x coordinate to a position on a regularly or irregularly spaced grid. - Returns the index of the grid point clostest to the value x + """Assigns each site x coordinate to a position on a regularly or irregularly spaced grid. + + Returns the index of the grid point closest to the value x Args: - coordinates (np.array): 1D grid of ordered numbers over a region. - x (float): Site x coordinate + coordinates (np.array): arraylike ordered list of x coordinates. + x (float): x coordinate. Returns: - int: Index of grid position closest to the site x coordinate. + int: Index of the coordinates array at the position closest to the input x coordinate. + """ return closest_index(coordinates, x) def closest_index(myList: Union[list[float], np.ndarray], myNumber: float) -> int: - """ - Assumes myList is sorted. Returns index of closest value to myNumber. + """Returns index of closest value to myNumber. + + Assumes myList is sorted. If two numbers are equally close, return the index of the smallest number. Args: @@ -73,6 +76,7 @@ def closest_index(myList: Union[list[float], np.ndarray], Returns: pos (int): Index of position of number in myList which is closest to myNumber. + """ myList = list(myList) pos = bisect_left(myList, myNumber) diff --git a/tests/test_grid.py b/tests/test_grid.py index aca9907..a868b87 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,5 +1,10 @@ import unittest from pyscses.grid import Grid, delta_x_from_grid +from pyscses.grid import (closest_index, + index_of_grid_at_x, + energy_at_x, + phi_at_x, + delta_x_from_grid) from pyscses.grid_point import GridPoint from pyscses.set_of_sites import SetOfSites from pyscses.site import Site @@ -7,7 +12,67 @@ from unittest.mock import Mock, MagicMock, patch, call import numpy as np -class TestGrid( unittest.TestCase ): +class TestGridFunctions(unittest.TestCase): + + def test_closest_index(self): + a = [1,3,5,7,9] + self.assertEqual(closest_index(a, 3.1), 1) + self.assertEqual(closest_index(a, 4.1), 2) + self.assertEqual(closest_index(a, 4.0), 1) + self.assertEqual(closest_index(a, 0.1), 0) + self.assertEqual(closest_index(a, 9.5), 4) + + def test_index_of_grid_at_x(self): + coordinates = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) + with patch('pyscses.grid.closest_index') as mock_closest_index: + mock_closest_index.side_effect = [0, 2, 3] + self.assertEqual(index_of_grid_at_x(coordinates=coordinates, + x=-1.5), 0) + self.assertEqual(index_of_grid_at_x(coordinates=coordinates, + x=0.1), 2) + self.assertEqual(index_of_grid_at_x(coordinates=coordinates, + x=1.5), 3) + + + def test_energy_at_x(self): + energy = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) + coordinates = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) + with patch('pyscses.grid.index_of_grid_at_x') as mock_index_of_grid_at_x: + mock_index_of_grid_at_x.side_effect = [0, 2, 3] + self.assertEqual(energy_at_x(energy=energy, + coordinates=coordinates, + x=-1.5), 0.1) + self.assertEqual(energy_at_x(energy=energy, + coordinates=coordinates, + x=-0.1), 0.3) + self.assertEqual(energy_at_x(energy=energy, + coordinates=coordinates, + x=0.6), 0.4) + + def test_phi_at_x(self): + energy = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) + coordinates = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) + with patch('pyscses.grid.index_of_grid_at_x') as mock_index_of_grid_at_x: + mock_index_of_grid_at_x.side_effect = [0, 2, 3] + self.assertEqual(phi_at_x(phi=energy, + coordinates=coordinates, + x=-1.5), 0.1) + self.assertEqual(phi_at_x(phi=energy, + coordinates=coordinates, + x=-0.1), 0.3) + self.assertEqual(phi_at_x(phi=energy, + coordinates=coordinates, + x=0.6), 0.4) + + def test_delta_x_from_grid(self): + coordinates = np.array([0.0, 1.0, 3.0, 6.0, 10.0]) + limits = (1.0, 4.0) + expected_delta_x = np.array([1.0, 1.5, 2.5, 3.5, 4.0]) + np.testing.assert_array_equal(delta_x_from_grid(coordinates=coordinates, + limits=limits), expected_delta_x) + + +class TestGrid(unittest.TestCase): @patch('pyscses.grid.index_of_grid_at_x') @patch('pyscses.grid.GridPoint') def test_grid_instance_is_initialised( self, mock_GridPoint, mock_index ):