Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support human-in-the-loop via trial.user_attrs #411

Merged
merged 24 commits into from
Apr 13, 2023

Conversation

knshnb
Copy link
Member

@knshnb knshnb commented Mar 6, 2023

Contributor License Agreement

This repository (optuna-dashboard) and Goptuna share common code.
This pull request may therefore be ported to Goptuna.
Make sure that you understand the consequences concerning licenses and check the box below if you accept the term before creating this pull request.

  • I agree this patch may be ported to Goptuna by other Goptuna contributors.

Reference Issues/PRs

register_objective_form_widgets was supported by #370. This PR adds register_user_attr_form_widgets, which registers the human evaluations to trial.user_attrs. With this, we can specify an objective function that combines multiple human evaluations (and suggested parameters).

What does this implement/fix? Explain your changes.

Example:

import os
import textwrap
import threading
import time
from wsgiref.simple_server import make_server

import optuna
from optuna_dashboard import ChoiceWidget
from optuna_dashboard import register_user_attr_form_widgets
from optuna_dashboard import save_note
from optuna_dashboard import SliderWidget
from optuna_dashboard import wsgi
from optuna_dashboard.artifact import upload_artifact
from optuna_dashboard.artifact.file_system import FileSystemBackend
from PIL import Image


base_path = os.path.join(os.path.dirname(__file__), "artifact")
artifact_backend = FileSystemBackend(base_path=base_path)


def wait_value(trial: optuna.Trial, key: str) -> float:
    while True:
        # Workaround for _CachedStorage and _cached_frozen_trial
        storage = trial.study._storage
        if isinstance(storage, optuna.storages._CachedStorage):
            storage = storage._backend
        user_attrs = storage.get_trial_user_attrs(trial._trial_id)
        if key in user_attrs:
            return user_attrs[key]
        time.sleep(5)


def objective(trial: optuna.Trial) -> float:
    # Ask new parameters
    r = trial.suggest_int("r", 0, 255)
    g = trial.suggest_int("g", 0, 255)
    b = trial.suggest_int("b", 0, 255)

    # Generate image
    image_path = f"tmp/sample-{trial.number}.png"
    image = Image.new("RGB", (320, 240), color=(r, g, b))
    image.save(image_path)

    # Upload Artifact
    artifact_id = upload_artifact(artifact_backend, trial, image_path)

    # Save Note
    note = textwrap.dedent(
        f"""\
    ## Trial {trial.number}
    
    ![generated-image](/artifacts/{trial.study._study_id}/{trial._trial_id}/{artifact_id})
    """
    )
    save_note(trial, note)

    # Wait for hitl values and calculate objective score
    choice_val = wait_value(trial, "hitl/choice")
    slider_val = wait_value(trial, "hitl/slider")
    return slider_val if choice_val == -1 else -1


def start_preferential_optimization(study: optuna.Study) -> None:
    register_user_attr_form_widgets(
        study,
        widgets=[
            ChoiceWidget(
                choices=["Good 👍", "Bad 👎"],
                values=[-1, 1],
                description="Please input your score!",
                user_attr_key="hitl/choice",
            ),
            SliderWidget(
                min=1,
                max=10,
                step=1,
                description="Higher is better.",
                user_attr_key="hitl/slider",
            ),
        ],
    )
    study.optimize(objective, n_jobs=3)


def main() -> None:
    if not os.path.exists(base_path):
        os.mkdir(base_path)

    storage = optuna.storages.RDBStorage("sqlite:///db.sqlite3")
    study = optuna.create_study(
        study_name="Human-in-the-loop Optimization user_attrs",
        storage=storage,
        load_if_exists=True,
        direction="maximize",
    )

    # Start Dashboard server on background
    app = wsgi(storage, artifact_backend=artifact_backend)
    httpd = make_server("127.0.0.1", 8080, app)
    thread = threading.Thread(target=httpd.serve_forever)
    thread.start()

    # Run optimize loop
    try:
        start_preferential_optimization(study)
    except KeyboardInterrupt:
        httpd.shutdown()
        httpd.server_close()
        thread.join()


if __name__ == "__main__":
    main()

@c-bata c-bata self-assigned this Mar 17, 2023
@knshnb knshnb marked this pull request as ready for review March 20, 2023 06:52
@knshnb
Copy link
Member Author

knshnb commented Mar 20, 2023

Let me work on refactoring (removing "objective" from the internal variable names, etc.) by follow-up PR.

optuna_dashboard/_app.py Outdated Show resolved Hide resolved
Copy link
Member

@c-bata c-bata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@knshnb Thank you for the update! Looks almost good to me. I left one suggestion though.

ObjectiveSliderWidget = SliderWidget
ObjectiveTextInputWidget = TextInputWidget
FORM_WIDGETS_KEY = "dashboard:form_widgets:v2"
FORM_WIDGETS_OUTPUT_TYPE_KEY = "dashboard:form_widgets_output_type:v2"
Copy link
Member

@c-bata c-bata Apr 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if FORM_WIDGETS_OUTPUT_TYPE_KEY could be included in FORM_WIDGETS_KEY as follows.

FormWidgetJSON = TypedDict(
    "FormWidgetJSON",
    {
        "output_type": Literal["objective", "user_attr"],
        "widgets": list[Union[ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON]]
    }
)

form_widgets: FormWidgetJSON = {
   "output_type": "user_attr",
    "widgets": [w.to_dict() for w in widgets],
}
study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, form_widgets)

Please see d0cb33e for details.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I refactored form_widgets_output_type according to your comment.

@knshnb
Copy link
Member Author

knshnb commented Apr 12, 2023

@c-bata Thanks for the review! I addressed your comment.
I realized that the values in ReadonlyObjectiveForm for register_user_attr_form_widgets are not displayed correctly, but would it be fine to discuss how to fix it as a follow-up PR?

c-bata added a commit to c-bata/optuna-dashboard that referenced this pull request Apr 13, 2023
Copy link
Member

@c-bata c-bata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@knshnb Thanks for the update!

I realized that the values in ReadonlyObjectiveForm for register_user_attr_form_widgets are not displayed correctly, but would it be fine to discuss how to fix it as a follow-up PR?

Sure! I think it's acceptable if the problem does not affect users who are not using register_user_attr_form_widgets().

Changes looks almost good to me. Though I have some minor additional review comments, I will address them on #431.
(See d6f3846 for details).

c-bata added a commit that referenced this pull request Apr 13, 2023
Follow-up #411 to refactor type definitions
@c-bata c-bata merged commit 031a1e1 into optuna:main Apr 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants