diff --git a/lib/src/main/java/com/knuddels/jtokkit/AbstractEncodingRegistry.java b/lib/src/main/java/com/knuddels/jtokkit/AbstractEncodingRegistry.java index f46a1e5..531c228 100644 --- a/lib/src/main/java/com/knuddels/jtokkit/AbstractEncodingRegistry.java +++ b/lib/src/main/java/com/knuddels/jtokkit/AbstractEncodingRegistry.java @@ -32,6 +32,10 @@ public Optional getEncodingForModel(final String modelName) { return Optional.of(getEncodingForModel(modelType.get())); } + if (modelName.startsWith(ModelType.GPT_4O.getName())) { + return Optional.of(getEncodingForModel(ModelType.GPT_4O)); + } + if (modelName.startsWith(ModelType.GPT_4_32K.getName())) { return Optional.of(getEncodingForModel(ModelType.GPT_4_32K)); } diff --git a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java index 3a14d32..d14afa9 100644 --- a/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java +++ b/lib/src/test/java/com/knuddels/jtokkit/BaseEncodingRegistryTest.java @@ -88,6 +88,27 @@ void getEncodingForModelByPrefixReturnsCorrectEncodingForGpt4() { assertEquals(encoding.get().getName(), ModelType.GPT_4.getEncodingType().getName()); } + @Test + void getEncodingForModelByPrefixReturnsCorrectEncodingForGpt4o() { + var encoding = registry.getEncodingForModel("gpt-4o-123"); + assertTrue(encoding.isPresent()); + assertEquals(encoding.get().getName(), ModelType.GPT_4O.getEncodingType().getName()); + } + + @Test + void getEncodingForModelByPrefixReturnsCorrectEncodingForGpt4oMini() { + var encoding = registry.getEncodingForModel("gpt-4o-mini-123"); + assertTrue(encoding.isPresent()); + assertEquals(encoding.get().getName(), ModelType.GPT_4O_MINI.getEncodingType().getName()); + } + + @Test + void getEncodingForModelByPrefixReturnsCorrectEncodingForGpt4Turbo() { + var encoding = registry.getEncodingForModel("gpt-4-turbo-123"); + assertTrue(encoding.isPresent()); + assertEquals(encoding.get().getName(), ModelType.GPT_4_TURBO.getEncodingType().getName()); + } + @Test void getEncodingForModelByPrefixReturnsCorrectEncodingForGpt3Turbo() { var encoding = registry.getEncodingForModel("gpt-3.5-turbo-0301");