From 9f7cabacda2bc47c50d81b370e9aaa32c25c207c Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Thu, 28 Jul 2022 19:29:32 -0700 Subject: [PATCH 1/5] Add 8.5 migration docs (#88923) This commit adds the a migration docs file for 8.5. This was copied from the 8.4 file, which had no migration notes. --- docs/reference/migration/index.asciidoc | 2 ++ docs/reference/migration/migrate_8_5.asciidoc | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 docs/reference/migration/migrate_8_5.asciidoc diff --git a/docs/reference/migration/index.asciidoc b/docs/reference/migration/index.asciidoc index 2a7a1b32131bc..58843a4736bb1 100644 --- a/docs/reference/migration/index.asciidoc +++ b/docs/reference/migration/index.asciidoc @@ -1,11 +1,13 @@ include::migration_intro.asciidoc[] +* <> * <> * <> * <> * <> * <> +include::migrate_8_5.asciidoc[] include::migrate_8_4.asciidoc[] include::migrate_8_3.asciidoc[] include::migrate_8_2.asciidoc[] diff --git a/docs/reference/migration/migrate_8_5.asciidoc b/docs/reference/migration/migrate_8_5.asciidoc new file mode 100644 index 0000000000000..91404e7b18ec5 --- /dev/null +++ b/docs/reference/migration/migrate_8_5.asciidoc @@ -0,0 +1,22 @@ +[[migrating-8.5]] +== Migrating to 8.5 +++++ +8.5 +++++ + +This section discusses the changes that you need to be aware of when migrating +your application to {es} 8.5. + +See also <> and <>. + +coming::[8.5.0] + + +[discrete] +[[breaking-changes-8.5]] +=== Breaking changes + +// tag::notable-breaking-changes[] +There are no breaking changes in {es} 8.5. +// end::notable-breaking-changes[] + From be4d809b39346bfb2acd14721977314d26161031 Mon Sep 17 00:00:00 2001 From: Rory Hunter Date: Fri, 29 Jul 2022 11:27:11 +0100 Subject: [PATCH 2/5] Add generateStubReleaseNotes task (#88933) When we feature freeze Elasticsearch, we need to create stub documentation for the next version. This turns out to be as simple as running the usual `generateReleaseNotes` task without any inputs. --- .../internal/release/ReleaseToolsPlugin.java | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ReleaseToolsPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ReleaseToolsPlugin.java index fb6ddc5e1be16..c93320dc2b498 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ReleaseToolsPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ReleaseToolsPlugin.java @@ -12,6 +12,7 @@ import org.elasticsearch.gradle.VersionProperties; import org.elasticsearch.gradle.internal.conventions.precommit.PrecommitTaskPlugin; import org.elasticsearch.gradle.internal.precommit.ValidateYamlAgainstSchemaTask; +import org.gradle.api.Action; import org.gradle.api.Plugin; import org.gradle.api.Project; import org.gradle.api.file.Directory; @@ -22,6 +23,7 @@ import org.gradle.api.tasks.util.PatternSet; import java.io.File; +import java.util.function.Function; import javax.inject.Inject; @@ -67,10 +69,14 @@ public void apply(Project project) { task.dependsOn(validateChangelogsAgainstYamlTask); }); - project.getTasks().register("generateReleaseNotes", GenerateReleaseNotesTask.class).configure(task -> { + final Function> configureGenerateTask = shouldConfigureYamlFiles -> task -> { task.setGroup("Documentation"); - task.setDescription("Generates release notes from changelog files held in this checkout"); - task.setChangelogs(yamlFiles); + if (shouldConfigureYamlFiles) { + task.setChangelogs(yamlFiles); + task.setDescription("Generates release notes from changelog files held in this checkout"); + } else { + task.setDescription("Generates stub release notes e.g. after feature freeze"); + } task.setReleaseNotesIndexTemplate(projectDirectory.file(RESOURCES + "templates/release-notes-index.asciidoc")); task.setReleaseNotesIndexFile(projectDirectory.file("docs/reference/release-notes.asciidoc")); @@ -100,7 +106,12 @@ public void apply(Project project) { task.setMigrationIndexFile(projectDirectory.file("docs/reference/migration/index.asciidoc")); task.dependsOn(validateChangelogsTask); - }); + }; + + project.getTasks().register("generateReleaseNotes", GenerateReleaseNotesTask.class).configure(configureGenerateTask.apply(true)); + project.getTasks() + .register("generateStubReleaseNotes", GenerateReleaseNotesTask.class) + .configure(configureGenerateTask.apply(false)); project.getTasks().register("pruneChangelogs", PruneChangelogsTask.class).configure(task -> { task.setGroup("Documentation"); From bd624ba2dc53cbaf1bdc3b6240b568fa263a8603 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Fri, 29 Jul 2022 12:42:53 +0200 Subject: [PATCH 3/5] Speed up operations on BlobStoreIndexShardSnapshots (#88912) This fixes a couple of slow points in `BlobStoreIndexShardSnapshots`, which become performance critical when working with large repositories. 1. Fix `physicalFiles` containing the same `FileInfo` instances repeatedly for every snapshot that holds the file. Without this fix the map can hold lists as long as the number of snapshots for the shard for files common to all snapshots of the shard. Also, only lazy build the map since it's only used during snapshotting and internalize the logic into `BlobStoreIndexShardSnapshots` so we don't have to bother with wrapping as unmodifiable. 2. Add efficient copy constructors for all 3 operations on the shard to avoid expensive looping over all snapshots and their files in many cases. 3. Use list instead of redundant map in deserialization, we weren't using the map for any deduplication anyways and are safe here thanks to Jackson's duplicate name detection --- .../BlobStoreIndexShardSnapshots.java | 132 +++++++++++------- .../blobstore/BlobStoreRepository.java | 35 ++--- 2 files changed, 91 insertions(+), 76 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/snapshots/blobstore/BlobStoreIndexShardSnapshots.java b/server/src/main/java/org/elasticsearch/index/snapshots/blobstore/BlobStoreIndexShardSnapshots.java index f5f95b25a684d..113d3c8f28a19 100644 --- a/server/src/main/java/org/elasticsearch/index/snapshots/blobstore/BlobStoreIndexShardSnapshots.java +++ b/server/src/main/java/org/elasticsearch/index/snapshots/blobstore/BlobStoreIndexShardSnapshots.java @@ -10,20 +10,25 @@ import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot.FileInfo; +import org.elasticsearch.index.store.StoreFileMetadata; +import org.elasticsearch.snapshots.SnapshotId; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.Iterator; import java.util.List; import java.util.Map; - -import static java.util.Collections.unmodifiableMap; +import java.util.Set; +import java.util.stream.Collectors; /** * Contains information about all snapshots for the given shard in repository @@ -33,54 +38,53 @@ */ public class BlobStoreIndexShardSnapshots implements Iterable, ToXContentFragment { - public static final BlobStoreIndexShardSnapshots EMPTY = new BlobStoreIndexShardSnapshots(Collections.emptyList()); + public static final BlobStoreIndexShardSnapshots EMPTY = new BlobStoreIndexShardSnapshots(Map.of(), List.of()); private final List shardSnapshots; private final Map files; - private final Map> physicalFiles; - public BlobStoreIndexShardSnapshots(List shardSnapshots) { + private BlobStoreIndexShardSnapshots(Map files, List shardSnapshots) { this.shardSnapshots = List.copyOf(shardSnapshots); - // Map between blob names and file info + this.files = files; + } + + public BlobStoreIndexShardSnapshots withRetainedSnapshots(Set retainedSnapshots) { + if (retainedSnapshots.isEmpty()) { + return EMPTY; + } + final var survivingSnapshotNames = retainedSnapshots.stream().map(SnapshotId::getName).collect(Collectors.toSet()); + final ArrayList updatedSnapshots = new ArrayList<>(survivingSnapshotNames.size()); Map newFiles = new HashMap<>(); - // Map between original physical names and file info - Map> physicalFiles = new HashMap<>(); for (SnapshotFiles snapshot : shardSnapshots) { - // First we build map between filenames in the repo and their original file info - // this map will be used in the next loop + if (survivingSnapshotNames.contains(snapshot.snapshot()) == false) { + continue; + } + updatedSnapshots.add(snapshot); for (FileInfo fileInfo : snapshot.indexFiles()) { FileInfo oldFile = newFiles.put(fileInfo.name(), fileInfo); assert oldFile == null || oldFile.isSame(fileInfo); } - // We are doing it in two loops here so we keep only one copy of the fileInfo per blob - // the first loop de-duplicates fileInfo objects that were loaded from different snapshots but refer to - // the same blob - for (FileInfo fileInfo : snapshot.indexFiles()) { - physicalFiles.computeIfAbsent(fileInfo.physicalName(), k -> new ArrayList<>()).add(newFiles.get(fileInfo.name())); - } } - Map> mapBuilder = new HashMap<>(); - for (Map.Entry> entry : physicalFiles.entrySet()) { - mapBuilder.put(entry.getKey(), List.copyOf(entry.getValue())); - } - this.physicalFiles = unmodifiableMap(mapBuilder); - this.files = unmodifiableMap(newFiles); + return new BlobStoreIndexShardSnapshots(newFiles, updatedSnapshots); } - private BlobStoreIndexShardSnapshots(Map files, List shardSnapshots) { - this.shardSnapshots = shardSnapshots; - this.files = files; - Map> physicalFiles = new HashMap<>(); - for (SnapshotFiles snapshot : shardSnapshots) { - for (FileInfo fileInfo : snapshot.indexFiles()) { - physicalFiles.computeIfAbsent(fileInfo.physicalName(), k -> new ArrayList<>()).add(files.get(fileInfo.name())); + public BlobStoreIndexShardSnapshots withAddedSnapshot(SnapshotFiles snapshotFiles) { + Map updatedFiles = null; + for (FileInfo fileInfo : snapshotFiles.indexFiles()) { + final FileInfo known = files.get(fileInfo.name()); + if (known == null) { + if (updatedFiles == null) { + updatedFiles = new HashMap<>(files); + } + updatedFiles.put(fileInfo.name(), fileInfo); + } else { + assert fileInfo.isSame(known); } } - Map> mapBuilder = new HashMap<>(); - for (Map.Entry> entry : physicalFiles.entrySet()) { - mapBuilder.put(entry.getKey(), List.copyOf(entry.getValue())); - } - this.physicalFiles = unmodifiableMap(mapBuilder); + return new BlobStoreIndexShardSnapshots( + updatedFiles == null ? files : updatedFiles, + CollectionUtils.appendToCopyNoNullElements(shardSnapshots, snapshotFiles) + ); } /** @@ -102,7 +106,10 @@ public BlobStoreIndexShardSnapshots withClone(String source, String target) { if (sourceFiles == null) { throw new IllegalArgumentException("unknown source [" + source + "]"); } - return new BlobStoreIndexShardSnapshots(CollectionUtils.appendToCopy(shardSnapshots, sourceFiles.withSnapshotName(target))); + return new BlobStoreIndexShardSnapshots( + files, + CollectionUtils.appendToCopyNoNullElements(shardSnapshots, sourceFiles.withSnapshotName(target)) + ); } /** @@ -114,14 +121,40 @@ public List snapshots() { return this.shardSnapshots; } + // index of Lucene file name to collection of file info in the repository + // lazy computed because building this is map is rather expensive and only needed for the snapshot create operation + private Map> physicalFiles; + /** - * Finds reference to a snapshotted file by its original name + * Finds reference to a snapshotted file by its {@link StoreFileMetadata} * - * @param physicalName original name - * @return a list of file infos that match specified physical file or null if the file is not present in any of snapshots + * @param storeFileMetadata store file metadata to find file info for + * @return the file info that matches the specified physical file or null if the file is not present in any of snapshots */ - public List findPhysicalIndexFiles(String physicalName) { - return physicalFiles.get(physicalName); + public FileInfo findPhysicalIndexFile(StoreFileMetadata storeFileMetadata) { + var p = this.physicalFiles; + if (p == null) { + p = new HashMap<>(); + for (SnapshotFiles snapshot : shardSnapshots) { + for (FileInfo fileInfo : snapshot.indexFiles()) { + // we use identity hash set since we lookup all instances from the same map and thus equality == instance equality + // and we don't want to add the same file to the map multiple times + p.computeIfAbsent(fileInfo.physicalName(), k -> Collections.newSetFromMap(new IdentityHashMap<>())) + .add(files.get(fileInfo.name())); + } + } + physicalFiles = p; + } + final var found = p.get(storeFileMetadata.name()); + if (found == null) { + return null; + } + for (FileInfo fileInfo : found) { + if (fileInfo.isSame(storeFileMetadata)) { + return fileInfo; + } + } + return null; } /** @@ -228,7 +261,8 @@ public static BlobStoreIndexShardSnapshots fromXContent(XContentParser parser) t if (token == null) { // New parser token = parser.nextToken(); } - Map> snapshotsMap = new HashMap<>(); + // list of tuples of snapshot name and file ids in the snapshot + List>> snapshotsAndFiles = new ArrayList<>(); Map historyUUIDs = new HashMap<>(); Map files = new HashMap<>(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser); @@ -256,7 +290,9 @@ public static BlobStoreIndexShardSnapshots fromXContent(XContentParser parser) t token = parser.nextToken(); if (Fields.FILES.equals(currentFieldName)) { if (token == XContentParser.Token.START_ARRAY) { - snapshotsMap.put(snapshot, XContentParserUtils.parseList(parser, XContentParser::text)); + snapshotsAndFiles.add( + Tuple.tuple(snapshot, XContentParserUtils.parseList(parser, XContentParser::text)) + ); } } else if (Fields.SHARD_STATE_ID.equals(currentFieldName)) { historyUUIDs.put(snapshot, parser.text()); @@ -268,19 +304,17 @@ public static BlobStoreIndexShardSnapshots fromXContent(XContentParser parser) t } } - List snapshots = new ArrayList<>(snapshotsMap.size()); - for (Map.Entry> entry : snapshotsMap.entrySet()) { + List snapshots = new ArrayList<>(snapshotsAndFiles.size()); + for (Tuple> entry : snapshotsAndFiles) { List fileInfosBuilder = new ArrayList<>(); - for (String file : entry.getValue()) { + for (String file : entry.v2()) { FileInfo fileInfo = files.get(file); assert fileInfo != null; fileInfosBuilder.add(fileInfo); } - snapshots.add( - new SnapshotFiles(entry.getKey(), Collections.unmodifiableList(fileInfosBuilder), historyUUIDs.get(entry.getKey())) - ); + snapshots.add(new SnapshotFiles(entry.v1(), Collections.unmodifiableList(fileInfosBuilder), historyUUIDs.get(entry.v1()))); } - return new BlobStoreIndexShardSnapshots(files, Collections.unmodifiableList(snapshots)); + return new BlobStoreIndexShardSnapshots(files, snapshots); } } diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index 3b6f61aad09ee..69c01a51b337b 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -2696,18 +2696,7 @@ public void snapshotShard(SnapshotShardContext context) { logger.trace("[{}] [{}] Processing [{}]", shardId, snapshotId, fileName); final StoreFileMetadata md = metadataFromStore.get(fileName); - BlobStoreIndexShardSnapshot.FileInfo existingFileInfo = null; - List filesInfo = snapshots.findPhysicalIndexFiles(fileName); - if (filesInfo != null) { - for (BlobStoreIndexShardSnapshot.FileInfo fileInfo : filesInfo) { - if (fileInfo.isSame(md)) { - // a commit point file with the same name, size and checksum was already copied to repository - // we will reuse it for this snapshot - existingFileInfo = fileInfo; - break; - } - } - } + BlobStoreIndexShardSnapshot.FileInfo existingFileInfo = snapshots.findPhysicalIndexFile(md); // We can skip writing blobs where the metadata hash is equal to the blob's contents because we store the hash/contents // directly in the shard level metadata in this case @@ -2733,6 +2722,8 @@ public void snapshotShard(SnapshotShardContext context) { filesInShardMetadataSize += md.length(); } } else { + // a commit point file with the same name, size and checksum was already copied to repository + // we will reuse it for this snapshot indexCommitPointFiles.add(existingFileInfo); } } @@ -2756,12 +2747,9 @@ public void snapshotShard(SnapshotShardContext context) { final boolean writeShardGens = SnapshotsService.useShardGenerations(context.getRepositoryMetaVersion()); final boolean writeFileInfoWriterUUID = SnapshotsService.includeFileInfoWriterUUID(context.getRepositoryMetaVersion()); // build a new BlobStoreIndexShardSnapshot, that includes this one and all the saved ones - List newSnapshotsList = new ArrayList<>(); - newSnapshotsList.add(new SnapshotFiles(snapshotId.getName(), indexCommitPointFiles, context.stateIdentifier())); - for (SnapshotFiles point : snapshots) { - newSnapshotsList.add(point); - } - final BlobStoreIndexShardSnapshots updatedBlobStoreIndexShardSnapshots = new BlobStoreIndexShardSnapshots(newSnapshotsList); + final BlobStoreIndexShardSnapshots updatedBlobStoreIndexShardSnapshots = snapshots.withAddedSnapshot( + new SnapshotFiles(snapshotId.getName(), indexCommitPointFiles, context.stateIdentifier()) + ); final Runnable afterWriteSnapBlob; if (writeShardGens) { // When using shard generations we can safely write the index-${uuid} blob before writing out any of the actual data @@ -3253,19 +3241,12 @@ private ShardSnapshotMetaDeleteResult deleteFromShardSnapshotMeta( long indexGeneration ) { // Build a list of snapshots that should be preserved - List newSnapshotsList = new ArrayList<>(); - final Set survivingSnapshotNames = survivingSnapshots.stream().map(SnapshotId::getName).collect(Collectors.toSet()); - for (SnapshotFiles point : snapshots) { - if (survivingSnapshotNames.contains(point.snapshot())) { - newSnapshotsList.add(point); - } - } + final BlobStoreIndexShardSnapshots updatedSnapshots = snapshots.withRetainedSnapshots(survivingSnapshots); ShardGeneration writtenGeneration = null; try { - if (newSnapshotsList.isEmpty()) { + if (updatedSnapshots.snapshots().isEmpty()) { return new ShardSnapshotMetaDeleteResult(indexId, snapshotShardId, ShardGenerations.DELETED_SHARD_GEN, blobs); } else { - final BlobStoreIndexShardSnapshots updatedSnapshots = new BlobStoreIndexShardSnapshots(newSnapshotsList); if (indexGeneration < 0L) { writtenGeneration = ShardGeneration.newGeneration(); INDEX_SHARD_SNAPSHOTS_FORMAT.write(updatedSnapshots, shardContainer, writtenGeneration.toBlobNamePart(), compress); From 6b8dab7807fa56a228a39b77c6761488e5c1a8e7 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 29 Jul 2022 07:28:25 -0400 Subject: [PATCH 4/5] [ML] fix BERT and MPNet tokenization bug when handling unicode accents (#88907) When handling unicode accents, it may have been that BERT tokenizations removed the incorrect characters. This would result in an exceptionally strange result and possibly an error. closes #88900 --- docs/changelog/88907.yaml | 6 ++++++ .../deployment/DeploymentManager.java | 2 +- .../xpack/ml/inference/nlp/Vocabulary.java | 2 ++ .../nlp/tokenizers/BasicTokenFilter.java | 21 ++++++++++++------- .../nlp/tokenizers/BasicTokenFilterTests.java | 1 + 5 files changed, 24 insertions(+), 8 deletions(-) create mode 100644 docs/changelog/88907.yaml diff --git a/docs/changelog/88907.yaml b/docs/changelog/88907.yaml new file mode 100644 index 0000000000000..2d9cab22424ca --- /dev/null +++ b/docs/changelog/88907.yaml @@ -0,0 +1,6 @@ +pr: 88907 +summary: Fix BERT and MPNet tokenization bug when handling unicode accents +area: Machine Learning +type: bug +issues: + - 88900 diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 0d917debe3d02..6b984628f3b7b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -206,7 +206,7 @@ Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException { stream ) ) { - return Vocabulary.createParser(true).apply(parser, null); + return Vocabulary.PARSER.apply(parser, null); } catch (IOException e) { logger.error(() -> "failed to parse trained model vocabulary [" + hit.getId() + "]", e); throw e; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java index 7665c61b76ce5..6deb9a8b6d0fb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/Vocabulary.java @@ -45,6 +45,8 @@ public static ConstructingObjectParser createParser(boolean ig return parser; } + public static ConstructingObjectParser PARSER = createParser(true); + private final List vocab; private final List merges; private final String modelId; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenFilter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenFilter.java index 8828efa4af1eb..3be4eded99894 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenFilter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenFilter.java @@ -140,25 +140,30 @@ public boolean incrementToken() throws IOException { return false; } - void stripAccent() { + private void stripAccent() { accentBuffer.setLength(0); + boolean changed = false; if (normalizer.quickCheck(termAtt) != Normalizer.YES) { normalizer.normalize(termAtt, accentBuffer); + changed = true; + } else { + accentBuffer.append(termAtt); } List badIndices = new ArrayList<>(); List charCount = new ArrayList<>(); int index = 0; + int deletedIndices = 0; for (PrimitiveIterator.OfInt it = accentBuffer.codePoints().iterator(); it.hasNext();) { int cp = it.next(); if (Character.getType(cp) == Character.NON_SPACING_MARK) { - badIndices.add(index); + // When we iterate to delete accents, we need to account for previously deleted ones + badIndices.add(index - deletedIndices); charCount.add(Character.charCount(cp)); + deletedIndices++; + changed = true; } index++; } - if (badIndices.isEmpty()) { - return; - } for (int i = 0; i < badIndices.size(); i++) { int badIndex = badIndices.get(i); int count = charCount.get(i); @@ -166,12 +171,14 @@ void stripAccent() { accentBuffer.deleteCharAt(badIndex); } } - termAtt.setEmpty().append(accentBuffer); + if (changed) { + termAtt.setEmpty().append(accentBuffer); + } } private LinkedList split() { LinkedList splits = new LinkedList<>(); - int startOffset = offsetAtt.startOffset(); + final int startOffset = offsetAtt.startOffset(); int charIndex = 0; int lastCharSplit = 0; for (PrimitiveIterator.OfInt it = termAtt.codePoints().iterator(); it.hasNext();) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenFilterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenFilterTests.java index 9199e2c776f2e..a3288baf65968 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenFilterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BasicTokenFilterTests.java @@ -67,6 +67,7 @@ public void testSplitCJK() throws Exception { public void testStripAccents() throws Exception { Analyzer analyzer = basicAnalyzerFromSettings(true, true, List.of("[UNK]")); assertAnalyzesToNoCharFilter(analyzer, "HäLLo how are you", new String[] { "HaLLo", "how", "are", "you" }); + assertAnalyzesToNoCharFilter(analyzer, "ÎÎÎÏνÎÎÎαοÏ", new String[] { "IIIII½IIII±I", "¿", "I" }); } private static void assertAnalyzesToNoCharFilter(Analyzer a, String input, String[] output) throws IOException { From 9f2b96d82e7a67aa6dddabf0fd2ae310749c39fd Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 29 Jul 2022 11:14:51 -0400 Subject: [PATCH 5/5] [ML] add sentence-piece unigram tokenizer (#88858) Add internal unigram tokenizer. This tokenizer is the same that XLM-Roberta utilizes, along with many other cross-lingual models and tasks. This does not fully integrate (adding configuration, integrating into nlp tasks, etc.). But instead is just the internal tokenization and some tests showing how it runs with a precompiled charsmap. --- .../nlp/tokenizers/DelimitedToken.java | 22 + .../PrecompiledCharMapNormalizer.java | 95 +++- .../nlp/tokenizers/UnigramTokenizer.java | 493 ++++++++++++++++++ .../PrecompiledCharMapNormalizerTests.java | 55 +- .../nlp/tokenizers/UnigramTokenizerTests.java | 165 ++++++ 5 files changed, 790 insertions(+), 40 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizerTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java index ec84b1794fa84..32713997f3e8d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DelimitedToken.java @@ -7,10 +7,13 @@ package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; +import static org.elasticsearch.core.Strings.format; + public class DelimitedToken { static DelimitedToken mergeTokens(List tokens) { @@ -67,6 +70,25 @@ public String toString() { } public static class Encoded extends DelimitedToken { + static DelimitedToken.Encoded mergeEncodedTokens(List tokens) { + if (tokens.size() == 1) { + return tokens.get(0); + } + int startOffSet = tokens.get(0).startOffset(); + int endOffset = tokens.get(tokens.size() - 1).endOffset(); + final int encoding = tokens.get(0).encoding; + List sequences = new ArrayList<>(tokens.size()); + for (var t : tokens) { + if (t.encoding != encoding) { + throw new IllegalArgumentException( + format("all merged tokens must have the same encoding, expected [%s]; found [%s]", encoding, t.encoding) + ); + } + sequences.add(t.charSequence()); + } + return new DelimitedToken.Encoded(new MultiCharSequence(sequences), tokens.get(0).encoding, startOffSet, endOffset); + } + private final int encoding; public Encoded(CharSequence charSequence, int encoding, int startOffset, int endOffset) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizer.java index 4470a8629bf65..f20e836fcae87 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizer.java @@ -13,14 +13,22 @@ import com.ibm.icu.text.BreakIterator; +import org.apache.lucene.analysis.charfilter.BaseCharFilter; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefBuilder; +import org.apache.lucene.util.CharsRef; +import org.apache.lucene.util.CharsRefBuilder; import org.apache.lucene.util.UnicodeUtil; +import java.io.CharArrayReader; +import java.io.IOException; +import java.io.Reader; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Base64; +import java.util.List; import java.util.Locale; import java.util.Optional; import java.util.OptionalInt; @@ -39,10 +47,15 @@ * DARTS * * - SP normalizer + * + * We implement this as a char filter to take advantage of the underlying offset correction and because normalization needs to occur before + * tokenization (just like a charfilter) */ -public class PrecompiledCharMapNormalizer { +public class PrecompiledCharMapNormalizer extends BaseCharFilter { + + record Config(int[] offsets, String utf8str) {} - static PrecompiledCharMapNormalizer fromBase64Str(String s) { + static Config fromBase64Str(String s) { int offset = 0; byte[] bytes = Base64.getDecoder().decode(s); int trieSize = ByteBuffer.wrap(bytes, offset, 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).getInt(); @@ -54,7 +67,7 @@ static PrecompiledCharMapNormalizer fromBase64Str(String s) { offset += 4; } String utf8Str = new String(bytes, offset, bytes.length - offset, StandardCharsets.UTF_8); - return new PrecompiledCharMapNormalizer(offsets, utf8Str); + return new Config(offsets, utf8Str); } // The offsets for each normalization piece. Used in DARTS algorithm to iterate and find appropriate section @@ -64,8 +77,12 @@ static PrecompiledCharMapNormalizer fromBase64Str(String s) { private final byte[] normalizedStrUtf8Bytes; // Continually reused to copy a single char into utf8 bytes private final byte[] reusableCharByteBuffer = new byte[4]; + // reusable char buffer for decoding utf8 bytes to determine char offset corrections + private final char[] reusableCharDecodeBuffer = new char[8]; + private Reader transformedInput; - public PrecompiledCharMapNormalizer(int[] offsets, String normalizedStr) { + public PrecompiledCharMapNormalizer(int[] offsets, String normalizedStr, Reader in) { + super(in); this.offsets = offsets; this.normalizedStrUtf8Bytes = normalizedStr.getBytes(StandardCharsets.UTF_8); } @@ -152,11 +169,7 @@ private Optional normalizePart(byte[] strBytes, int offset, int len) { return Optional.of(new BytesRef(normalizedStrUtf8Bytes, firstIndex, secondIndex - firstIndex)); } - String normalize(String str) { - return normalize((CharSequence) str).utf8ToString(); - } - - BytesRef normalize(CharSequence str) { + Reader normalize(CharSequence str) { // We need to iterate actual Unicode graphemes (this includes surrogate pairs, etc.) ByteBuffer byteBuffer = StandardCharsets.UTF_8.encode(CharBuffer.wrap(str)); byte[] strBytes = new byte[byteBuffer.limit()]; @@ -167,9 +180,10 @@ BytesRef normalize(CharSequence str) { // We iterate the whole string, so b.first() is always `0` int startIter = b.first(); int codePointPos = 0; - BytesRefBuilder strBuilder = new BytesRefBuilder(); + CharsRefBuilder strBuilder = new CharsRefBuilder(); strBuilder.grow(strBytes.length); int bytePos = 0; + int normalizedCharPos = 0; // Keep in mind, these break points aren't necessarily surrogate pairs, but also codepoints that contain a combining mark for (int end = b.next(); end != BreakIterator.DONE; startIter = end, end = b.next()) { int byteLen = 0; @@ -181,9 +195,15 @@ BytesRef normalize(CharSequence str) { // The trie only go up to a depth of 5 bytes. // So even looking at it for graphemes (with combining, surrogate, etc.) that are 6+ bytes in length is useless. if (byteLen < 6) { - Optional subStr = normalizePart(strBytes, bytePos, byteLen); - if (subStr.isPresent()) { - strBuilder.append(subStr.get()); + Optional maybeSubStr = normalizePart(strBytes, bytePos, byteLen); + if (maybeSubStr.isPresent()) { + BytesRef subStr = maybeSubStr.get(); + int numChars = UnicodeUtil.UTF8toUTF16(subStr.bytes, subStr.offset, subStr.length, reusableCharDecodeBuffer); + normalizedCharPos += numChars; + if (numChars != end - startIter) { + addOffCorrectMap(normalizedCharPos, getLastCumulativeDiff() + end - startIter - numChars); + } + strBuilder.append(reusableCharDecodeBuffer, 0, numChars); bytePos += byteLen; continue; } @@ -191,18 +211,53 @@ BytesRef normalize(CharSequence str) { int charByteIndex = 0; for (int i = startIter; i < end; i++) { int utf8CharBytes = numUtf8Bytes(str.charAt(i)); - Optional subStr = normalizePart(strBytes, charByteIndex + bytePos, utf8CharBytes); - if (subStr.isPresent()) { - strBuilder.append(subStr.get()); + Optional maybeSubStr = normalizePart(strBytes, charByteIndex + bytePos, utf8CharBytes); + if (maybeSubStr.isPresent()) { + BytesRef subStr = maybeSubStr.get(); + int numChars = UnicodeUtil.UTF8toUTF16(subStr.bytes, subStr.offset, subStr.length, reusableCharDecodeBuffer); + normalizedCharPos += numChars; + // Meaning we removed this char + if (numChars < 1) { + addOffCorrectMap(normalizedCharPos, getLastCumulativeDiff() + 1); + } else if (numChars > 1) { + addOffCorrectMap(normalizedCharPos, getLastCumulativeDiff() - 1); + } + strBuilder.append(reusableCharDecodeBuffer, 0, numChars); } else { - int numBytes = UnicodeUtil.UTF16toUTF8(str, i, 1, reusableCharByteBuffer); - strBuilder.append(reusableCharByteBuffer, 0, numBytes); + normalizedCharPos += 1; + strBuilder.append(str.charAt(i)); } charByteIndex += utf8CharBytes; } bytePos += byteLen; } - return strBuilder.get(); + return new CharArrayReader(strBuilder.chars(), 0, strBuilder.length()); + } + + @Override + public int read(char[] cbuf, int off, int len) throws IOException { + if (transformedInput == null) { + fill(); + } + + return transformedInput.read(cbuf, off, len); } + @Override + public int read() throws IOException { + if (transformedInput == null) { + fill(); + } + + return transformedInput.read(); + } + + private void fill() throws IOException { + List charArrays = new ArrayList<>(); + char[] temp = new char[1024]; + for (int cnt = input.read(temp); cnt > 0; cnt = input.read(temp)) { + charArrays.add(new CharsRef(Arrays.copyOfRange(temp, 0, cnt), 0, cnt)); + } + transformedInput = normalize(new MultiCharSequence(charArrays)); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java new file mode 100644 index 0000000000000..26f7f49d98565 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java @@ -0,0 +1,493 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.apache.lucene.analysis.CharArraySet; +import org.apache.lucene.analysis.CharacterUtils; +import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.UnicodeUtil; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.core.Nullable; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizerUtils.numUtf8Bytes; +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizerUtils.splitOutNeverSplit; + +/** + * Sentence-piece unigram tokenizer. + * + * Does whitespace tokenization with unigram tokenization on the resulting tokens. + * + * This cannot be a token-filter as it needs access to the offset correction logic provided by the upstream CharFilter. + * + * You may notice that the offsets are always matching the individual tokens position back to the original string. This is because + * there aren't "sub-word" tokens, per-se. So, we don't have tokens that share the same offsets as in WordPiece. + */ +public final class UnigramTokenizer extends Tokenizer { + private static final double K_UNK_PENALTY = 10.0; + static final String PREFIX = "▁"; + + private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); + private final OffsetAttribute offsetAtt = addAttribute(OffsetAttribute.class); + + static UnigramTokenizer build(List neverSplit, List dictionary, List scores, String unknownToken) { + if (dictionary.isEmpty()) { + throw new IllegalArgumentException("vocab empty"); + } + if (unknownToken == null) { + throw new IllegalArgumentException("unknown token ID"); + } + CharArraySet neverSplitSet = new CharArraySet(neverSplit, false); + CharTrie neverSplitTree = CharTrie.build(neverSplit); + if (dictionary.size() != scores.size()) { + throw new IllegalArgumentException( + format("provided vocabulary [%s] and scores [%s] must have the same size", dictionary.size(), scores.size()) + ); + } + int vocabSize = dictionary.size(); + BytesTrie vocabTrie = new BytesTrie(); + Map tokenToId = Maps.newHashMapWithExpectedSize(vocabSize); + int vocabIndex = 0; + double minScore = Double.POSITIVE_INFINITY; + double[] vocabScores = new double[vocabSize]; + for (String word : dictionary) { + minScore = Double.min(minScore, scores.get(vocabIndex)); + BytesRef vocab = new BytesRef(word); + vocabScores[vocabIndex] = scores.get(vocabIndex); + tokenToId.put(vocab, vocabIndex++); + vocabTrie.insert(vocab); + } + return new UnigramTokenizer( + minScore, + vocabScores, + neverSplitTree, + neverSplitSet, + tokenToId, + vocabTrie, + Optional.ofNullable(tokenToId.get(new BytesRef(unknownToken))) + .orElseThrow( + () -> new IllegalArgumentException("provided vocabulary does not contain the unknown token of [" + unknownToken + "]") + ) + ); + } + + private final LinkedList tokens; + private final List tokenizedValues; + private final SimpleWhitespaceTokenizer whitespaceTokenizer; + + private final double minScore; + // This may be configurable in the future + private final boolean fuseUnk = true; + private final double[] vocabScores; + private final CharTrie neverSplit; + private final CharArraySet neverSplitHash; + private final Map vocabToId; + private final BytesTrie vocabTrie; + private final int unknownTokenId; + // This is a buffer that is reused per token for decoding the normalized char-sequence into utf-8 bytes + // It's usage is NOT thread safe + private byte[] normalizedByteBuffer = new byte[128]; + + public UnigramTokenizer( + double minScore, + double[] vocabScores, + CharTrie neverSplit, + CharArraySet neverSplitHash, + Map vocabToId, + BytesTrie vocabTrie, + int unknownTokenId + ) { + super(); + this.tokens = new LinkedList<>(); + this.tokenizedValues = new ArrayList<>(); + this.minScore = minScore; + this.neverSplit = neverSplit; + this.neverSplitHash = neverSplitHash; + this.vocabToId = vocabToId; + this.vocabTrie = vocabTrie; + this.unknownTokenId = unknownTokenId; + this.vocabScores = vocabScores; + this.whitespaceTokenizer = new SimpleWhitespaceTokenizer(); + } + + @Override + public void reset() throws IOException { + super.reset(); + tokens.clear(); + tokenizedValues.clear(); + whitespaceTokenizer.reset(); + } + + @Override + public void end() throws IOException { + super.end(); + offsetAtt.setOffset(correctOffset(whitespaceTokenizer.finalOffset), correctOffset(whitespaceTokenizer.finalOffset)); + } + + @Override + public boolean incrementToken() throws IOException { + clearAttributes(); + if (tokens.isEmpty() == false) { + DelimitedToken.Encoded token = tokens.removeFirst(); + termAtt.setEmpty().append(token.charSequence()); + offsetAtt.setOffset(token.startOffset(), token.endOffset()); + return true; + } + // First, whitespace tokenize + DelimitedToken whitespaceToken = whitespaceTokenizer.next(); + if (whitespaceToken != null) { + if (neverSplitHash.contains(whitespaceToken.charSequence())) { + Integer maybeTokenized = vocabToId.get(new BytesRef(whitespaceToken.charSequence())); + tokenizedValues.add( + new DelimitedToken.Encoded( + whitespaceToken.charSequence().toString(), + Objects.requireNonNullElse(maybeTokenized, unknownTokenId), + correctOffset(whitespaceToken.startOffset()), + correctOffset(whitespaceToken.endOffset()) + ) + ); + offsetAtt.setOffset(correctOffset(whitespaceToken.startOffset()), correctOffset(whitespaceToken.endOffset())); + return true; + } + int inputOffsetStart = whitespaceToken.startOffset(); + // Split out our neverSplit tokens + LinkedList largeTokensWithNeverSplits = splitOutNeverSplit( + whitespaceToken.charSequence(), + neverSplit, + neverSplitHash + ); + // Encode each token, skipping our "never split" ones. + for (DelimitedToken token : largeTokensWithNeverSplits) { + if (neverSplitHash.contains(token.charSequence())) { + Integer tokenId = vocabToId.get(new BytesRef(token.charSequence())); + DelimitedToken.Encoded toAdd = tokenId == null + ? new DelimitedToken.Encoded( + token.charSequence().toString(), + unknownTokenId, + correctOffset(token.startOffset() + inputOffsetStart), + correctOffset(token.endOffset() + inputOffsetStart) + ) + : new DelimitedToken.Encoded( + token.charSequence().toString(), + tokenId, + correctOffset(token.startOffset() + inputOffsetStart), + correctOffset(token.endOffset() + inputOffsetStart) + ); + tokens.add(toAdd); + continue; + } + // We always prefix the initial sub-tokens + // e.g. " asdf-asdf " -> ['', '▁as', 'd', 'f', '', '▁-', 'as', 'd', 'f'] + IntToIntFunction offsetCorrectorFunction = i -> { + int adj = i + inputOffsetStart + token.startOffset(); + // if the passed offset to set is `0`, that means the tokenization probably matched on the meta-space character + // Meaning, the start and end offsets for that token will be the same and ultimately discarded when re-constituting + // tokenized results (if that is necessary for the task). + if (i > 0) { + // We always apply the prefix, so account for that when correcting the offsets, basically, the original + // normalization + // doesn't know about our prefix, so we should find out the correct offsets when not taking it into account. + adj -= PREFIX.length(); + } + return correctOffset(adj); + }; + List tokenList = tokenize( + MultiCharSequence.from(PREFIX, token.charSequence()), + offsetCorrectorFunction + ); + tokenizedValues.addAll(tokenList); + tokens.addAll(tokenList); + } + DelimitedToken.Encoded token = tokens.removeFirst(); + termAtt.setEmpty().append(token.charSequence()); + offsetAtt.setOffset(token.startOffset(), token.endOffset()); + return true; + } + return false; + } + + /** + * This algorithm does the following: + * + * - iterates all the prefixes for the given input sequence, byte by byte. + * - Keeps track of the best scores for the prefixes we find and reconstitutes the tokens from those prefixes + * + * This is derived from: + * https://github.com/google/sentencepiece/blob/901368e0752b57a408ac5c84bca0a219d62c648f/src/unigram_model.cc#L890 + * https://github.com/huggingface/tokenizers/blob/1f1f86dd320fa653924eb1560e51d1b287ab0613/tokenizers/src/models/unigram/model.rs#L229 + * + * @param inputSequence The sequence to encode, should have NO whitespace characters + * @param offsetCorrection Offset corrections to apply to the tokens. Should take into account any previous char-filtering and tokens. + * @return The list of delimited and encoded tokens + */ + List tokenize(CharSequence inputSequence, IntToIntFunction offsetCorrection) { + int bytelen = UnicodeUtil.calcUTF16toUTF8Length(inputSequence, 0, inputSequence.length()); + if (bytelen > normalizedByteBuffer.length) { + normalizedByteBuffer = new byte[bytelen + 1]; + } + int numBytes = UnicodeUtil.UTF16toUTF8(inputSequence, 0, inputSequence.length(), normalizedByteBuffer); + double unkScore = minScore - K_UNK_PENALTY; + BestPathNode[] bestPathNodes = new BestPathNode[numBytes + 1]; + int bytePos = 0; + int charPos = 0; + while (bytePos < numBytes) { + double bestScoreTillHere = bestPathNodes[bytePos] == null ? 0 : bestPathNodes[bytePos].score; + int mblen = numUtf8Bytes(inputSequence.charAt(charPos)); + boolean hasSingleNode = false; + // Find the matching prefixes, incrementing by the chars, each time + for (BytesRef prefix : vocabTrie.matchingPrefixes(new BytesRef(normalizedByteBuffer, bytePos, numBytes - bytePos))) { + int pathKey = bytePos + prefix.length; + int tokenId = vocabToId.get(prefix); + double score = vocabScores[tokenId]; + BestPathNode node = bestPathNodes[pathKey]; + double candidateScore = score + bestScoreTillHere; + if (node == null || candidateScore > node.score) { + if (node == null) { + node = new BestPathNode(); + bestPathNodes[pathKey] = node; + } + node.id = tokenId; + node.score = candidateScore; + node.startsAtBytePos = bytePos; + node.startsAtCharPos = charPos; + } + hasSingleNode = hasSingleNode || (pathKey - bytePos) == mblen; + } + if (hasSingleNode == false) { + BestPathNode node = bestPathNodes[bytePos + mblen]; + double candidateScore = unkScore + bestScoreTillHere; + if (node == null || candidateScore > node.score) { + if (node == null) { + node = new BestPathNode(); + bestPathNodes[bytePos + mblen] = node; + } + node.id = unknownTokenId; + node.score = candidateScore; + node.startsAtBytePos = bytePos; + node.startsAtCharPos = charPos; + } + } + // Move our prefix search to the next char + bytePos += mblen; + ++charPos; + } + int endsAtBytes = numBytes; + int endsAtChars = inputSequence.length(); + List unknownTokens = new ArrayList<>(); + List results = new ArrayList<>(); + // Now we work our way backwards finding the best path nodes, using the `startAtBytePos` as backward links. + while (endsAtBytes > 0) { + BestPathNode node = bestPathNodes[endsAtBytes]; + int startsAtBytes = node.startsAtBytePos; + if (node.id == unknownTokenId && fuseUnk) { + unknownTokens.add( + new DelimitedToken.Encoded( + new String(normalizedByteBuffer, startsAtBytes, endsAtBytes - startsAtBytes, StandardCharsets.UTF_8), + unknownTokenId, + offsetCorrection.apply(node.startsAtCharPos), + offsetCorrection.apply(endsAtChars) + ) + ); + } else { + if (unknownTokens.isEmpty() == false) { + Collections.reverse(unknownTokens); + results.add(DelimitedToken.Encoded.mergeEncodedTokens(unknownTokens)); + unknownTokens.clear(); + } + results.add( + new DelimitedToken.Encoded( + new String(normalizedByteBuffer, startsAtBytes, endsAtBytes - startsAtBytes, StandardCharsets.UTF_8), + node.id, + offsetCorrection.apply(node.startsAtCharPos), + offsetCorrection.apply(endsAtChars) + ) + ); + } + endsAtBytes = startsAtBytes; + endsAtChars = node.startsAtCharPos; + } + if (unknownTokens.isEmpty() == false) { + Collections.reverse(unknownTokens); + results.add(DelimitedToken.Encoded.mergeEncodedTokens(unknownTokens)); + unknownTokens.clear(); + } + Collections.reverse(results); + return results; + } + + private static byte fromBytesRef(BytesRef bytesRef, int index) { + return bytesRef.bytes[index + bytesRef.offset]; + } + + /** + * This is a bytes-trie, this is used for gathering known matching prefixes given the original vocabulary. + * + * NOTE: it is possible for a node to be a "leaf" and have children. It being a "leaf", just means that it is the end of a possible + * vocab entry that matches a given prefix. + */ + static class BytesTrie { + private final Map children; + private boolean isLeaf; + + BytesTrie() { + children = new HashMap<>(); + } + + private void setLeaf(boolean isLeaf) { + this.isLeaf = isLeaf; + } + + private boolean isLeaf() { + return isLeaf; + } + + List matchingPrefixes(BytesRef input) { + List prefixes = new ArrayList<>(); + int numMatchedChildren = 0; + BytesTrie node = this; + for (int i = input.offset; i < input.length + input.offset; i++) { + if (node == null) { + break; + } + if (node.isLeaf() && numMatchedChildren > 0) { + prefixes.add(new BytesRef(input.bytes, input.offset, numMatchedChildren)); + } + node = node.children.get(input.bytes[i]); + numMatchedChildren++; + } + if (node != null && node.isLeaf() && numMatchedChildren > 0) { + prefixes.add(new BytesRef(input.bytes, input.offset, numMatchedChildren)); + } + return prefixes; + } + + void insert(BytesRef bytes) { + if (bytes.length == 0) { + return; + } + BytesTrie currentNode = this; + int currentTokenIndex = 0; + + // find last child + while (currentTokenIndex < bytes.length) { + currentNode = currentNode.children.computeIfAbsent(fromBytesRef(bytes, currentTokenIndex), k -> new BytesTrie()); + currentTokenIndex++; + } + currentNode.setLeaf(true); + } + + public static BytesTrie build(Collection tokens) { + BytesTrie root = new BytesTrie(); + for (BytesRef token : tokens) { + root.insert(token); + } + return root; + } + } + + /** + * This keeps track of the best-path in the vocab for given prefixes + */ + private static class BestPathNode { + // Token Id, -1 if its unknown + private int id = -1; + // Token score + double score = 0.0; + // starts at byte position for walking back the best scoring node + private int startsAtBytePos = -1; + // Its char position for correctly identifying offsets related to the original input + private int startsAtCharPos = -1; + } + + @FunctionalInterface + public interface IntToIntFunction { + int apply(int value); + } + + /** + * This is a simple whitespace tokenizer that generates whitespace delimited tokens from the input stream + * + * This is effectively the lucene WhitespaceTokenizer, slightly adjusted for our needs here. + */ + class SimpleWhitespaceTokenizer { + private int offset = 0, bufferIndex = 0, dataLen = 0, finalOffset = 0; + private static final int IO_BUFFER_SIZE = 4096; + private final CharacterUtils.CharacterBuffer ioBuffer = CharacterUtils.newCharacterBuffer(IO_BUFFER_SIZE); + + void reset() { + bufferIndex = 0; + offset = 0; + dataLen = 0; + finalOffset = 0; + ioBuffer.reset(); + } + + @Nullable + DelimitedToken next() throws IOException { + int length = 0; + int start = -1; // this variable is always initialized + int end = -1; + char[] buffer = termAtt.buffer(); + while (true) { + if (bufferIndex >= dataLen) { + offset += dataLen; + CharacterUtils.fill(ioBuffer, input); // read supplementary char aware with CharacterUtils + if (ioBuffer.getLength() == 0) { + dataLen = 0; // so next offset += dataLen won't decrement offset + if (length > 0) { + break; + } else { + finalOffset = offset; + return null; + } + } + dataLen = ioBuffer.getLength(); + bufferIndex = 0; + } + // use CharacterUtils here to support < 3.1 UTF-16 code unit behavior if the char based + // methods are gone + final int c = Character.codePointAt(ioBuffer.getBuffer(), bufferIndex, ioBuffer.getLength()); + final int charCount = Character.charCount(c); + bufferIndex += charCount; + if (Character.isWhitespace(c) == false) { // if it's a token char + if (length == 0) { // start of token + assert start == -1; + start = offset + bufferIndex - charCount; + end = start; + } else if (length >= buffer.length - 1) { // supplementary could run out of bounds? + // make sure a supplementary fits in the buffer + buffer = termAtt.resizeBuffer(2 + length); + } + end += charCount; + length += Character.toChars(c, buffer, length); + } else if (length > 0) { + break; + } + } + + termAtt.setLength(length); + assert start != -1; + return new DelimitedToken(termAtt, start, finalOffset = end); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizerTests.java index 8016ed2e02278..8541ccfb6c2cd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizerTests.java @@ -10,43 +10,58 @@ import org.elasticsearch.test.ESTestCase; import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.OptionalInt; +import java.io.StringReader; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; public class PrecompiledCharMapNormalizerTests extends ESTestCase { public void testCommonPrefix() throws IOException { - PrecompiledCharMapNormalizer parsed = loadTestCharMap(); - OptionalInt local = parsed.commonPrefix("\uFB01".getBytes(StandardCharsets.UTF_8)); - assertThat(local.isPresent(), is(true)); - assertThat(local.getAsInt(), equalTo(2130)); - String transformed = parsed.normalize("\uFB01"); - assertThat(transformed, equalTo("fi")); - assertThat(parsed.normalize("𝔾"), equalTo("G")); - assertThat(parsed.normalize("\uD835\uDD60"), equalTo("o")); - assertThat(parsed.normalize("\u200D"), equalTo(" ")); - assertThat(parsed.normalize("เขาไม่ได้พูดสักคำ"), equalTo("เขาไม\u0E48ได\u0E49พ\u0E39ดส\u0E31กค\u0E4Dา")); + PrecompiledCharMapNormalizer.Config parsed = loadTestCharMap(); + assertNormalization("\u0008", parsed, ""); + assertNormalization("\uFB01", parsed, "fi"); + assertNormalization("𝔾", parsed, "G"); + assertNormalization("\uD835\uDD60", parsed, "o"); + assertNormalization("\u200D", parsed, " "); + assertNormalization("เขาไม่ได้พูดสักคำ", parsed, "เขาไม\u0E48ได\u0E49พ\u0E39ดส\u0E31กค\u0E4Dา"); } public void testAdverseScenario() throws IOException { - PrecompiledCharMapNormalizer parsed = loadTestCharMap(); - assertThat(parsed.normalize("คำ"), equalTo("ค\u0e4dา")); + PrecompiledCharMapNormalizer.Config parsed = loadTestCharMap(); + assertNormalization("คำ", parsed, "ค\u0e4dา"); } public void testAdverseScenarioHindi() throws IOException { - PrecompiledCharMapNormalizer parsed = loadTestCharMap(); - assertThat(parsed.normalize("ड़ी दुख"), equalTo("ड\u093cी द\u0941ख")); + PrecompiledCharMapNormalizer.Config parsed = loadTestCharMap(); + assertNormalization("ड़ी दुख", parsed, "ड\u093cी द\u0941ख"); } public void testTwoCharUnicode() throws IOException { - PrecompiledCharMapNormalizer parsed = loadTestCharMap(); - assertThat(parsed.normalize("آ"), equalTo("آ")); + PrecompiledCharMapNormalizer.Config parsed = loadTestCharMap(); + assertNormalization("آ", parsed, "آ"); } - private static PrecompiledCharMapNormalizer loadTestCharMap() throws IOException { + public void testWhitespaceScenario() throws IOException { + PrecompiledCharMapNormalizer.Config parsed = loadTestCharMap(); + assertNormalization("​​από", parsed, " από"); + } + + private void assertNormalization(String input, PrecompiledCharMapNormalizer.Config config, String expected) throws IOException { + PrecompiledCharMapNormalizer normalizer = new PrecompiledCharMapNormalizer( + config.offsets(), + config.utf8str(), + new StringReader(input) + ); + char[] output = new char[64]; + int read = normalizer.read(output, 0, 64); + if (read <= 0) { + assertThat("", equalTo(expected)); + } else { + assertThat(new String(output, 0, read), equalTo(expected)); + } + } + + static PrecompiledCharMapNormalizer.Config loadTestCharMap() throws IOException { PreCompiledCharMap map = PreCompiledCharMap.fromResource( "/org.elasticsearch.xpack.ml.inference.nlp.tokenizers/precompiled_char_map.json" ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizerTests.java new file mode 100644 index 0000000000000..8f04ccf3dc0c2 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizerTests.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; +import org.apache.lucene.util.BytesRef; + +import java.io.IOException; +import java.io.Reader; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.UnigramTokenizer.PREFIX; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; + +public class UnigramTokenizerTests extends BaseTokenStreamTestCase { + private static final String UNKNOWN_TOKEN = ""; + private static final List NEVER_SPLIT = List.of(""); + + public void testSimpleTokenization() throws IOException { + TestNLPAnalyzer analyzer = new TestNLPAnalyzer( + List.of(UNKNOWN_TOKEN, PREFIX + "a", "b", "c", "d", "cd", PREFIX + "ab", PREFIX + "abc", PREFIX + "abcd", ""), + List.of(0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 5.0, 10.0, 0.0), + UNKNOWN_TOKEN, + new PrecompiledCharMapNormalizer.Config(new int[0], "") + ); + + assertAnalyzesToNoCharFilter(analyzer, "", new String[0]); + assertAnalyzesToNoCharFilter(analyzer, "abcd", new String[] { PREFIX + "abcd" }); + } + + public void testLessSimpleTokenization() throws IOException { + TestNLPAnalyzer analyzer = new TestNLPAnalyzer( + List.of(UNKNOWN_TOKEN, PREFIX + "ab", "cd", PREFIX + "abc", "a", "b", "c", "ABC", "abcdabcd", "q", "r", "qr", ""), + List.of(0.0, 0.0, -0.1, -0.2, -0.3, -0.4, -0.5, -0.5, 20.0, 20.5, 20.5, -0.5, 0.0), + UNKNOWN_TOKEN, + new PrecompiledCharMapNormalizer.Config(new int[0], "") + ); + + assertAnalyzesToNoCharFilter(analyzer, "", new String[0]); + assertAnalyzesToNoCharFilter(analyzer, "abcd", new String[] { PREFIX + "ab", "cd" }); + assertAnalyzesToNoCharFilter(analyzer, "abc", new String[] { PREFIX + "abc" }); + assertAnalyzesToNoCharFilter(analyzer, "AB", new String[] { PREFIX + "AB" }); + assertAnalyzesToNoCharFilter(analyzer, "abcc", new String[] { PREFIX + "abc", "c" }); + assertAnalyzesToNoCharFilter(analyzer, " \nabcd \n\n abcc \n", new String[] { PREFIX + "ab", "cd", PREFIX + "abc", "c" }); + } + + public void testLessSimpleTokenizationWithNeverSplit() throws IOException { + TestNLPAnalyzer analyzer = new TestNLPAnalyzer( + List.of( + UNKNOWN_TOKEN, + PREFIX + "ab", + "cd", + PREFIX + "cd", + PREFIX + "abc", + "a", + "b", + "c", + "ABC", + "abcdabcd", + "q", + "r", + "qr", + "" + ), + List.of(0.0, 0.0, -0.1, -0.2, -0.2, -0.3, -0.4, -0.5, -0.5, 20.0, 20.5, 20.5, -0.5, 0.0), + UNKNOWN_TOKEN, + new PrecompiledCharMapNormalizer.Config(new int[0], "") + ); + + assertAnalyzesToNoCharFilter(analyzer, "", new String[] { "" }); + assertAnalyzesToNoCharFilter(analyzer, "abcd", new String[] { "", PREFIX + "ab", "cd", "" }); + assertAnalyzesToNoCharFilter( + analyzer, + " \nabcd \n\n abcc \n", + new String[] { "", PREFIX + "ab", "", PREFIX + "cd", PREFIX + "abc", "c", "" } + ); + } + + public void testTriePrefixMatch() { + List inputs = new ArrayList<>( + List.of( + new BytesRef("a"), + new BytesRef("b"), + new BytesRef("c"), + new BytesRef("d"), + new BytesRef("cd"), + new BytesRef("ab"), + new BytesRef("abc"), + new BytesRef("abcd") + ) + ); + Collections.shuffle(inputs, random()); + UnigramTokenizer.BytesTrie bytesTrie = UnigramTokenizer.BytesTrie.build(inputs); + String input = "abcd"; + assertThat( + bytesTrie.matchingPrefixes(new BytesRef(input)).stream().map(BytesRef::utf8ToString).toList(), + contains("a", "ab", "abc", "abcd") + ); + input = "bcd"; + assertThat(bytesTrie.matchingPrefixes(new BytesRef(input)).stream().map(BytesRef::utf8ToString).toList(), contains("b")); + input = "cd"; + assertThat(bytesTrie.matchingPrefixes(new BytesRef(input)).stream().map(BytesRef::utf8ToString).toList(), contains("c", "cd")); + input = "d"; + assertThat(bytesTrie.matchingPrefixes(new BytesRef(input)).stream().map(BytesRef::utf8ToString).toList(), contains("d")); + input = ""; + assertThat(bytesTrie.matchingPrefixes(new BytesRef(input)).stream().map(BytesRef::utf8ToString).toList(), empty()); + input = "zabcd"; + assertThat(bytesTrie.matchingPrefixes(new BytesRef(input)).stream().map(BytesRef::utf8ToString).toList(), empty()); + input = "azbcd"; + assertThat(bytesTrie.matchingPrefixes(new BytesRef(input)).stream().map(BytesRef::utf8ToString).toList(), contains("a")); + input = "abzcd"; + assertThat(bytesTrie.matchingPrefixes(new BytesRef(input)).stream().map(BytesRef::utf8ToString).toList(), contains("a", "ab")); + input = "abcdz"; + assertThat( + bytesTrie.matchingPrefixes(new BytesRef(input)).stream().map(BytesRef::utf8ToString).toList(), + contains("a", "ab", "abc", "abcd") + ); + } + + private static class TestNLPAnalyzer extends Analyzer { + private final List dictionary; + private final List scores; + private final String unknownToken; + private final PrecompiledCharMapNormalizer.Config normalizer; + + TestNLPAnalyzer(List dictionary, List scores, String unknownToken, PrecompiledCharMapNormalizer.Config normalizer) { + this.dictionary = dictionary; + this.scores = scores; + this.unknownToken = unknownToken; + this.normalizer = normalizer; + } + + @Override + protected Reader initReader(String fieldName, Reader reader) { + if (normalizer.offsets().length > 0) { + return new PrecompiledCharMapNormalizer(normalizer.offsets(), normalizer.utf8str(), reader); + } + return reader; + } + + @Override + protected TokenStreamComponents createComponents(String fieldName) { + UnigramTokenizer tokenizer = UnigramTokenizer.build(NEVER_SPLIT, dictionary, scores, unknownToken); + return new TokenStreamComponents(tokenizer); + } + } + + private static void assertAnalyzesToNoCharFilter(Analyzer a, String input, String[] output) throws IOException { + assertTokenStreamContents(a.tokenStream("dummy", input), output, null, null, null, null, null, input.length()); + checkResetException(a, input); + // We don't allow the random char filter because our offsets aren't corrected appropriately due to "never_split" + // If we could figure out a way to pass "never_split" through whichever passed char_filter there was, then it would work + checkAnalysisConsistency(random(), a, false, input); + } + +}