diff --git a/TCutility/results/result.py b/TCutility/results/result.py index 14bcbf0f..83b81167 100644 --- a/TCutility/results/result.py +++ b/TCutility/results/result.py @@ -4,7 +4,29 @@ class Result(dict): '''Class used for storing results from AMS calculations. The class is functionally a dictionary, but allows dot notation to access variables in the dictionary. The class works case-insensitively, but will retain the case of the key when it was first set.''' + + def __call__(self): + '''Calling of a dictionary subclass should not be possible, instead we raise an error with information about the key and method that were attempted to be called.''' + head, method = '.'.join(self.get_parent_tree().split('.')[:-1]), self.get_parent_tree().split('.')[-1] + raise AttributeError(f'Tried to call method "{method}" from {head}, but {head} is empty') + + def __str__(self): + '''Override str method to prevent printing of hidden keys. You can still print them if you call repr instead of str.''' + return '{' + ', '.join([f'{key}: {str(val)}' for key, val in self.items()]) + '}' + + def items(self): + '''We override the items method from dict in order to skip certain keys. We want to hide keys starting and ending + with dunders, as they should not be exposed to the user. + ''' + return [(key, self[key]) for key in self.keys()] + + def keys(self): + original_keys = super().keys() + return [key for key in original_keys if not (key.startswith('__') and key.endswith('__'))] + def __getitem__(self, key): + if key.startswith('__') and key.endswith('__'): + return None self.__set_empty(key) val = super().__getitem__(self.__get_case(key)) return val @@ -22,14 +44,37 @@ def __setattr__(self, key, val): self.__setitem__(key, val) def __contains__(self, key): - # Custom method to check if the key is defined in this object, case-insensitive. - return key.lower() in [key_.lower() for key_ in self.keys()] + # Custom method to check if the key is defined in this object and is also non-empty, case-insensitive. + return key.lower() in [key_.lower() for key_ in self.keys()] and self[key] + + def __hash__(self): + '''Hashing of a dictionary subclass should not be possible, instead we should raise an error to let the user know + that they made a mistake. Also give information of which key was being read. + ''' + raise KeyError(f'Tried to hash {self.get_parent_tree()}, but it is empty') + + def __bool__(self): + '''Make sure that keys starting and ending in "__" are skipped''' + return len([key for key in self.keys() if not (key.startswith('__') and key.endswith('__'))]) > 0 + + def get_parent_tree(self): + '''Method to get the path from this object to the parent object. The result is presented in a formatted string''' + # every parent except the top-most parent has defined a __parent__ attribute + if '__parent__' not in self: + return 'Head' + # iteratively build the tree using the __name__ attribute. + parent_names = self.__parent__.get_parent_tree() + parent_names += '.' + self.__name__ + return parent_names def __set_empty(self, key): # This function checks if the key has been set. # If it has not, we create a new Result object and set it at the desired key - if key not in self: + if self.__get_case(key) not in self.keys(): val = Result() + # we also keep track of the parent of this object and also the name it was assigned to for later bookkeeping + val.__parent__ = self + val.__name__ = key self.__setitem__(key, val) def __get_case(self, key): @@ -40,12 +85,32 @@ def __get_case(self, key): return key_ return key + def prune(self): + '''Remove empty paths of this object. + ''' + items = list(self.items()) + for key, value in items: + try: + value.prune() + except AttributeError: + pass + + if not value: + del self[key] + if __name__ == '__main__': ret = Result() - ret.aDf.x = {'a': 1, 'b': 2} - print(ret.adf.x.a) + # print(ret.adf) + # print(dict(ret.adf)) + # print(bool(ret.adf)) + ret.adf.x = {'a': 1, 'b': 2} + # ret.adf.system.atoms = [] + # ret.adf.system.atoms.append('test 1 2 3') - ret.ADF.y = 2345 - ret.dftb.z = 'hello' - print(ret) + # test_dict[ret.adf.y] = 20 + # ret.adf.y.join() + # {ret.test: 123} + # ret.__name__ = 'testname' + # print(ret.__name__) + print(repr(ret)) diff --git a/test/test_result_class.py b/test/test_result_class.py new file mode 100644 index 00000000..4afb6808 --- /dev/null +++ b/test/test_result_class.py @@ -0,0 +1,70 @@ +from TCutility import results + +def test_init(): + res = results.result.Result() # noqa F841 + + +def test_assign(): + res = results.result.Result() + res.a = 10 + assert res.a == 10 + + +def test_assign2(): + res = results.result.Result() + res.a.b = 10 + assert res.a.b == 10 + + +def test_assign3(): + res = results.result.Result() + res.a.b = 10 + res.a.c = 20 + assert res.a.b == 10 and res.a.c == 20 + + +def test_assign4(): + res = results.result.Result() + res.a.lst = [] + res.a.lst.append(10) + assert len(res.a.lst) == 1 + + +def test_contains(): + res = results.result.Result() + res.a = 10 + assert 'a' in res + + +def test_contains2(): + res = results.result.Result() + res.a.b = 10 + assert 'a' in res + + +def test_contains3(): + res = results.result.Result() + res.a.b = 10 + assert 'b' in res.a + + +def test_contains_neg(): + res = results.result.Result() + assert 'a' not in res + + +def test_contains_neg2(): + res = results.result.Result() + assert 'b' not in res.a + + +def test_contains_neg3(): + res = results.result.Result() + res.a.b = 10 + assert 'b' not in res + + +def test_contains_neg4(): + res = results.result.Result() + res.a + assert 'a' not in res diff --git a/test/test_results.py b/test/test_results.py index 4f375ab3..a90550ef 100644 --- a/test/test_results.py +++ b/test/test_results.py @@ -35,7 +35,6 @@ def test_LOT4() -> None: assert res.level.summary == 'M06-2X/QZ4P' - def test_sections() -> None: res = results.read(j('test', 'fixtures', 'DFT_EDA')) - assert all(section in res for section in ['adf', 'engine', 'ams_version', 'history', 'is_multijob', 'molecule', 'status', 'timing']) + assert all(section in res for section in ['adf', 'engine', 'ams_version', 'is_multijob', 'molecule', 'status', 'timing'])