Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding blue print doc for cohere multi-modal model #3229

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@ public CohereMultiModalEmbeddingPreProcessFunction() {
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs == null || docs.isEmpty() || (docs.size() == 1 && docs.get(0) == null)) {
if (docs == null || docs.isEmpty() || docs.get(0) == null) {
throw new IllegalArgumentException("No image provided");
}
}

@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, String> parametersMap = new HashMap<>();
Map<String, Object> parametersMap = new HashMap<>();

/**
* Cohere multi-modal model expects either image or texts, not both.
* For image, customer can use this pre-process function. For texts, customer can use
* connector.pre_process.cohere.embedding
* Cohere expects An array of image data URIs for the model to embed. Maximum number of images per call is 1.
*/
parametersMap.put("images", inputData.getDocs().get(0));
parametersMap.put("images", inputData.getDocs());
return RemoteInferenceInputDataSet
.builder()
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void testProcess_whenCorrectInput_expectCorrectOutput() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(1, dataSet.getParameters().size());
assertEquals("imageString", dataSet.getParameters().get("images"));
assertEquals("[\"imageString\"]", dataSet.getParameters().get("images"));

}

Expand Down
Loading
Loading