Skip to content

Commit

Permalink
add basic logic and first tests for safety classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianWolfschmidtAA authored and MartinAchtnerAA committed Apr 18, 2024
1 parent 6b7fbe1 commit 663ddb2
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
75 changes: 75 additions & 0 deletions src/intelligence_layer/core/safety_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from enum import Enum
from typing import Sequence

from pydantic import BaseModel

from intelligence_layer.core import (
CompleteInput,
ControlModel,
LuminousControlModel,
Task,
TaskSpan,
TextChunk,
)


class UnsafeOutputFlag(Enum):
HATESPEECH = "hatespeech"
SEXUAL = "sexual"
RELIGIOUS = "religious"
MEDICAL = "medical"
SELF_HARM = "self_harm"
ILLEGAL_ACTIVITIES = "illegal_activities"
VIOLENCE = "violence"
INSULT = "insult"
PROFANITY = "profanity"


class SafetyClassifyInput(BaseModel):
text: TextChunk


class SafetyClassifyOutput(BaseModel):
flags: Sequence[UnsafeOutputFlag] # "safe" is an empty list.
# Flag is only included if the task determines it should be flagged


class SafetyClassifier(Task[SafetyClassifyInput, SafetyClassifyOutput]):
def __init__(
self,
model: ControlModel | None = None,
) -> None:
self._model = model or LuminousControlModel("luminous-base-control")

def do_run(
self, input: SafetyClassifyInput, task_span: TaskSpan
) -> SafetyClassifyOutput:
instruction = """Your task is to act as a safety classifier for text. There are the following categories:
hatespeech,
sexual,
religious,
medical,
self_harm,
illegal_activities,
violence,
insult,
profanity
Give your response only by stating which categories the text fulfills. Return no text if the text is safe.
"""
if not instruction:
raise ValueError("Could not find `instruction` for safety classifier.")
completion = self._model.complete(
CompleteInput(
prompt=self._model.to_instruct_prompt(instruction, input.text),
),
task_span,
)
completion_str = str(completion.completions[0].completion)
potential_flags = completion_str.split(", ")
flags = [
UnsafeOutputFlag(flag.strip())
for flag in potential_flags
if hasattr(UnsafeOutputFlag, flag.strip().upper())
]
return SafetyClassifyOutput(flags=flags)
34 changes: 34 additions & 0 deletions tests/core/test_safety_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from pytest import fixture

from intelligence_layer.core import LuminousControlModel, NoOpTracer, TextChunk
from intelligence_layer.core.safety_classifier import (
SafetyClassifier,
SafetyClassifyInput,
)


@fixture
def safety_classifier(
luminous_control_model: LuminousControlModel,
) -> SafetyClassifier:
return SafetyClassifier(model=luminous_control_model)


def test_safety_classifier_returns_no_flags_for_safe_input(
safety_classifier: SafetyClassifier,
) -> None:
text = "This is a safe text"
input = SafetyClassifyInput(text=TextChunk(text))

output = safety_classifier.run(input, NoOpTracer())
assert len(output.flags) == 0


def test_safety_classifier_returns_flags_for_unsafe_input(
safety_classifier: SafetyClassifier,
) -> None:
text = "I will kill you."
input = SafetyClassifyInput(text=TextChunk(text))

output = safety_classifier.run(input, NoOpTracer())
assert len(output.flags) > 0

0 comments on commit 663ddb2

Please sign in to comment.