Skip to content

Commit

Permalink
feat: Making sure that the user provided y_cols correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
periklis91 committed Jun 13, 2024
1 parent 4d9bcdf commit ddcfd0a
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions jaqpotpy/datasets/molecular_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def __init__(self, df: pd.DataFrame = None, path: Optional[str] = None,
featurizer: Optional[MolecularFeaturizer] = None,
task:str = None
) -> None:


if not(isinstance(y_cols, str) or (isinstance(y_cols, list) and isinstance(y_cols[1], str))):
raise ValueError("y_cols must be provided and should be either a string or a list of strings")

super().__init__(df=df, path=path, y_cols=y_cols, x_cols=x_cols, task = task)

if isinstance(smiles_cols, str):
Expand All @@ -47,7 +50,7 @@ def __init__(self, df: pd.DataFrame = None, path: Optional[str] = None,
self.smiles_cols = None
self.smiles_cols_len = 0
else:
raise TypeError("smiles_cols must either be a string, a list of strings or a None.")
raise TypeError("smiles_cols must either be a string, a list of strings or None.")

self.featurizer = featurizer
self._featurizer_name = None
Expand All @@ -74,16 +77,22 @@ def x_cols_all(self, value):
def create(self):

if (isinstance(self.smiles_cols,list) and len(self.smiles_cols) == 1):
# The method featurize_dataframe needs self.smiles to be pd.Series
self.smiles = self._df[self.smiles_cols[0]]
descriptors = self.featurizer.featurize_dataframe(self.smiles)
elif isinstance(self.smiles_cols,str):
self.smiles = self._df[self.smiles_cols]
descriptors = self.featurizer.featurize_dataframe(self.smiles)
else:
elif:
featurized_dfs = [self.featurizer.featurize_dataframe(self._df[[col]]) for col in self.smiles_cols]
descriptors = pd.concat(featurized_dfs, axis=1)

else:
#Case where no smiles were provided
self.smiles = None
descriptors = []

self._y = self._df[self.y_cols]

if self.x_cols is None:
# Estimate x_cols by excluding y_cols and smiles_col
x_ext = pd.concat(self._df.drop(columns=self.y_cols + [self.smiles_cols]), descriptors)
Expand All @@ -97,25 +106,25 @@ def create(self):
self._df = pd.concat([self._x, self._y], axis = 1)

def __get_x__(self):
return self.df[self._x].to_numpy()
return self._df[self._x].to_numpy()

def __get_y__(self):
return self.df[self._y].to_numpy()
return self._df[self._y].to_numpy()

def __get__(self, instance, owner):
if instance is None:
return self
return instance.__dict__[self.df]
return instance.__dict__[self._df]

def __getitem__(self, idx):
# print(self.df[self.X].iloc[idx].values)
# print(type(self.df[self.X].iloc[idx].values))
selected_x = self.df[self._x].iloc[idx].values
selected_y = self.df[self._y].iloc[idx].to_numpy()
selected_x = self._df[self._x].iloc[idx].values
selected_y = self._df[self._y].iloc[idx].to_numpy()
return selected_x, selected_y

def __len__(self):
return len(self.df)
return len(self._df)

def __repr__(self) -> str:
return (f"{self.__class__.__name__}"
Expand Down

0 comments on commit ddcfd0a

Please sign in to comment.