Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial inference data export script #2975

Merged
merged 11 commits into from
May 5, 2023
3 changes: 2 additions & 1 deletion backend/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
LabelValues,
)
from oasst_shared.schemas.protocol import TextLabel
from oasst_shared.utils import Anonymizer
from sqlmodel import Session, func


Expand Down Expand Up @@ -182,7 +183,7 @@ def export_trees(
anonymizer_seed: Optional[str] = None,
) -> None:
message_labels: dict[UUID, LabelValues] = {}
anonymizer = tree_export.Anonymizer(anonymizer_seed) if anonymizer_seed else None
anonymizer = Anonymizer(anonymizer_seed) if anonymizer_seed else None
if user_id:
# when filtering by user we don't have complete message trees, export as list
result = fetch_tree_messages_and_avg_labels(
Expand Down
28 changes: 1 addition & 27 deletions backend/oasst_backend/utils/tree_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import contextlib
import gzip
import hashlib
import json
import sys
import uuid
Expand All @@ -21,32 +20,7 @@
ExportMessageTree,
LabelValues,
)


def sha256_hash(key: str, seed: int) -> str:
return hashlib.sha256(f"{key}{seed}".encode("UTF-8")).hexdigest()


class Anonymizer:
def __init__(self, seed, value_generator=lambda key, seed: sha256_hash(key, seed)):
self._map = {}
self._values = set()
self._seed = seed
self._gen = value_generator

def __getitem__(self, key):
if key not in self._map:
new_value = self._gen(key, self._seed)
if new_value in self._values:
raise ValueError("Generated value already exists. Try a different seed or value generator.")
self._map[key] = new_value
self._values.add(new_value)
return self._map[key]

def anonymize(self, collection: str, key: str | None) -> str | None:
if key is None:
return None
return self[f"{collection}:{key}"]
from oasst_shared.utils import Anonymizer


def prepare_export_message_node(
Expand Down
7 changes: 5 additions & 2 deletions docker/inference/Dockerfile.server
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,18 @@ ARG APP_USER


COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared ${SHARED_LIBS_BASE}/oasst-shared
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-data ${SHARED_LIBS_BASE}/oasst-data

USER root
RUN --mount=type=cache,target=/var/cache/pip,from=build \
pip install \
--cache-dir=/var/cache/pip \
-e "${SHARED_LIBS_BASE}/oasst-shared"
-e "${SHARED_LIBS_BASE}/oasst-shared" "${SHARED_LIBS_BASE}/oasst-data"
USER ${APP_USER}


VOLUME [ "${APP_BASE}/lib/oasst-shared" ]
VOLUME [ "${APP_BASE}/lib/oasst-data" ]


CMD uvicorn main:app --reload --host 0.0.0.0 --port "${PORT}"
Expand All @@ -85,11 +87,12 @@ ARG APP_USER


COPY --chown="${APP_USER}:${APP_USER}" ./oasst-shared /tmp/lib/oasst-shared
COPY --chown="${APP_USER}:${APP_USER}" ./oasst-data /tmp/lib/oasst-data
RUN --mount=type=cache,target=/var/cache/pip,from=dev \
pip install \
--cache-dir=/var/cache/pip \
--target="${APP_LIBS}" \
/tmp/lib/oasst-shared
/tmp/lib/oasst-shared /tmp/lib/oasst-data

COPY --chown="${APP_USER}:${APP_USER}" ./inference/server/server_main.sh /entrypoint.sh
ENTRYPOINT ["/entrypoint.sh"]
242 changes: 242 additions & 0 deletions inference/server/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import argparse
import asyncio
import contextlib
import gzip
import json
import sys
from collections import defaultdict
from pathlib import Path
from typing import TextIO

import sqlalchemy
import sqlmodel
from fastapi.encoders import jsonable_encoder
from oasst_data import (
ExportMessageEvent,
ExportMessageEventReport,
ExportMessageEventScore,
ExportMessageNode,
ExportMessageTree,
)
from oasst_inference_server import deps
from oasst_inference_server.database import AsyncSession
from oasst_inference_server.models import DbChat, DbMessage
from oasst_shared.utils import Anonymizer


# see https://stackoverflow.com/questions/17602878/how-to-handle-both-with-open-and-sys-stdout-nicely
@contextlib.contextmanager
def smart_open(filename: str = None) -> TextIO:
if filename and filename != "-":
fh = open(filename, "wt", encoding="UTF-8")
else:
fh = sys.stdout

try:
yield fh
finally:
if fh is not sys.stdout:
fh.close()


def maybe_anonymize(anonymizer: Anonymizer | None, collection: str, key: str) -> str:
if anonymizer:
return anonymizer.anonymize(collection, key)
else:
return key


def prepare_export_events(
chat: DbChat,
message: DbMessage,
anonymizer: Anonymizer | None = None,
) -> dict[str, list[ExportMessageEvent]]:
export_events: dict[str, list[ExportMessageEvent]] = []

if message.reports:
export_events["report"] = [
ExportMessageEventReport(
user_id=maybe_anonymize(anonymizer, "user", str(chat.user_id)),
report_type=str(db_report.report_type),
reason=db_report.reason,
)
for db_report in message.reports
]

if message.score:
export_events["score"] = [
ExportMessageEventScore(
user_id=maybe_anonymize(anonymizer, "user", str(chat.user_id)),
score=message.score,
)
]

return export_events


def prepare_export_message_tree(
chat: DbChat,
anonymizer: Anonymizer | None = None,
) -> ExportMessageTree:
messages: list[DbMessage] = chat.messages

# Exclude messages without content (e.g. work still in progress or aborted)
export_messages: list[ExportMessageNode] = [
prepare_export_message_node(chat, message, anonymizer=anonymizer) for message in messages if message.content
]

messages_by_parent = defaultdict(list)
for message in export_messages:
messages_by_parent[message.parent_id].append(message)

def assign_replies(node: ExportMessageNode) -> ExportMessageNode:
node.replies = messages_by_parent[node.message_id]
for child in node.replies:
assign_replies(child)
return node

prompt = assign_replies(messages_by_parent[None][0])
return ExportMessageTree(message_tree_id=str(chat.id), tree_state="not_applicable", prompt=prompt)


def prepare_export_message_node(
chat: DbChat,
message: DbMessage,
anonymizer: Anonymizer | None = None,
) -> ExportMessageNode:
if message.worker_config:
model_name = message.worker_config.model_config.model_id
else:
model_name = None

# Chat prompts are human-written, responses are synthetic
synthetic = message.role == "assistant"

events: dict[str, list[ExportMessageEvent]] = prepare_export_events(chat, message, anonymizer=anonymizer)

message_id = maybe_anonymize(anonymizer, "message", message.id)
parent_id = maybe_anonymize(anonymizer, "message", message.parent_id)
user_id = maybe_anonymize(anonymizer, "user", chat.user_id)

return ExportMessageNode(
message_id=message_id,
parent_id=parent_id,
user_id=user_id,
created_date=message.created_at,
text=message.content,
role=message.role,
synthetic=synthetic,
model_name=model_name,
events=events,
)


def write_messages_to_file(
file: Path,
chats: list[DbChat],
use_compression: bool = True,
write_trees: bool = True,
anonymizer: Anonymizer | None = None,
) -> None:
out_buff: TextIO

if use_compression:
if not file:
raise RuntimeError("File name must be specified when using compression.")
out_buff = gzip.open(file, "wt", encoding="UTF-8")
else:
out_buff = smart_open(file)

with out_buff as f:
for chat in chats:
if write_trees:
export_chat = prepare_export_message_tree(chat, anonymizer=anonymizer)
file_data = jsonable_encoder(export_chat, exclude_none=True)
json.dump(file_data, f)
f.write("\n")
else:
# Exclude messages without content (e.g. work still in progress or aborted)
for message in filter(lambda m: m.content, chat.messages):
export_message = prepare_export_message_node(chat, message, anonymizer=anonymizer)
file_data = jsonable_encoder(export_message, exclude_none=True)
json.dump(file_data, f)
f.write("\n")


async def fetch_eligible_chats(session_generator) -> list[DbChat]:
"""Fetch chats which are not opted out of data collection."""
session: AsyncSession
async with session_generator() as session:
query = (
sqlmodel.select(DbChat)
.filter(DbChat.allow_data_use)
.options(
sqlalchemy.orm.joinedload("*"),
)
)
chats: list[DbChat] = (await session.exec(query)).unique().all()
return chats


def export_chats(
session_generator,
export_path: Path,
use_compression: bool = True,
write_trees: bool = True,
anonymizer_seed: str | None = None,
) -> None:
eligible_chats: list[DbChat] = asyncio.run(fetch_eligible_chats(session_generator))
anonymizer = Anonymizer(anonymizer_seed) if anonymizer_seed else None

write_messages_to_file(
export_path,
eligible_chats,
write_trees=write_trees,
use_compression=use_compression,
anonymizer=anonymizer,
)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--export-file",
type=str,
help="Name of file to export chats to. If not provided, output will be sent to STDOUT",
)
parser.add_argument(
"--no-compression",
action="store_true",
help="Disable compression when writing to file.",
)
parser.add_argument(
"--write-flat",
action="store_true",
help="Write chats as individual messages rather than trees.",
)
parser.add_argument(
"--anonymizer-seed",
type=int,
help="Seed for the anonymizer. If not specified, no anonymization will be performed.",
)
# TODO: filters: reported, score, user ID, chat ID, etc
# TODO: date range?
return parser.parse_args()


def main():
args = parse_args()

export_path = Path(args.export_file) if args.export_file else None

export_chats(
deps.manual_create_session,
export_path,
use_compression=not args.no_compression,
write_trees=not args.write_flat,
anonymizer_seed=args.anonymizer_seed,
)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions oasst-data/oasst_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
ExportMessageEventEmoji,
ExportMessageEventRanking,
ExportMessageEventRating,
ExportMessageEventReport,
ExportMessageEventScore,
ExportMessageNode,
ExportMessageTree,
LabelAvgValue,
Expand All @@ -19,6 +21,8 @@
"ExportMessageEventEmoji",
"ExportMessageEventRating",
"ExportMessageEventRanking",
"ExportMessageEventReport",
"ExportMessageEventScore",
"ExportMessageNode",
"ExportMessageTree",
"read_message_trees",
Expand Down
13 changes: 12 additions & 1 deletion oasst-data/oasst_data/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from typing import Literal, Optional

from pydantic import BaseModel
from pydantic import BaseModel, conint


class LabelAvgValue(BaseModel):
Expand Down Expand Up @@ -38,6 +38,17 @@ class ExportMessageEventRanking(ExportMessageEvent):
not_rankable: Optional[bool] # flawed, factually incorrect or unacceptable


class ExportMessageEventReport(ExportMessageEvent):
type: Literal["report"] = "report"
report_type: str
reason: str


class ExportMessageEventScore(ExportMessageEvent):
type: Literal["score"] = "score"
score: conint(ge=-1, le=1)


class DetoxifyRating(BaseModel):
toxicity: float
severe_toxicity: float
Expand Down
Loading