Skip to content

Commit

Permalink
Merge pull request #5175 from inception-project/feature/5174-Add-supp…
Browse files Browse the repository at this point in the history
…ort-for-additional-LLM-options

#5174 - Add support for additional LLM options
  • Loading branch information
reckart authored Nov 22, 2024
2 parents 016aa9b + 92dcdbd commit 7f68e4f
Show file tree
Hide file tree
Showing 26 changed files with 965 additions and 319 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,17 @@
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.commons.lang3.StringUtils.appendIfMissing;
import static org.springframework.http.HttpHeaders.CONTENT_TYPE;

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URLEncoder;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpRequest.BodyPublishers;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.commons.io.IOUtils;
import org.springframework.http.HttpHeaders;

import com.fasterxml.jackson.databind.ObjectMapper;

Expand All @@ -50,12 +46,12 @@ public class AzureAiOpenAiClientImpl

public AzureAiOpenAiClientImpl()
{
this.client = HttpClient.newBuilder().build();
client = HttpClient.newBuilder().build();
}

public AzureAiOpenAiClientImpl(HttpClient aClient)
{
this.client = aClient;
client = aClient;
}

protected HttpResponse<InputStream> sendRequest(HttpRequest aRequest) throws IOException
Expand All @@ -75,40 +71,10 @@ protected HttpResponse<InputStream> sendRequest(HttpRequest aRequest) throws IOE
protected String getResponseBody(HttpResponse<InputStream> response) throws IOException
{
if (response.body() != null) {
return IOUtils.toString(response.body(), StandardCharsets.UTF_8);
}
else {
return "";
}
}

protected <T> T deserializeResponse(HttpResponse<String> response, Class<T> aType)
throws IOException
{
try {
return objectMapper.readValue(response.body(), aType);
}
catch (IOException e) {
throw new IOException("Error while deserializing server response!", e);
}
}

protected String urlEncodeParameters(Map<String, String> aParameters)
{
if (aParameters.isEmpty()) {
return "";
}
StringBuilder uriBuilder = new StringBuilder();
for (Entry<String, String> param : aParameters.entrySet()) {
if (uriBuilder.length() > 0) {
uriBuilder.append("&");
}
uriBuilder.append(URLEncoder.encode(param.getKey(), UTF_8));
uriBuilder.append('=');
uriBuilder.append(URLEncoder.encode(param.getValue(), UTF_8));
return IOUtils.toString(response.body(), UTF_8);
}

return uriBuilder.toString();
return "";
}

@Override
Expand All @@ -117,7 +83,7 @@ public String generate(String aUrl, ChatCompletionRequest aRequest) throws IOExc
var request = HttpRequest.newBuilder() //
.uri(URI.create(
appendIfMissing(aUrl, "/") + "chat/completions?api-version=2023-05-15")) //
.header(HttpHeaders.CONTENT_TYPE, "application/json") //
.header(CONTENT_TYPE, "application/json") //
.header("api-key", aRequest.getApiKey()) //
.POST(BodyPublishers.ofString(JSONUtil.toJsonString(aRequest), UTF_8)) //
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.llm.azureaiopenai.client;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;
import static java.util.Arrays.asList;

import java.util.HashMap;
Expand All @@ -25,33 +26,46 @@

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;

import de.tudarmstadt.ukp.inception.recommendation.imls.llm.support.traits.DoubleOption;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.support.traits.Option;

public class ChatCompletionRequest
{
// See https://platform.openai.com/docs/api-reference/chat/create
public static final Option<Integer> MAX_TOKENS = new Option<>(Integer.class, "max_tokens");
public static final Option<Integer> SEED = new Option<>(Integer.class, "seed");
public static final Option<Integer> N = new Option<>(Integer.class, "n");
public static final Option<Double> FREQUENCY_PENALTY = new DoubleOption("frequency_penalty",
-2.0d, 2.0d);
public static final Option<Double> TEMPERATURE = new DoubleOption("temperature", 0.0d, 2.0d);

public static List<Option<?>> getAllOptions()
{
return asList(MAX_TOKENS, SEED, N);
return asList(MAX_TOKENS, SEED, FREQUENCY_PENALTY, TEMPERATURE);
}

private final @JsonIgnore String apiKey;
private final @JsonIgnore String model;

private final @JsonInclude(NON_NULL) GenerateResponseFormat format;
private final @JsonInclude(NON_NULL) @JsonProperty("frequency_penalty") Double frequencyPenalty;
private final @JsonInclude(NON_NULL) @JsonProperty("temperature") Double temperature;
private final @JsonInclude(NON_NULL) @JsonProperty("seed") Integer seed;

private final List<ChatCompletionMessage> messages;
private final @JsonInclude(Include.NON_NULL) GenerateResponseFormat format;

private ChatCompletionRequest(Builder builder)
{
messages = asList(new ChatCompletionMessage("user", builder.prompt));
format = builder.format;

model = builder.model;
apiKey = builder.apiKey;

format = builder.format;
frequencyPenalty = FREQUENCY_PENALTY.get(builder.options);
temperature = TEMPERATURE.get(builder.options);
seed = SEED.get(builder.options);
}

public String getApiKey()
Expand Down Expand Up @@ -85,43 +99,43 @@ public static final class Builder
private String apiKey;
private String prompt;
private GenerateResponseFormat format;
private Map<String, Object> options = new HashMap<>();
private Map<Option<?>, Object> options = new HashMap<>();

private Builder()
{
}

public Builder withModel(String aModel)
{
this.model = aModel;
model = aModel;
return this;
}

public Builder withApiKey(String aApiKey)
{
this.apiKey = aApiKey;
apiKey = aApiKey;
return this;
}

public Builder withPrompt(String aPrompt)
{
this.prompt = aPrompt;
prompt = aPrompt;
return this;
}

public Builder withFormat(GenerateResponseFormat aFormat)
{
this.format = aFormat;
format = aFormat;
return this;
}

public <T> Builder withOption(Option<T> aOption, T aValue)
{
if (aValue != null) {
this.options.put(aOption.getName(), aValue);
options.put(aOption, aValue);
}
else {
this.options.remove(aOption.getName());
options.remove(aOption);
}
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ protected String exchange(String aPrompt) throws IOException
var request = ChatCompletionRequest.builder() //
.withApiKey(((ApiKeyAuthenticationTraits) traits.getAuthentication()).getApiKey()) //
.withPrompt(aPrompt) //
.withOptions(traits.getOptions()) //
.withModel(traits.getModel());

if (traits.getFormat() == JSON) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,39 +94,45 @@ private Builder()

public Builder withModel(String aModel)
{
this.model = aModel;
model = aModel;
return this;
}

public Builder withApiKey(String aApiKey)
{
this.apiKey = aApiKey;
apiKey = aApiKey;
return this;
}

public Builder withPrompt(String aPrompt)
{
this.prompt = aPrompt;
prompt = aPrompt;
return this;
}

public Builder withResponseFormat(ResponseFormat aFormat)
{
this.format = aFormat;
format = aFormat;
return this;
}

public <T> Builder withOption(Option<T> aOption, T aValue)
{
if (aValue != null) {
this.options.put(aOption.getName(), aValue);
options.put(aOption.getName(), aValue);
}
else {
this.options.remove(aOption.getName());
options.remove(aOption.getName());
}
return this;
}

public Builder withOptions(Map<String, Object> aOptions)
{
options.putAll(aOptions);
return this;
}

public ChatCompletionRequest build()
{
return new ChatCompletionRequest(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class ChatCompletionResponse
private @JsonProperty("model") String model;
private @JsonProperty("created") long createdAt;
private @JsonProperty("choices") List<ChatCompletionChoice> choices;
private @JsonProperty("usage") ChatCompletionUsage usage;
private @JsonProperty("time_info") ChatCompletionTimeInfo timeInfo;

public String getModel()
{
Expand Down Expand Up @@ -58,4 +60,24 @@ public void setChoices(List<ChatCompletionChoice> aChoices)
{
choices = aChoices;
}

public ChatCompletionUsage getUsage()
{
return usage;
}

public void setUsage(ChatCompletionUsage aUsage)
{
usage = aUsage;
}

public ChatCompletionTimeInfo getTimeInfo()
{
return timeInfo;
}

public void setTimeInfo(ChatCompletionTimeInfo aTimeInfo)
{
timeInfo = aTimeInfo;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Technische Universität Darmstadt under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The Technische Universität Darmstadt
* licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.llm.chatgpt.client;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(ignoreUnknown = true)
public class ChatCompletionTimeInfo
{
private @JsonProperty("queue_time") double queueTime;
private @JsonProperty("prompt_time") double promptTime;
private @JsonProperty("completion_time") double completionTime;
private @JsonProperty("total_time") double totalTime;
private @JsonProperty("created") long created;

public double getQueueTime()
{
return queueTime;
}

public void setQueueTime(double aQueueTime)
{
queueTime = aQueueTime;
}

public double getPromptTime()
{
return promptTime;
}

public void setPromptTime(double aPromptTime)
{
promptTime = aPromptTime;
}

public double getCompletionTime()
{
return completionTime;
}

public void setCompletionTime(double aCompletionTime)
{
completionTime = aCompletionTime;
}

public double getTotalTime()
{
return totalTime;
}

public void setTotalTime(double aTotalTime)
{
totalTime = aTotalTime;
}

public long getCreated()
{
return created;
}

public void setCreated(long aCreated)
{
created = aCreated;
}
}
Loading

0 comments on commit 7f68e4f

Please sign in to comment.