Skip to content

Commit

Permalink
#4292 - ollama-based recommender
Browse files Browse the repository at this point in the history
- Towards configurable advanced options
  • Loading branch information
reckart committed Nov 28, 2023
1 parent efb78fa commit 2f704b8
Show file tree
Hide file tree
Showing 14 changed files with 344 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_EMPTY;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
Expand Down Expand Up @@ -51,7 +52,7 @@ public class OllamaRecommenderTraits

private ExtractionMode extractionMode = ExtractionMode.RESPONSE_AS_LABEL;

private @JsonInclude(NON_EMPTY) Map<String, Object> options = new HashMap<String, Object>();
private @JsonInclude(NON_EMPTY) Map<String, Object> options = new LinkedHashMap<String, Object>();

public String getUrl()
{
Expand Down Expand Up @@ -122,4 +123,15 @@ public void setExtractionMode(ExtractionMode aExtractionMode)
{
extractionMode = aExtractionMode;
}

public Map<String, Object> getOptions()
{
return Collections.unmodifiableMap(options);
}

public void setOptions(Map<String, Object> aOptions)
{
options.clear();
options.putAll(aOptions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,26 @@
<textarea wicket:id="prompt" class="form-control" rows="10"/>
</div>
</div>
<form wicket:id="optionSettingsForm">
<div class="row form-row">
<label class="col-sm-3 col-form-label">
<wicket:message key="options"/>
</label>
<div class="col-sm-9">
<div class="input-group">
<select wicket:id="option" class="form-select"/>
<button wicket:id="addOption" class="btn btn-outline-secondary" type="button">Add</button>
</div>
<div wicket:id="optionSettingsContainer">
<div wicket:id="optionSettings">
<span wicket:id="option"/>
<input wicket:id="value" type="text" class="form-control"/>
<button wicket:id="removeOption" class="btn btn-outline-secondary" type="button">Remove</button>
</div>
</div>
</div>
</div>
</form>
</form>
</wicket:panel>
</html>
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,38 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.ollama;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

import org.apache.wicket.ajax.AjaxRequestTarget;
import org.apache.wicket.feedback.IFeedback;
import org.apache.wicket.markup.html.WebMarkupContainer;
import org.apache.wicket.markup.html.basic.Label;
import org.apache.wicket.markup.html.form.CheckBox;
import org.apache.wicket.markup.html.form.ChoiceRenderer;
import org.apache.wicket.markup.html.form.DropDownChoice;
import org.apache.wicket.markup.html.form.Form;
import org.apache.wicket.markup.html.form.TextArea;
import org.apache.wicket.markup.html.form.TextField;
import org.apache.wicket.markup.html.list.ListItem;
import org.apache.wicket.markup.html.list.ListView;
import org.apache.wicket.model.CompoundPropertyModel;
import org.apache.wicket.model.IModel;
import org.apache.wicket.model.Model;
import org.apache.wicket.model.PropertyModel;
import org.apache.wicket.model.util.ListModel;
import org.apache.wicket.spring.injection.annot.SpringBean;

import de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.AbstractTraitsEditor;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactory;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaGenerateRequest;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.Option;
import de.tudarmstadt.ukp.inception.support.lambda.LambdaAjaxFormComponentUpdatingBehavior;
import de.tudarmstadt.ukp.inception.support.lambda.LambdaAjaxLink;
import de.tudarmstadt.ukp.inception.support.lambda.LambdaAjaxSubmitLink;

public class OllamaRecommenderTraitsEditor
extends AbstractTraitsEditor
Expand All @@ -48,6 +62,9 @@ public class OllamaRecommenderTraitsEditor

private final IModel<OllamaRecommenderTraits> traits;

private final WebMarkupContainer optionSettingsContainer;
private final IModel<List<OptionSetting>> optionSettings;

public OllamaRecommenderTraitsEditor(String aId, IModel<Recommender> aRecommender,
IModel<List<Preset>> aPresets)
{
Expand All @@ -57,7 +74,7 @@ public OllamaRecommenderTraitsEditor(String aId, IModel<Recommender> aRecommende

traits = CompoundPropertyModel.of(toolFactory.readTraits(aRecommender.getObject()));

Form<OllamaRecommenderTraits> form = new Form<OllamaRecommenderTraits>(MID_FORM, traits)
var form = new Form<OllamaRecommenderTraits>(MID_FORM, traits)
{
private static final long serialVersionUID = -1;

Expand Down Expand Up @@ -95,7 +112,61 @@ protected void onSubmit()
form.add(new PromptingModeSelect("promptingMode"));
form.add(new ExtractionModeSelect("extractionMode"));
form.add(new OllamaResponseFormatSelect("format"));

add(form);

var optionSettingsForm = new Form<>("optionSettingsForm",
CompoundPropertyModel.of(new OptionSetting()));
optionSettingsForm.setVisibilityAllowed(false); // FIXME Not quite ready yet
form.add(optionSettingsForm);

optionSettingsForm.add(
new DropDownChoice<Option<?>>("option", OllamaGenerateRequest.getAllOptions()));
optionSettingsForm
.add(new LambdaAjaxSubmitLink<OptionSetting>("addOption", this::addOptionSetting));

optionSettingsContainer = new WebMarkupContainer("optionSettingsContainer");
optionSettingsContainer.setOutputMarkupPlaceholderTag(true);
optionSettingsForm.add(optionSettingsContainer);

optionSettings = new ListModel<>(traits.getObject().getOptions().entrySet().stream()
.map(e -> new OptionSetting(e.getKey(), String.valueOf(e.getValue())))
.collect(Collectors.toCollection(ArrayList::new)));

optionSettingsContainer.add(createOptionSettingsList("optionSettings", optionSettings));
}

private ListView<OptionSetting> createOptionSettingsList(String aId,
IModel<List<OptionSetting>> aOptionSettings)
{
return new ListView<OptionSetting>(aId, aOptionSettings)
{
private static final long serialVersionUID = 244305980337592760L;

@Override
protected void populateItem(ListItem<OptionSetting> aItem)
{
var optionSetting = aItem.getModelObject();

aItem.add(new Label("option", optionSetting.getOption()));
aItem.add(new TextField<>("value", PropertyModel.of(optionSetting, "value")));
aItem.add(new LambdaAjaxLink("removeOption",
_target -> removeOptionSetting(_target, aItem.getModelObject())));
}
};
}

private void addOptionSetting(AjaxRequestTarget aTarget, Form<OptionSetting> aForm)
{
optionSettings.getObject()
.add(new OptionSetting(aForm.getModel().getObject().getOption(), ""));

aTarget.addChildren(getPage(), IFeedback.class);
aTarget.add(optionSettingsContainer);
}

private void removeOptionSetting(AjaxRequestTarget aTarget, OptionSetting aBinding)
{
optionSettings.getObject().remove(aBinding);
aTarget.add(optionSettingsContainer);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ model=Model
prompt=Prompt
raw=Raw prompt
preset=Preset
options=Advanced options

promptingMode=Processing mode
PromptingMode.PER_ANNOTATION=Per annotation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Technische Universität Darmstadt under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The Technische Universität Darmstadt
* licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.ollama;

import java.io.Serializable;

public class OptionSetting
implements Serializable
{
private static final long serialVersionUID = 639108348141364660L;

private String option;
private String value;

public OptionSetting()
{
// For serialization
}

public OptionSetting(String aOption, String aValue)
{
option = aOption;
value = aValue;
}

public String getOption()
{
return option;
}

public void setOption(String aOption)
{
option = aOption;
}

public String getValue()
{
return value;
}

public void setValue(String aValue)
{
value = aValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.net.http.HttpRequest.BodyPublishers;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

Expand Down Expand Up @@ -63,14 +64,7 @@ protected HttpResponse<InputStream> sendRequest(HttpRequest aRequest) throws IOE
try {
var response = client.send(aRequest, HttpResponse.BodyHandlers.ofInputStream());

// If the response indicates that the request was not successful,
// then it does not make sense to go on and try to decode the XMI
if (response.statusCode() >= HTTP_BAD_REQUEST) {
String responseBody = getResponseBody(response);
String msg = format("Request was not successful: [%d] - [%s]",
response.statusCode(), responseBody);
throw new IOException(msg);
}
handleError(response);

return response;
}
Expand Down Expand Up @@ -129,14 +123,7 @@ public String generate(String aUrl, OllamaGenerateRequest aRequest) throws IOExc

var response = sendRequest(request);

// If the response indicates that the request was not successful,
// then it does not make sense to go on and try to decode the XMI
if (response.statusCode() >= HTTP_BAD_REQUEST) {
String responseBody = getResponseBody(response);
String msg = format("Request was not successful: [%d] - [%s]", response.statusCode(),
responseBody);
throw new IOException(msg);
}
handleError(response);

var result = new StringBuilder();
try (var is = response.body()) {
Expand All @@ -149,4 +136,30 @@ public String generate(String aUrl, OllamaGenerateRequest aRequest) throws IOExc

return result.toString();
}

public List<OllamaModel> listModels(String aUrl) throws IOException
{
var request = HttpRequest.newBuilder() //
.uri(URI.create(appendIfMissing(aUrl, "/") + "api/tags")) //
.header(HttpHeaders.CONTENT_TYPE, "application/json").GET() //
.build();

var response = sendRequest(request);

handleError(response);

try (var is = response.body()) {
return objectMapper.readValue(is, OllamaTagsResponse.class).getModels();
}
}

private void handleError(HttpResponse<InputStream> response) throws IOException
{
if (response.statusCode() >= HTTP_BAD_REQUEST) {
String responseBody = getResponseBody(response);
String msg = format("Request was not successful: [%d] - [%s]", response.statusCode(),
responseBody);
throw new IOException(msg);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client;

import static java.util.Arrays.asList;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.fasterxml.jackson.annotation.JsonInclude;
Expand Down Expand Up @@ -69,6 +72,15 @@ public class OllamaGenerateRequest
"rope_frequency_scale");
public static final Option<Integer> NUM_THREAD = new Option<>(Integer.class, "num_thread");

public static List<Option<?>> getAllOptions()
{
return asList(NUM_KEEP, SEED, NUM_PREDICT, TOP_K, TOP_P, TFS_Z, TYPICAL_P, REPEAT_LAST_N,
TEMPERATURE, REPEAT_PENALTY, PRESENCE_PENALTY, FREQUENCY_PENALTY, MIROSTAT,
MIROSTAT_TAU, MIROSTAT_ETA, PENALIZE_NEWLINE, STOP, NUMA, NUM_CTX, NUM_BATCH,
NUM_GQA, NUM_GPU, MAIN_GPU, LOW_VRAM, F16_KV, LOGITS_ALL, VOCAB_ONLY, USE_MMAP,
USE_MLOCK, EMBEDDING_ONLY, ROPE_FREQUENCY_BASE, ROPE_FREQUENCY_SCALE, NUM_THREAD);
}

private String model;
private String prompt;
private boolean stream;
Expand Down
Loading

0 comments on commit 2f704b8

Please sign in to comment.