Skip to content

Commit

Permalink
fix: broken evaluation tutorial
Browse files Browse the repository at this point in the history
Task: IL-438
  • Loading branch information
NiklasKoehneckeAA committed May 30, 2024
1 parent f698c83 commit 111968a
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
- The document index client now correctly URL-encodes document names in its queries.
- The `ArgillaEvaluator` not properly supports `dataset_name`
- Update broken README links to Read The Docs
- The `evaluation` tutorial contained a broken multi-label classify example. This was fixed.

### Deprecations
...
Expand Down
61 changes: 45 additions & 16 deletions src/documentation/evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@
" )\n",
" for item in data\n",
" ],\n",
" dataset_name=\"tweet_topic_multi\",\n",
" dataset_name=\"tweet_topic_single\",\n",
")"
]
},
Expand Down Expand Up @@ -279,10 +279,39 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multilabel classification\n",
"\n",
"As an alternative to the `PromptBasedClassify` we now gonne use the `EmbeddingBasedClassify` for multi label classifications.\n",
"In this case, we have to provide some example for each class.\n",
"In this case, we have to provide some example for each class and our examples need to contain a list of classes instead of a single class\n",
"\n",
"We can even reuse our data repositories:"
"First, we will create a new dataset with more expected classes and a slightly different format:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = dataset_repository.create_dataset(\n",
" examples=[\n",
" Example(\n",
" input=ClassifyInput(chunk=TextChunk(item[\"text\"]), labels=all_labels),\n",
" expected_output=item[\n",
" \"label_name\"\n",
" ], # <- difference here, we take all labels instead of a single one\n",
" )\n",
" for item in data\n",
" ],\n",
" dataset_name=\"tweet_topic_multi\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can then create examples for our labels and put them into our `EmbeddingBasedClassify`"
]
},
{
Expand All @@ -308,8 +337,16 @@
" LabelWithExamples(name=name, examples=examples)\n",
" for name, examples in build_labels_and_examples(all_data[25:]).items()\n",
" ],\n",
")\n",
"eval_logic = MultiLabelClassifyEvaluationLogic(threshold=0.60)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval_logic = MultiLabelClassifyEvaluationLogic(threshold=0.7)\n",
"aggregation_logic = MultiLabelClassifyAggregationLogic()\n",
"\n",
"embedding_based_classify_runner = Runner(\n",
Expand Down Expand Up @@ -352,23 +389,15 @@
" embedding_based_classify_evaluation_result.id\n",
" )\n",
")\n",
"embedding_based_classify_aggregation_result.raise_on_evaluation_failure()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embedding_based_classify_aggregation_result.raise_on_evaluation_failure()\n",
"embedding_based_classify_aggregation_result.statistics.macro_avg"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Apparently, our method has a great recall value, i.e. all relevant labels are retrieved. However, the low precision value indicates that we tend to falsely predict labels at times.\n",
"Apparently, our method has a great recall value, i.e. all relevant labels are retrieved. However, the low precision value indicates that we tend to falsely predict labels at times. This gives us an indicator that the threshold for our evaluation logic is probably too low with `0.6`\n",
"\n",
"Note, that the evaluation criteria for the multiple label approach are a lot harsher; we evaluate whether we correctly predict all labels & not just one of the correct ones!"
]
Expand Down Expand Up @@ -402,7 +431,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.12.2"
}
},
"nbformat": 4,
Expand Down
65 changes: 65 additions & 0 deletions tests/use_cases/classify/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
InMemoryEvaluationRepository,
InMemoryRunRepository,
Runner,
SuccessfulExampleOutput,
)
from intelligence_layer.examples import (
AggregatedLabelInfo,
Expand All @@ -31,6 +32,7 @@
MultiLabelClassifyEvaluation,
MultiLabelClassifyEvaluationLogic,
MultiLabelClassifyOutput,
Probability,
)


Expand Down Expand Up @@ -216,6 +218,69 @@ def classify_runner(
)


def test_multi_label_eval_logic_works_correctly() -> None:
threshold = 0.5
eval_logic = MultiLabelClassifyEvaluationLogic(threshold=threshold)
tp = "aaaa"
tn = "bbbb"
fp = "cccc"
fn = "dddd"
expected_output: Sequence[str] = [tp, fn]
input_example = Example(
input=ClassifyInput(chunk=TextChunk("..."), labels=frozenset([tp, tn, fp, fn])),
expected_output=expected_output,
)
input_output = SuccessfulExampleOutput(
run_id="",
example_id="",
output=MultiLabelClassifyOutput(
scores={
tp: Probability(threshold + 0.1),
tn: Probability(threshold - 0.1),
fp: Probability(threshold + 0.1),
fn: Probability(threshold - 0.1),
}
),
)
res = eval_logic.do_evaluate(input_example, input_output)
assert tp in res.tp
assert tn in res.tn
assert fp in res.fp
assert fn in res.fn


def test_multi_label_eval_logic_works_if_everything_is_over_threshold() -> None:
threshold = 0.5
eval_logic = MultiLabelClassifyEvaluationLogic(threshold=threshold)
tp = "aaaa"
tn = "bbbb"
fp = "cccc"
fn = "dddd"
expected_output: Sequence[str] = [tp, fn]

input_example = Example(
input=ClassifyInput(chunk=TextChunk("..."), labels=frozenset([tp, tn, fp, fn])),
expected_output=expected_output,
)
input_output = SuccessfulExampleOutput(
run_id="",
example_id="",
output=MultiLabelClassifyOutput(
scores={
tp: Probability(threshold + 0.1),
tn: Probability(threshold + 0.1),
fp: Probability(threshold + 0.1),
fn: Probability(threshold + 0.1),
}
),
)
res = eval_logic.do_evaluate(input_example, input_output)
assert tp in res.tp
assert tn in res.fp
assert fp in res.fp
assert fn in res.tp


def test_multi_label_classify_evaluator_single_example(
single_entry_dataset_name: str,
classify_evaluator: Evaluator[
Expand Down

0 comments on commit 111968a

Please sign in to comment.