diff --git a/jaqpotpy/datasets/molecular_datasets.py b/jaqpotpy/datasets/molecular_datasets.py index 1f7afc40..49825c87 100644 --- a/jaqpotpy/datasets/molecular_datasets.py +++ b/jaqpotpy/datasets/molecular_datasets.py @@ -34,7 +34,10 @@ def __init__(self, df: pd.DataFrame = None, path: Optional[str] = None, featurizer: Optional[MolecularFeaturizer] = None, task:str = None ) -> None: - + + if not(isinstance(y_cols, str) or (isinstance(y_cols, list) and isinstance(y_cols[1], str))): + raise ValueError("y_cols must be provided and should be either a string or a list of strings") + super().__init__(df=df, path=path, y_cols=y_cols, x_cols=x_cols, task = task) if isinstance(smiles_cols, str): @@ -47,7 +50,7 @@ def __init__(self, df: pd.DataFrame = None, path: Optional[str] = None, self.smiles_cols = None self.smiles_cols_len = 0 else: - raise TypeError("smiles_cols must either be a string, a list of strings or a None.") + raise TypeError("smiles_cols must either be a string, a list of strings or None.") self.featurizer = featurizer self._featurizer_name = None @@ -74,16 +77,22 @@ def x_cols_all(self, value): def create(self): if (isinstance(self.smiles_cols,list) and len(self.smiles_cols) == 1): + # The method featurize_dataframe needs self.smiles to be pd.Series self.smiles = self._df[self.smiles_cols[0]] descriptors = self.featurizer.featurize_dataframe(self.smiles) elif isinstance(self.smiles_cols,str): self.smiles = self._df[self.smiles_cols] descriptors = self.featurizer.featurize_dataframe(self.smiles) - else: + elif: featurized_dfs = [self.featurizer.featurize_dataframe(self._df[[col]]) for col in self.smiles_cols] descriptors = pd.concat(featurized_dfs, axis=1) - + else: + #Case where no smiles were provided + self.smiles = None + descriptors = [] + self._y = self._df[self.y_cols] + if self.x_cols is None: # Estimate x_cols by excluding y_cols and smiles_col x_ext = pd.concat(self._df.drop(columns=self.y_cols + [self.smiles_cols]), descriptors) @@ -97,25 +106,25 @@ def create(self): self._df = pd.concat([self._x, self._y], axis = 1) def __get_x__(self): - return self.df[self._x].to_numpy() + return self._df[self._x].to_numpy() def __get_y__(self): - return self.df[self._y].to_numpy() + return self._df[self._y].to_numpy() def __get__(self, instance, owner): if instance is None: return self - return instance.__dict__[self.df] + return instance.__dict__[self._df] def __getitem__(self, idx): # print(self.df[self.X].iloc[idx].values) # print(type(self.df[self.X].iloc[idx].values)) - selected_x = self.df[self._x].iloc[idx].values - selected_y = self.df[self._y].iloc[idx].to_numpy() + selected_x = self._df[self._x].iloc[idx].values + selected_y = self._df[self._y].iloc[idx].to_numpy() return selected_x, selected_y def __len__(self): - return len(self.df) + return len(self._df) def __repr__(self) -> str: return (f"{self.__class__.__name__}"