-
Notifications
You must be signed in to change notification settings - Fork 7
/
config.py
70 lines (59 loc) · 2.44 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration of a LLMSR experiments
."""
from __future__ import annotations
import dataclasses
from typing import Type
import os
from llmsr import sampler
from llmsr import evaluator
@dataclasses.dataclass(frozen=True)
class ExperienceBufferConfig:
"""Configures Experience Buffer parameters.
Args:
functions_per_prompt (int): Number of previous hypotheses to include in prompts
num_islands (int): Number of islands in experience buffer for diversity
reset_period (int): Seconds between weakest island resets
cluster_sampling_temperature_init (float): Initial cluster softmax sampling temperature
cluster_sampling_temperature_period (int): Period for temperature decay
"""
functions_per_prompt: int = 2
num_islands: int = 10
reset_period: int = 4 * 60 * 60
cluster_sampling_temperature_init: float = 0.1
cluster_sampling_temperature_period: int = 30_000
@dataclasses.dataclass(frozen=True)
class Config:
"""Configuration for LLMSR experiments.
Args:
experience_buffer: Evolution multi-population settings
num_samplers (int): Number of parallel samplers
num_evaluators (int): Number of parallel evaluators
samples_per_prompt (int): Number of hypotheses per prompt
evaluate_timeout_seconds (int): Hypothesis evaluation timeout
use_api (bool): API usage flag
"""
experience_buffer: ExperienceBufferConfig = dataclasses.field(default_factory=ExperienceBufferConfig)
num_samplers: int = 1
num_evaluators: int = 1
samples_per_prompt: int = 4
evaluate_timeout_seconds: int = 30
use_api: bool = False
api_model: str = "gpt-3.5-turbo"
@dataclasses.dataclass()
class ClassConfig:
llm_class: Type[sampler.LLM]
sandbox_class: Type[evaluator.Sandbox]