Skip to content

Commit

Permalink
Updated type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Sep 3, 2024
1 parent ed75c0a commit 839b431
Showing 1 changed file with 63 additions and 62 deletions.
125 changes: 63 additions & 62 deletions pybrush/EstimatorInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
from pandas.api.types import is_float_dtype, is_bool_dtype, is_integer_dtype
from pybrush import Parameters, Dataset
from typing import Union, List, Dict

class EstimatorInterface():
"""
Expand Down Expand Up @@ -133,70 +134,70 @@ class EstimatorInterface():
"""

def __init__(self,
mode='classification',
pop_size=100,
max_gens=100,
max_time=-1,
max_stall=0,
verbosity=0,
max_depth=3,
max_size=20,
num_islands=5,
n_jobs=1,
mig_prob=0.05,
cx_prob= 1/7,
mutation_probs = {"point":1/6, "insert":1/6, "delete":1/6, "subtree":1/6,
mode: str = 'classification',
pop_size: int = 100,
max_gens: int = 100,
max_time: int = -1,
max_stall: int = 0,
verbosity: int = 0,
max_depth: int = 3,
max_size: int = 20,
num_islands: int = 5,
n_jobs: int = 1,
mig_prob: float = 0.05,
cx_prob: float = 1/7,
mutation_probs: Dict[str, float] = {"point":1/6, "insert":1/6, "delete":1/6, "subtree":1/6,
"toggle_weight_on":1/6, "toggle_weight_off":1/6},
functions: list[str]|dict[str,float] = {},
initialization="uniform",
algorithm="nsga2",
objectives=["error", "size"],
random_state=None,
logfile="",
save_population="",
load_population="",
shuffle_split=True,
bandit='dynamic_thompson',
weights_init=True,
val_from_arch=True,
use_arch=False,
scorer=None,
sel = "lexicase",
surv = "nsga2",
functions: Union[List[str], Dict[str, float]] = {},
initialization: str = "uniform",
objectives: List[str] = ["error", "size"],
scorer: str = None,
algorithm: str = "nsga2",
weights_init: bool = True,
validation_size: float = 0.0,
batch_size: float = 1.0
):
self.pop_size=pop_size
self.max_gens=max_gens
self.max_stall=max_stall
self.max_time=max_time
self.verbosity=verbosity
self.algorithm=algorithm
self.mode=mode
self.max_depth=max_depth
self.max_size=max_size
self.num_islands=num_islands
self.mig_prob=mig_prob
self.n_jobs=n_jobs
self.cx_prob=cx_prob
self.bandit=bandit
self.logfile=logfile
self.save_population=save_population
self.load_population=load_population
self.mutation_probs=mutation_probs
self.val_from_arch=val_from_arch # TODO: val from arch implementation (in cpp side)
self.use_arch=use_arch
self.functions=functions
self.objectives=objectives
self.scorer=scorer
self.shuffle_split=shuffle_split
self.initialization=initialization
self.random_state=random_state
self.batch_size=batch_size
self.sel=sel
self.surv=surv
self.weights_init=weights_init
self.validation_size=validation_size
use_arch: bool = False,
val_from_arch: bool = True,
batch_size: float = 1.0,
sel: str = "lexicase",
surv: str = "nsga2",
save_population: str = "",
load_population: str = "",
bandit: str = 'dynamic_thompson',
shuffle_split: bool = True,
logfile: str = "",
random_state: int = None
) -> None:
self.pop_size = pop_size
self.max_gens = max_gens
self.max_stall = max_stall
self.max_time = max_time
self.verbosity = verbosity
self.algorithm = algorithm
self.mode = mode
self.max_depth = max_depth
self.max_size = max_size
self.num_islands = num_islands
self.mig_prob = mig_prob
self.n_jobs = n_jobs
self.cx_prob = cx_prob
self.bandit = bandit
self.logfile = logfile
self.save_population = save_population
self.load_population = load_population
self.mutation_probs = mutation_probs
self.val_from_arch = val_from_arch
self.use_arch = use_arch
self.functions = functions
self.objectives = objectives
self.scorer = scorer
self.shuffle_split = shuffle_split
self.initialization = initialization
self.random_state = random_state
self.batch_size = batch_size
self.sel = sel
self.surv = surv
self.weights_init = weights_init
self.validation_size = validation_size

def _wrap_parameters(self, **extra_kwargs):
"""
Expand Down

0 comments on commit 839b431

Please sign in to comment.