diff --git a/CHANGELOG.md b/CHANGELOG.md index c8b3c62b7..dc0f4e80b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ - The document index client now correctly URL-encodes document names in its queries. - The `ArgillaEvaluator` not properly supports `dataset_name` - Update broken README links to Read The Docs + - The `evaluation` tutorial contained a broken multi-label classify example. This was fixed. ### Deprecations ... diff --git a/src/documentation/evaluation.ipynb b/src/documentation/evaluation.ipynb index 75eda7620..4289d6335 100644 --- a/src/documentation/evaluation.ipynb +++ b/src/documentation/evaluation.ipynb @@ -210,7 +210,7 @@ " )\n", " for item in data\n", " ],\n", - " dataset_name=\"tweet_topic_multi\",\n", + " dataset_name=\"tweet_topic_single\",\n", ")" ] }, @@ -279,10 +279,39 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "## Multilabel classification\n", + "\n", "As an alternative to the `PromptBasedClassify` we now gonne use the `EmbeddingBasedClassify` for multi label classifications.\n", - "In this case, we have to provide some example for each class.\n", + "In this case, we have to provide some example for each class and our examples need to contain a list of classes instead of a single class\n", "\n", - "We can even reuse our data repositories:" + "First, we will create a new dataset with more expected classes and a slightly different format:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = dataset_repository.create_dataset(\n", + " examples=[\n", + " Example(\n", + " input=ClassifyInput(chunk=TextChunk(item[\"text\"]), labels=all_labels),\n", + " expected_output=item[\n", + " \"label_name\"\n", + " ], # <- difference here, we take all labels instead of a single one\n", + " )\n", + " for item in data\n", + " ],\n", + " dataset_name=\"tweet_topic_multi\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can then create examples for our labels and put them into our `EmbeddingBasedClassify`" ] }, { @@ -308,8 +337,16 @@ " LabelWithExamples(name=name, examples=examples)\n", " for name, examples in build_labels_and_examples(all_data[25:]).items()\n", " ],\n", - ")\n", - "eval_logic = MultiLabelClassifyEvaluationLogic(threshold=0.60)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "eval_logic = MultiLabelClassifyEvaluationLogic(threshold=0.7)\n", "aggregation_logic = MultiLabelClassifyAggregationLogic()\n", "\n", "embedding_based_classify_runner = Runner(\n", @@ -352,15 +389,7 @@ " embedding_based_classify_evaluation_result.id\n", " )\n", ")\n", - "embedding_based_classify_aggregation_result.raise_on_evaluation_failure()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "embedding_based_classify_aggregation_result.raise_on_evaluation_failure()\n", "embedding_based_classify_aggregation_result.statistics.macro_avg" ] }, @@ -368,7 +397,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Apparently, our method has a great recall value, i.e. all relevant labels are retrieved. However, the low precision value indicates that we tend to falsely predict labels at times.\n", + "Apparently, our method has a great recall value, i.e. all relevant labels are retrieved. However, the low precision value indicates that we tend to falsely predict labels at times. This gives us an indicator that the threshold for our evaluation logic is probably too low with `0.6`\n", "\n", "Note, that the evaluation criteria for the multiple label approach are a lot harsher; we evaluate whether we correctly predict all labels & not just one of the correct ones!" ] @@ -402,7 +431,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/tests/use_cases/classify/test_classify.py b/tests/use_cases/classify/test_classify.py index 80fddae2a..882c61c04 100644 --- a/tests/use_cases/classify/test_classify.py +++ b/tests/use_cases/classify/test_classify.py @@ -19,6 +19,7 @@ InMemoryEvaluationRepository, InMemoryRunRepository, Runner, + SuccessfulExampleOutput, ) from intelligence_layer.examples import ( AggregatedLabelInfo, @@ -31,6 +32,7 @@ MultiLabelClassifyEvaluation, MultiLabelClassifyEvaluationLogic, MultiLabelClassifyOutput, + Probability, ) @@ -216,6 +218,69 @@ def classify_runner( ) +def test_multi_label_eval_logic_works_correctly() -> None: + threshold = 0.5 + eval_logic = MultiLabelClassifyEvaluationLogic(threshold=threshold) + tp = "aaaa" + tn = "bbbb" + fp = "cccc" + fn = "dddd" + expected_output: Sequence[str] = [tp, fn] + input_example = Example( + input=ClassifyInput(chunk=TextChunk("..."), labels=frozenset([tp, tn, fp, fn])), + expected_output=expected_output, + ) + input_output = SuccessfulExampleOutput( + run_id="", + example_id="", + output=MultiLabelClassifyOutput( + scores={ + tp: Probability(threshold + 0.1), + tn: Probability(threshold - 0.1), + fp: Probability(threshold + 0.1), + fn: Probability(threshold - 0.1), + } + ), + ) + res = eval_logic.do_evaluate(input_example, input_output) + assert tp in res.tp + assert tn in res.tn + assert fp in res.fp + assert fn in res.fn + + +def test_multi_label_eval_logic_works_if_everything_is_over_threshold() -> None: + threshold = 0.5 + eval_logic = MultiLabelClassifyEvaluationLogic(threshold=threshold) + tp = "aaaa" + tn = "bbbb" + fp = "cccc" + fn = "dddd" + expected_output: Sequence[str] = [tp, fn] + + input_example = Example( + input=ClassifyInput(chunk=TextChunk("..."), labels=frozenset([tp, tn, fp, fn])), + expected_output=expected_output, + ) + input_output = SuccessfulExampleOutput( + run_id="", + example_id="", + output=MultiLabelClassifyOutput( + scores={ + tp: Probability(threshold + 0.1), + tn: Probability(threshold + 0.1), + fp: Probability(threshold + 0.1), + fn: Probability(threshold + 0.1), + } + ), + ) + res = eval_logic.do_evaluate(input_example, input_output) + assert tp in res.tp + assert tn in res.fp + assert fp in res.fp + assert fn in res.tp + + def test_multi_label_classify_evaluator_single_example( single_entry_dataset_name: str, classify_evaluator: Evaluator[