Skip to content

Commit

Permalink
#5140 - Support access to tag sets in prompt template
Browse files Browse the repository at this point in the history
-Add tagset access to the ChatGPT recommender
  • Loading branch information
reckart committed Nov 6, 2024
1 parent 7fd1374 commit 7237323
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
package de.tudarmstadt.ukp.inception.recommendation.imls.chatgpt;

import static de.tudarmstadt.ukp.inception.recommendation.imls.support.llm.prompt.PromptContextGenerator.VAR_EXAMPLES;
import static de.tudarmstadt.ukp.inception.recommendation.imls.support.llm.prompt.PromptContextGenerator.VAR_TAGS;
import static de.tudarmstadt.ukp.inception.recommendation.imls.support.llm.prompt.PromptContextGenerator.getPromptContextGenerator;
import static de.tudarmstadt.ukp.inception.recommendation.imls.support.llm.response.ResponseExtractor.getResponseExtractor;
import static java.util.stream.Collectors.toMap;
import static org.apache.commons.lang3.exception.ExceptionUtils.getRootCauseMessage;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.Map;
import java.util.LinkedHashMap;
import java.util.Objects;

import org.apache.uima.cas.CAS;
import org.slf4j.Logger;
Expand All @@ -39,6 +42,7 @@
import de.tudarmstadt.ukp.inception.recommendation.imls.chatgpt.client.ResponseFormat;
import de.tudarmstadt.ukp.inception.recommendation.imls.support.llm.prompt.JinjaPromptRenderer;
import de.tudarmstadt.ukp.inception.rendering.model.Range;
import de.tudarmstadt.ukp.inception.schema.api.AnnotationSchemaService;
import de.tudarmstadt.ukp.inception.security.client.auth.apikey.ApiKeyAuthenticationTraits;
import de.tudarmstadt.ukp.inception.support.logging.LogMessage;

Expand All @@ -52,15 +56,17 @@ public class ChatGptRecommender
private final ChatGptRecommenderTraits traits;

private final ChatGptClient client;
private final AnnotationSchemaService schemaService;
private final JinjaPromptRenderer promptRenderer;

public ChatGptRecommender(Recommender aRecommender, ChatGptRecommenderTraits aTraits,
ChatGptClient aClient)
ChatGptClient aClient, AnnotationSchemaService aSchemaService)
{
super(aRecommender);

traits = aTraits;
client = aClient;
schemaService = aSchemaService;
promptRenderer = new JinjaPromptRenderer();
}

Expand All @@ -69,8 +75,18 @@ public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
var responseExtractor = getResponseExtractor(traits.getExtractionMode());
var globalBindings = new LinkedHashMap<String, Object>();
var examples = responseExtractor.generate(this, aCas, MAX_FEW_SHOT_EXAMPLES);
var globalBindings = Map.of(VAR_EXAMPLES, examples);
globalBindings.put(VAR_EXAMPLES, examples);

var tagset = getRecommender().getFeature().getTagset();
if (tagset != null) {
var tags = schemaService.listTags(tagset).stream() //
.collect(toMap( //
tag -> tag.getName(), //
tag -> Objects.toString(tag.getDescription(), "")));
globalBindings.put(VAR_TAGS, tags);
}

getPromptContextGenerator(traits.getPromptingMode())
.generate(this, aCas, aBegin, aEnd, globalBindings).forEach(promptContext -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactoryImplBase;
import de.tudarmstadt.ukp.inception.recommendation.imls.chatgpt.client.ChatGptClient;
import de.tudarmstadt.ukp.inception.schema.api.AnnotationSchemaService;
import de.tudarmstadt.ukp.inception.support.io.WatchedResourceFile;
import de.tudarmstadt.ukp.inception.support.yaml.YamlUtil;
import de.tudarmstadt.ukp.inception.ui.core.docanno.layer.DocumentMetadataLayerSupport;
Expand All @@ -53,12 +54,14 @@ public class ChatGptRecommenderFactory
private final static Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private final ChatGptClient client;
private final AnnotationSchemaService schemaService;

private WatchedResourceFile<ArrayList<Preset>> presets;

public ChatGptRecommenderFactory(ChatGptClient aClient)
public ChatGptRecommenderFactory(ChatGptClient aClient, AnnotationSchemaService aSchemaService)
{
client = aClient;
schemaService = aSchemaService;

var presetsResource = getClass().getResource("presets.yaml");
presets = new WatchedResourceFile<>(presetsResource, is -> YamlUtil.getObjectMapper()
Expand All @@ -83,7 +86,7 @@ public String getName()
public RecommendationEngine build(Recommender aRecommender)
{
ChatGptRecommenderTraits traits = readTraits(aRecommender);
return new ChatGptRecommender(aRecommender, traits, client);
return new ChatGptRecommender(aRecommender, traits, client, schemaService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import de.tudarmstadt.ukp.inception.recommendation.imls.chatgpt.ChatGptRecommenderFactory;
import de.tudarmstadt.ukp.inception.recommendation.imls.chatgpt.client.ChatGptClient;
import de.tudarmstadt.ukp.inception.recommendation.imls.chatgpt.client.ChatGptClientImpl;
import de.tudarmstadt.ukp.inception.schema.api.AnnotationSchemaService;

@Configuration
@ConditionalOnProperty(prefix = "recommender.chatgpt", name = "enabled", //
Expand All @@ -37,8 +38,9 @@ public ChatGptClient chatGptClient()
}

@Bean
public ChatGptRecommenderFactory chatGptRecommenderFactory(ChatGptClient aClient)
public ChatGptRecommenderFactory chatGptRecommenderFactory(ChatGptClient aClient,
AnnotationSchemaService aSchemaService)
{
return new ChatGptRecommenderFactory(aClient);
return new ChatGptRecommenderFactory(aClient, aSchemaService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,48 @@
prompt: |-
Identify all named entities in the following text.
{% if tags %}
Label each entity using one of the following labels:
{% for tag, description in tags.items() %}
* {{- tag -}} {%if description %}: {{- description -}}{% endif %}{% endfor %}
{% endif %}
Text:
```
{{ text }}
```
- name: Extract named entities from sentenes (dynamic few-shot)
promptingMode: per-sentence
format: json
extractionMode: mentions-from-json
prompt: |-
Identify all named entities in the following text.
Identify all named entities in the following text and return them as JSON.
{% if tags %}
Label each entity using one of the following labels:
{% for tag, description in tags.items() %}
* {{- tag -}} {%if description %}: {{- description -}}{% endif %}{% endfor %}
{% endif %}
{% if examples %}
{% for example in examples %}
Text:
'''
```
{{ example.getText() }}
'''
```
Response:
{{ example.getLabelledMentions() | tojson }}
{% endfor %}
Text:
{% endif %}
'''
```
{{ text }}
'''
```
{% if examples %}
Response:
{% endif %}
Expand All @@ -41,8 +56,9 @@
promptingMode: per-sentence
extractionMode: response-as-label
prompt: |-
Summarize the following sentence in a single word.
Summarize the following text in a single word.
Text:
```
{{ text }}
```
Expand All @@ -53,6 +69,7 @@
prompt: |-
Briefly describe what the following text is about.
Text:
```
{{ text }}
```
Expand All @@ -63,6 +80,7 @@
prompt: |-
Briefly summarize the following text.
Text:
```
{% for x in cas.select('custom.Span') %}
{{ x }}
Expand All @@ -73,8 +91,9 @@
promptingMode: per-annotation
extractionMode: response-as-label
prompt: |-
Very briefly describe the meaning of `{{ text }}` in the following sentence.
Very briefly describe the meaning of `{{ text }}` in the following text.
Text:
```
{{ sentence }}
```
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public interface PromptContextGenerator
static final String VAR_DOCUMENT = "document";
static final String VAR_EXAMPLES = "examples";
static final String VAR_CAS = "cas";
static final String VAR_TAGS = "tags";

Stream<PromptContext> generate(RecommendationEngine aEngine, CAS aCas, int aBegin, int aEnd,
Map<String, ? extends Object> aBindings);
Expand Down

0 comments on commit 7237323

Please sign in to comment.