diff --git a/Cargo.lock b/Cargo.lock index fa1a45a75f2..93e13056714 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -60,30 +60,30 @@ checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" [[package]] name = "anstyle-parse" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +checksum = "a3a318f1f38d2418400f8209655bfd825785afd25aa30bb7ba6cc792e4596748" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.1" +version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" dependencies = [ "anstyle", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -333,9 +333,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.4.10" +version = "4.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fffed7514f420abec6d183b1d3acfd9099c79c3a10a06ade4f8203f1411272" +checksum = "bfaff671f6b22ca62406885ece523383b9b64022e341e53e009a62ebc47a45f2" dependencies = [ "clap_builder", "clap_derive", @@ -343,9 +343,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.9" +version = "4.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63361bae7eef3771745f02d8d892bec2fee5f6e34af316ba556e7f97a7069ff1" +checksum = "a216b506622bb1d316cd51328dce24e07bdff4a6128a47c7e7fad11878d5adbb" dependencies = [ "anstream", "anstyle", @@ -392,9 +392,9 @@ dependencies = [ [[package]] name = "core-foundation" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -402,9 +402,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "cpufeatures" @@ -549,9 +549,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3" +checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" dependencies = [ "powerfmt", ] @@ -1218,9 +1218,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "lock_api" @@ -1612,9 +1612,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.60" +version = "0.10.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79a4c6c3a2b158f7f8f2a2fc5a969fa3a068df6fc9dbb4a43845436e3af7c800" +checksum = "6b8419dc8cc6d866deb801274bba2e6f8f6108c1bb7fcc10ee5ab864931dbb45" dependencies = [ "bitflags 2.4.1", "cfg-if", @@ -1644,9 +1644,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.96" +version = "0.9.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3812c071ba60da8b5677cc12bcb1d42989a65553772897a7e0355545a819838f" +checksum = "c3eaad34cdd97d81de97964fc7f29e2d104f483840d906ef56daa1912338460b" dependencies = [ "cc", "libc", @@ -2309,15 +2309,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.25" +version = "0.38.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc99bc2d4f1fed22595588a013687477aedf3cdcfb26558c559edb67b4d9b22e" +checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" dependencies = [ "bitflags 2.4.1", "errno", "libc", "linux-raw-sys", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -2567,9 +2567,9 @@ dependencies = [ [[package]] name = "slotmap" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" dependencies = [ "version_check", ] @@ -3835,18 +3835,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.26" +version = "0.7.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e97e415490559a91254a2979b4829267a57d2fcd741a98eee8b722fb57289aa0" +checksum = "7d6f15f7ade05d2a4935e34a457b936c23dc70a05cc1d97133dc99e7a3fe0f0e" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.26" +version = "0.7.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd7e48ccf166952882ca8bd778a43502c64f33bf94c12ebe2a7f08e5a0f6689f" +checksum = "dbbad221e3f78500350ecbd7dfa4e63ef945c05f4c61cb7f4d3f84cd0bba649b" dependencies = [ "proc-macro2", "quote", diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 62abe8c6edb..9590e463214 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -67,6 +67,14 @@ Options: - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model +``` +## SPECULATE +```shell + --speculate + The number of input_ids to speculate on If using a medusa model, the heads will be picked up automatically Other wise, it will use n-gram speculation which is relatively free in terms of compute, but the speedup heavily depends on the task + + [env: SPECULATE=] + ``` ## DTYPE ```shell diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json new file mode 100644 index 00000000000..d8a298eb23a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json @@ -0,0 +1,98 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 338, + "logprob": -10.0078125, + "text": "is" + }, + { + "id": 21784, + "logprob": -15.515625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -2.8847656, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -4.140625, + "text": "?" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -1.1582031, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.23083496, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": 0.0, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": 0.0, + "special": false, + "text": " learning" + }, + { + "id": 29892, + "logprob": -0.61816406, + "special": false, + "text": "," + }, + { + "id": 607, + "logprob": -0.7089844, + "special": false, + "text": " which" + }, + { + "id": 508, + "logprob": -1.7724609, + "special": false, + "text": " can" + }, + { + "id": 367, + "logprob": 0.0, + "special": false, + "text": " be" + }, + { + "id": 5545, + "logprob": 0.0, + "special": false, + "text": " considered" + }, + { + "id": 408, + "logprob": -0.3869629, + "special": false, + "text": " as" + } + ] + }, + "generated_text": "What is Deep Learning?\nDeep learning, which can be considered as" +} diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json new file mode 100644 index 00000000000..413af1d7ee2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json @@ -0,0 +1,414 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2753906, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.48046875, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1845703, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.5727539, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.00010967255, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.04510498, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.00020992756, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.0046539307, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025844574, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1826172, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004711151, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025892258, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1826172, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004711151, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025892258, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2724609, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.47729492, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1826172, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.56689453, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108003616, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.018295288, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004711151, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00025892258, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json new file mode 100644 index 00000000000..15754b14956 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json @@ -0,0 +1,103 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1724, + "logprob": -10.734375, + "text": "What" + }, + { + "id": 338, + "logprob": -1.5488281, + "text": "is" + }, + { + "id": 21784, + "logprob": -9.2890625, + "text": "Deep" + }, + { + "id": 29257, + "logprob": -1.2753906, + "text": "Learning" + }, + { + "id": 29973, + "logprob": -0.48046875, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.1845703, + "special": false, + "text": "\n" + }, + { + "id": 2772, + "logprob": -0.5727539, + "special": false, + "text": "De" + }, + { + "id": 1022, + "logprob": -0.000108122826, + "special": false, + "text": "ep" + }, + { + "id": 6509, + "logprob": -0.1239624, + "special": false, + "text": " learning" + }, + { + "id": 338, + "logprob": -0.044433594, + "special": false, + "text": " is" + }, + { + "id": 263, + "logprob": -0.01852417, + "special": false, + "text": " a" + }, + { + "id": 11306, + "logprob": -0.45922852, + "special": false, + "text": " subset" + }, + { + "id": 310, + "logprob": -0.0002104044, + "special": false, + "text": " of" + }, + { + "id": 4933, + "logprob": -0.004787445, + "special": false, + "text": " machine" + }, + { + "id": 6509, + "logprob": -0.00026226044, + "special": false, + "text": " learning" + } + ] + }, + "generated_text": "\nDeep learning is a subset of machine learning" +} diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py new file mode 100644 index 00000000000..003409b0f53 --- /dev/null +++ b/integration-tests/models/test_flash_medusa.py @@ -0,0 +1,59 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_medusa_handle(launcher): + with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_medusa(flash_medusa_handle): + await flash_medusa_handle.health(300) + return flash_medusa_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_medusa_simple(flash_medusa, response_snapshot): + response = await flash_medusa.generate( + "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_medusa_all_params(flash_medusa, response_snapshot): + response = await flash_medusa.generate( + "What is Deep Learning?", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): + responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" + assert responses[0].generated_text == '\nDeep learning is a subset of machine learning' + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py index 63cb09b5a21..7d21afd9bb7 100644 --- a/integration-tests/models/test_flash_mistral.py +++ b/integration-tests/models/test_flash_mistral.py @@ -21,6 +21,7 @@ async def test_flash_mistral(flash_mistral, response_snapshot): ) assert response.details.generated_tokens == 10 + assert response.generated_text == ": Let n = 10 - 1" assert response == response_snapshot @@ -55,6 +56,7 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho ) assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" + assert responses[0].generated_text == ": Let n = 10 - 1" assert responses == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b4fc86b7bee..4e230205a20 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -155,6 +155,13 @@ struct Args { #[clap(long, env, value_enum)] quantize: Option, + /// The number of input_ids to speculate on + /// If using a medusa model, the heads will be picked up automatically + /// Other wise, it will use n-gram speculation which is relatively free + /// in terms of compute, but the speedup heavily depends on the task. + #[clap(long, env)] + speculate: Option, + /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. #[clap(long, env, value_enum)] dtype: Option, @@ -375,6 +382,7 @@ fn shard_manager( model_id: String, revision: Option, quantize: Option, + speculate: Option, dtype: Option, trust_remote_code: bool, uds_path: String, @@ -432,6 +440,11 @@ fn shard_manager( shard_args.push(quantize.to_string()) } + if let Some(speculate) = speculate { + shard_args.push("--speculate".to_string()); + shard_args.push(speculate.to_string()) + } + if let Some(dtype) = dtype { shard_args.push("--dtype".to_string()); shard_args.push(dtype.to_string()) @@ -882,6 +895,7 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let quantize = args.quantize; + let speculate = args.speculate; let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; @@ -896,6 +910,7 @@ fn spawn_shards( model_id, revision, quantize, + speculate, dtype, trust_remote_code, uds_path, diff --git a/load_tests/common.js b/load_tests/common.js index be812e9b57b..5d71abeab1f 100644 --- a/load_tests/common.js +++ b/load_tests/common.js @@ -7,7 +7,9 @@ const seed = 0; const host = __ENV.HOST || '127.0.0.1:8000'; const timePerToken = new Trend('time_per_token', true); -const throughput = new Counter('tokens_per_s'); +const tokens = new Counter('tokens'); +const new_tokens = new Counter('new_tokens'); +const input_tokens = new Counter('input_tokens'); randomSeed(seed); // const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json")) @@ -19,7 +21,7 @@ export function get_options(reference_latency_ms){ thresholds: { http_req_failed: ['rate==0'], time_per_token: [{ - threshold: `p(50)<${3 * reference_latency_ms}`, + threshold: `p(50)<${5 * reference_latency_ms}`, abortOnFail: true, delayAbortEval: '10s' }], @@ -28,7 +30,7 @@ export function get_options(reference_latency_ms){ load_test: { executor: 'constant-arrival-rate', duration: '60s', - preAllocatedVUs: 100, + preAllocatedVUs: 10, rate: 10, timeUnit: '1s', }, @@ -48,17 +50,22 @@ export function run(host, generate_payload, max_new_tokens) { return; } + check(res, { 'Post status is 200': (r) => res.status === 200, }); - const n_tokens = max_new_tokens; - const timings = res.timings.duration; + const duration = res.timings.duration; if (res.status === 200) { - const latency_ms_per_token = timings / n_tokens; + const body = res.json(); + const n_tokens = body.details.tokens.length; + const latency_ms_per_token = duration / n_tokens; timePerToken.add(latency_ms_per_token); const latency_in_s = latency_ms_per_token / 1000; const individual_throughput = 1 / latency_in_s; - throughput.add(individual_throughput); + const _input_tokens = body.details.prefill.length; + tokens.add(n_tokens + _input_tokens); + input_tokens.add(_input_tokens); + new_tokens.add(n_tokens); } } diff --git a/load_tests/tgi.js b/load_tests/tgi.js index 93a0e278137..1db4ab6fc43 100644 --- a/load_tests/tgi.js +++ b/load_tests/tgi.js @@ -1,13 +1,13 @@ import { get_options, run } from "./common.js"; -const reference_latency_ms = 30; +const reference_latency_ms = 70; const host = __ENV.HOST || '127.0.0.1:8000'; const max_new_tokens = 50; function generate_payload(gpt){ const input = gpt["conversations"][0]["value"]; - return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "temperature" : 0.5}} + return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}} } export const options = get_options(reference_latency_ms); diff --git a/proto/generate.proto b/proto/generate.proto index c873e6615ef..19ec059be11 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package generate.v1; +package generate.v2; service TextGenerationService { /// Model Info @@ -32,6 +32,7 @@ message InfoResponse { string dtype = 2; string device_type = 3; optional uint32 window_size = 4; + uint32 speculate = 5; } /// Empty request @@ -135,43 +136,27 @@ message GeneratedText { optional uint64 seed = 4; } -message PrefillTokens { - /// Prefill Token IDs +message Tokens { + /// Token IDs repeated uint32 ids = 1; - /// Prefill Logprobs + /// Logprobs repeated float logprobs = 2; - /// Prefill tokens + /// tokens repeated string texts = 3; -} - -message TopTokens { - /// Top Token IDs - repeated uint32 ids = 1; - /// Top Logprobs - repeated float logprobs = 2; - /// Top Token Texts - repeated string texts = 3; - /// If the tokens are special - repeated bool is_special = 6; + /// special + repeated bool is_special = 4; } message Generation { /// Request ID uint64 request_id = 1; /// Prefill tokens (optional) - PrefillTokens prefill_tokens = 2; - /// Token ID - uint32 token_id = 3; - /// Logprob - float token_logprob = 4; - /// Text - string token_text = 5; - /// Is it a special token - bool token_is_special = 6; + Tokens prefill_tokens = 2; + Tokens tokens = 3; /// Complete generated text - optional GeneratedText generated_text = 7; + optional GeneratedText generated_text = 4; /// Top tokens - TopTokens top_tokens = 8; + repeated Tokens top_tokens = 5; } message FilterBatchRequest { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 341e70fd588..1560f19c442 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -1,6 +1,6 @@ /// Single shard Client -use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; -use crate::pb::generate::v1::*; +use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; +use crate::pb::generate::v2::*; use crate::Result; use grpc_metadata::InjectTelemetryContext; use std::cmp::min; diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index f334be21759..c38b931b6f8 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -6,11 +6,11 @@ mod pb; mod sharded_client; pub use client::Client; -pub use pb::generate::v1::HealthResponse; -pub use pb::generate::v1::InfoResponse as ShardInfo; -pub use pb::generate::v1::{ +pub use pb::generate::v2::HealthResponse; +pub use pb::generate::v2::InfoResponse as ShardInfo; +pub use pb::generate::v2::{ Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, - PrefillTokens, Request, StoppingCriteriaParameters, + Request, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/infer.rs b/router/src/infer.rs index aa6dc66460d..2e199ce2210 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -9,7 +9,7 @@ use std::sync::{ Arc, }; use text_generation_client::{ - Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, + Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens, }; use thiserror::Error; use tokio::sync::mpsc::error::SendError; @@ -50,10 +50,11 @@ impl Infer { max_concurrent_requests: usize, requires_padding: bool, window_size: Option, + speculate: u32, generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding, 16, window_size); + let queue = Queue::new(requires_padding, 16, window_size, speculate); let shared = Arc::new(Shared { batching_task: Notify::new(), }); @@ -523,50 +524,63 @@ fn send_responses( } // Create last Token - let token = Token { - id: generation.token_id, - text: generation.token_text, - logprob: generation.token_logprob, - special: generation.token_is_special, - }; - - // generation.top_tokens - - let mut top_tokens = Vec::new(); - if let Some(top_tokens_) = generation.top_tokens { - top_tokens.extend( + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs.into_iter()) + .zip(tokens_.texts.into_iter()) + .zip(tokens_.is_special.into_iter()) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { top_tokens_ .ids - .into_iter() - .zip(top_tokens_.logprobs.into_iter()) - .zip(top_tokens_.texts.into_iter()) - .zip(top_tokens_.is_special.into_iter()) - .map(|(((id, logprob), text), special)| Token { + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { id, - text, + text: text.to_string(), logprob, special, - }), - ) + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: generated_text.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } } - if let Some(generated_text) = generation.generated_text { - // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { - token, - top_tokens, - generated_text, - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }))?; - } else { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; - } Ok(stopped) } @@ -591,7 +605,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { #[derive(Debug)] pub(crate) enum InferStreamResponse { // Optional first message - Prefill(PrefillTokens), + Prefill(Tokens), // Intermediate messages Intermediate { token: Token, diff --git a/router/src/queue.rs b/router/src/queue.rs index bbb8db0e4ce..106cacc4b4c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -34,7 +34,12 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { + pub(crate) fn new( + requires_padding: bool, + block_size: u32, + window_size: Option, + speculate: u32, + ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -43,6 +48,7 @@ impl Queue { requires_padding, block_size, window_size, + speculate, queue_receiver, )); @@ -91,9 +97,10 @@ async fn queue_task( requires_padding: bool, block_size: u32, window_size: Option, + speculate: u32, mut receiver: mpsc::UnboundedReceiver, ) { - let mut state = State::new(requires_padding, block_size, window_size); + let mut state = State::new(requires_padding, block_size, window_size, speculate); while let Some(cmd) = receiver.recv().await { match cmd { @@ -136,10 +143,18 @@ struct State { /// Sliding window window_size: Option, + + /// Speculation amount + speculate: u32, } impl State { - fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { + fn new( + requires_padding: bool, + block_size: u32, + window_size: Option, + speculate: u32, + ) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, @@ -147,6 +162,7 @@ impl State { requires_padding, block_size, window_size, + speculate, } } @@ -229,7 +245,7 @@ impl State { } if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens) > token_budget + || (prefill_tokens + decode_tokens + self.speculate) > token_budget { // Entry is over budget // Add it back to the front @@ -359,7 +375,7 @@ mod tests { #[test] fn test_append() { - let mut state = State::new(false, 1, None); + let mut state = State::new(false, 1, None, 0); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -375,7 +391,7 @@ mod tests { #[test] fn test_next_batch_empty() { - let mut state = State::new(false, 1, None); + let mut state = State::new(false, 1, None, 0); assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none()); @@ -383,7 +399,7 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -415,7 +431,7 @@ mod tests { #[test] fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None); + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -448,14 +464,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); @@ -463,7 +479,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -496,7 +512,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -519,9 +535,28 @@ mod tests { assert_eq!(batch.size, 2); } + #[tokio::test] + async fn test_queue_next_batch_token_speculate() { + let queue = Queue::new(false, 1, None, 2); + let (entry1, _guard1) = default_entry(); + let (entry2, _guard2) = default_entry(); + queue.append(entry1); + queue.append(entry2); + + // Budget of 1 is not enough + assert!(queue.next_batch(None, 1, 1).await.is_none()); + + let (entries, batch, _) = queue.next_batch(None, 6, 6).await.unwrap(); + assert_eq!(entries.len(), 2); + assert!(entries.contains_key(&0)); + assert!(entries.contains_key(&1)); + assert_eq!(batch.id, 0); + assert_eq!(batch.size, 2); + } + #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None); + let queue = Queue::new(false, 1, None, 0); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/server.rs b/router/src/server.rs index f254afd8aa1..5f41fd5efb2 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -596,6 +596,7 @@ pub async fn run( max_concurrent_requests, shard_info.requires_padding, shard_info.window_size, + shard_info.speculate, generation_health, ); diff --git a/server/Makefile-vllm b/server/Makefile-vllm index ddb648ead34..c9c1d52047f 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,22 +1,25 @@ -build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git -build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc -build-vllm-cuda: BRANCH=main -build-vllm-cuda: build-vllm +vllm-cuda: + # Clone vllm + pip install -U ninja packaging --no-cache-dir + git clone https://github.com/vllm-project/vllm.git vllm + +build-vllm-cuda: vllm-cuda + cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc + cd vllm && python setup.py build -build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git -build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae -build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin -build-vllm-rocm: build-vllm +install-vllm-cuda: build-vllm-cuda + pip uninstall vllm -y || true + cd vllm && python setup.py install -vllm: +vllm-rocm: # Clone vllm pip install -U ninja packaging --no-cache-dir - git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm + git clone https://github.com/fxmarty/vllm-public.git vllm -build-vllm: vllm - cd vllm && git fetch && git checkout $(VLLM_COMMIT) +build-vllm-rocm: vllm-rocm + cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae cd vllm && python setup.py build -install-vllm: build-vllm +install-vllm-rocm: build-vllm-rocm pip uninstall vllm -y || true cd vllm && python setup.py install diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 71013cb60ca..1990ef8b2b0 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -133,8 +133,8 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 10264 for generation in generations]) - assert all([generation.token_text == "Test" for generation in generations]) + assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids]) + assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts]) assert generations[0].request_id == 0 diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 0f9dab2ceba..f105ce6f5aa 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -129,8 +129,8 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 13 for generation in generations]) - assert all([generation.token_text == "." for generation in generations]) + assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids]) + assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts]) assert generations[0].request_id == 0 diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 299340f87d9..d553067e1ac 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -151,8 +151,8 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 259 for generation in generations]) - assert all([generation.token_text == " " for generation in generations]) + assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids]) + assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts]) assert generations[0].request_id == 0 diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 3abe86afd23..cb151173d7e 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -32,6 +32,7 @@ def serve( revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, + speculate: Optional[int] = None, dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", @@ -81,7 +82,7 @@ def serve( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( - model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path + model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path ) @@ -116,7 +117,7 @@ def download_weights( logger.info("Files are already present on the host. " "Skipping download.") return # Local files not found - except (utils.LocalEntryNotFoundError, FileNotFoundError): + except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): pass is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( @@ -137,6 +138,29 @@ def download_weights( except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass + try: + import json + medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt") + if auto_convert: + medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors") + if not medusa_sf.exists(): + utils.convert_files([Path(medusa_head)], [medusa_sf], []) + medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json") + with open(medusa_config, "r") as f: + config = json.load(f) + + model_id = config["base_model_name_or_path"] + revision = "main" + try: + utils.weight_files(model_id, revision, extension) + logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.") + return + # Local files not found + except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): + pass + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + # Try to download weights from the hub try: filenames = utils.weight_hub_files(model_id, revision, extension) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ab3b25b7c89..27e3897d01c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -6,6 +6,7 @@ from transformers.models.auto import modeling_auto from typing import Optional +from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM @@ -77,12 +78,12 @@ if MISTRAL: __all__.append(FlashMistral) - def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str], + speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, ) -> Model: @@ -97,6 +98,11 @@ def get_model( else: raise RuntimeError(f"Unknown dtype {dtype}") + if speculate is not None: + set_speculate(speculate) + else: + set_speculate(0) + if "facebook/galactica" in model_id: return GalacticaSharded( model_id, @@ -131,6 +137,33 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + + use_medusa = None + if "medusa_num_heads" in config_dict: + use_medusa = model_id + medusa_config = config_dict + model_id = config_dict["base_model_name_or_path"] + revision = "main" + speculate_medusa = config_dict["medusa_num_heads"] + if speculate is not None: + if speculate > speculate_medusa: + raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match") + else: + set_speculate(speculate) + else: + set_speculate(speculate_medusa) + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + method = "medusa" + else: + method = "n-gram" + + speculate = get_speculate() + if speculate > 0: + logger.info(f"Using speculation {method} with {speculate} input ids.") + model_type = config_dict["model_type"] if model_type == "gpt_bigcode": @@ -206,6 +239,7 @@ def get_model( quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, + use_medusa=use_medusa ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 8056a8ecc18..c571a0220a4 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -10,10 +10,9 @@ from text_generation_server.models import Model from text_generation_server.models.types import ( Batch, - PrefillTokens, + Tokens, Generation, GeneratedText, - TopTokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -676,8 +675,8 @@ def generate_token( clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts + prefill_tokens = Tokens( + prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] ) else: prefill_tokens = None @@ -691,7 +690,7 @@ def generate_token( special_toptokens = [ token_id in self.all_special_ids for token_id in top_token_ids ] - top_tokens = TopTokens( + top_tokens = Tokens( top_token_ids, top_token_logprobs, toptoken_texts, @@ -703,10 +702,12 @@ def generate_token( generation = Generation( request.id, prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), generated_text, top_tokens, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f1a4854f9ea..79344ea1432 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -11,13 +11,13 @@ from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict -from text_generation_server.models import Model +from text_generation_server.models import Model +from text_generation_server.utils.speculate import get_speculate from text_generation_server.models.types import ( Batch, - PrefillTokens, + Tokens, Generation, GeneratedText, - TopTokens, ) from text_generation_server.models.cache_manager import ( get_cache_manager, @@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch): # Decoder values input_ids: torch.Tensor position_ids: torch.Tensor + speculative_ids: torch.Tensor # Flash Attention values @@ -120,6 +121,7 @@ def from_pb( )["input_ids"] position_ids = [] + speculative_ids = [] cu_seqlen_prefill = [0] needed_blocks_slots = [] start_slots = [] @@ -163,6 +165,8 @@ def from_pb( input_length = len(tokenized_input) input_lengths.append(input_length) + + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) @@ -186,7 +190,8 @@ def from_pb( # Paged attention # Remove one as the first token des not have a past - total_tokens = input_length + max_new_tokens - 1 + speculative_length = get_speculate() + total_tokens = input_length + max_new_tokens - 1 + speculative_length needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) @@ -224,7 +229,7 @@ def from_pb( cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) - max_length = max(max_length, input_length + max_new_tokens) + max_length = max(max_length, input_length + max_new_tokens + speculative_length) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device @@ -255,7 +260,6 @@ def from_pb( cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int32 ) - position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) @@ -309,6 +313,7 @@ def from_pb( top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + speculative_ids=None, ) @tracer.start_as_current_span("filter") @@ -419,6 +424,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] + speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -454,6 +460,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + speculative_ids=speculative_ids, ) @classmethod @@ -473,6 +480,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_batch_size += len(b) total_slots += len(b.slots) blocks += b.blocks + speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( @@ -480,6 +488,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch max( input_length + stopping_criteria.max_new_tokens + + speculative_length - stopping_criteria.current_tokens for input_length, stopping_criteria in zip( b.input_lengths, b.stopping_criterias @@ -577,6 +586,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch device=batches[0].next_token_chooser.device, ) + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None + # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: b.block_tables = None @@ -611,6 +622,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + speculative_ids=speculative_ids ) def __del__(self): @@ -714,16 +726,55 @@ def warmup(self, batch: FlashCausalLMBatch): def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward + if batch.speculative_ids is not None: + input_ids=batch.input_ids + position_ids=batch.position_ids + cu_seqlen_prefill=batch.cu_seqlen_prefill + kv_cache=get_cache_manager().kv_cache + block_tables=batch.block_tables_tensor + slots=batch.slots[batch.slot_indices] + input_lengths=batch.input_lengths_tensor + max_s=batch.max_seqlen + lm_head_indices=batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + # Add Copy the block tables for all members + block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids=batch.input_ids + position_ids=batch.position_ids + cu_seqlen_prefill=batch.cu_seqlen_prefill + kv_cache=get_cache_manager().kv_cache + block_tables=batch.block_tables_tensor + slots=batch.slots[batch.slot_indices] + input_lengths=batch.input_lengths_tensor + max_s=batch.max_seqlen + lm_head_indices=batch.prefill_head_indices + return self.model.forward( - input_ids=batch.input_ids, - position_ids=batch.position_ids, - cu_seqlen_prefill=batch.cu_seqlen_prefill, - kv_cache=get_cache_manager().kv_cache, - block_tables=batch.block_tables_tensor, - slots=batch.slots[batch.slot_indices], - input_lengths=batch.input_lengths_tensor, - max_s=batch.max_seqlen, - lm_head_indices=batch.prefill_head_indices, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=lm_head_indices, ) @tracer.start_as_current_span("generate_token") @@ -752,21 +803,32 @@ def generate_token( del batch raise e + if isinstance(out, tuple): + out, speculative_logits = out + else: + speculative_logits = None + + if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) + if speculative_logits is not None: + speculative_logits = ( + speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits + ) else: next_token_logits = out - next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits + next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs ) + speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1] if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -792,6 +854,7 @@ def generate_token( iterator = zip( batch.input_lengths, batch.all_input_ids, + accepted_ids ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second @@ -799,9 +862,11 @@ def generate_token( # It is faster if we delay this sync for the maximum amount of time # For each member of the batch + index = 0 for i, ( input_length, all_input_ids, + n_accepted_ids ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -830,15 +895,18 @@ def generate_token( start_index + 1 : start_index + out_length ] - batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] + for j in range(n_accepted_ids): + batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] + index += 1 cumulative_length += input_length - # Set values in batch - batch.input_ids = next_input_ids - batch.position_ids = next_position_ids + 1 - batch.input_lengths_tensor += 1 - batch.slot_indices += 1 + + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.input_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids if prefill and prefill_logprobs: # Get prefill logprobs @@ -851,7 +919,7 @@ def generate_token( # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = batch.input_ids.tolist() + next_token_ids = next_input_ids.tolist() # Zipped iterator iterator = zip( @@ -864,13 +932,13 @@ def generate_token( batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, - next_token_ids, - next_token_logprobs, + accepted_ids, batch_top_token_ids, batch_top_token_logprobs, ) # For each member of the batch + index = 0 for i, ( request, input_length, @@ -881,29 +949,43 @@ def generate_token( do_sample, seed, top_n_tokens, - next_token_id, - next_token_logprob, + n_accepted_ids, top_token_ids, top_token_logprobs, ) in enumerate(iterator): # Append next token to all tokens - all_input_ids.append(next_token_id) + next_token_texts = [] + left = 0 + before = stopping_criteria.current_tokens + + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) - # Generated token - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped - if not stop: - stopped = False + _next_token_ids = next_token_ids[index: index+n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left] + index += n_accepted_ids # Shard generations # All generations will be appended in the rust sharded client @@ -943,8 +1025,9 @@ def generate_token( clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = PrefillTokens( - prefill_token_ids, request_prefill_logprobs, prefill_texts + + prefill_tokens = Tokens( + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = [] ) else: prefill_tokens = None @@ -958,7 +1041,7 @@ def generate_token( special_toptokens = [ token_id in self.all_special_ids for token_id in top_token_ids ] - top_tokens = TopTokens( + top_tokens = Tokens( top_token_ids, top_token_logprobs, toptoken_texts, @@ -970,10 +1053,12 @@ def generate_token( generation = Generation( request.id, prefill_tokens, - next_token_id, - next_token_logprob, - next_token_text, - next_token_id in self.all_special_ids, + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), generated_text, top_tokens, ) @@ -981,7 +1066,9 @@ def generate_token( generations.append(generation) # Update values - batch.input_lengths[i] = input_length + 1 + batch.input_lengths[i] = input_length + n_accepted_ids.item() + if batch.input_lengths[i] > batch.max_seqlen: + batch.max_seqlen = batch.input_lengths[i] batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -994,6 +1081,5 @@ def generate_token( batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None - batch.max_seqlen = batch.max_seqlen + 1 return generations, batch diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index d2ed0b15a75..3a84b1b6a66 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -28,6 +28,7 @@ def __init__( quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -66,6 +67,18 @@ def __init__( weights._set_gptq_params(model_id) model = FlashLlamaForCausalLM(config, weights) + if use_medusa: + from text_generation_server.utils.medusa import MedusaModel + from huggingface_hub import hf_hub_download + import json + medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json") + with open(medusa_config, "r") as f: + config = json.load(f) + medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt") + medusa_sf = medusa_head[:-len(".pt")] + ".safetensors" + weights = Weights([medusa_sf], device, dtype, process_group=self.process_group) + lm_head = model.lm_head + model.lm_head = MedusaModel(config, weights, lm_head) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 919e4625a9d..e103d9fc2e1 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -21,6 +21,7 @@ FlashMistralForCausalLM, MistralConfig, ) +from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -132,7 +133,8 @@ def from_pb( # Paged attention # Remove one as the first token des not have a past - total_tokens = input_length + max_new_tokens - 1 + speculative_length = get_speculate() + total_tokens = input_length + max_new_tokens - 1 + speculative_length # Needed blocks can not go over SLIDING_WINDOW_BLOCKS needed_blocks = min( @@ -183,7 +185,7 @@ def from_pb( cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) - max_length = max(max_length, input_length + max_new_tokens) + max_length = max(max_length, input_length + max_new_tokens + speculative_length) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device @@ -272,6 +274,7 @@ def from_pb( blocks=blocks, max_blocks=max_blocks, prefill_cache_indices=prefill_cache_indices, + speculative_ids=None ) @@ -340,17 +343,55 @@ def batch_type(self) -> Type[FlashMistralBatch]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward + if batch.speculative_ids is not None: + input_ids=batch.input_ids + position_ids=batch.position_ids + cu_seqlen_prefill=batch.cu_seqlen_prefill + kv_cache=get_cache_manager().kv_cache + block_tables=batch.block_tables_tensor + slots=batch.slots[batch.slot_indices] + input_lengths=batch.input_lengths_tensor + max_s=batch.max_seqlen + lm_head_indices=batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + # Add Copy the block tables for all members + block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids=batch.input_ids + position_ids=batch.position_ids + cu_seqlen_prefill=batch.cu_seqlen_prefill + kv_cache=get_cache_manager().kv_cache + block_tables=batch.block_tables_tensor + slots=batch.slots[batch.slot_indices] + input_lengths=batch.input_lengths_tensor + max_s=batch.max_seqlen + lm_head_indices=batch.prefill_head_indices logits = self.model.forward( - input_ids=batch.input_ids, - position_ids=batch.position_ids, - cu_seqlen_prefill=batch.cu_seqlen_prefill, - kv_cache=get_cache_manager().kv_cache, - block_tables=batch.block_tables_tensor, - slots=batch.slots[batch.slot_indices], - input_lengths=batch.input_lengths_tensor, - max_s=batch.max_seqlen, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=batch.prefill_head_indices, + lm_head_indices=lm_head_indices, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index dcad1fa9851..2f4bb139ae3 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -20,7 +20,7 @@ from text_generation_server.models import Model from text_generation_server.models.types import ( Batch, - PrefillTokens, + Tokens, Generation, GeneratedText, ) @@ -791,8 +791,8 @@ def generate_token( clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts + prefill_tokens = Tokens( + prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] ) else: prefill_tokens = None @@ -802,10 +802,12 @@ def generate_token( generation = Generation( request.id, prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), generated_text, top_tokens, ) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 17d2ea9b433..8552960dd28 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -6,6 +6,7 @@ from transformers import PreTrainedTokenizerBase, PretrainedConfig from text_generation_server.models.types import Batch, Generation +from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) @@ -22,6 +23,7 @@ def __init__( rank: int = 0, world_size: int = 1, sliding_window: Optional[int] = None, + speculate: Optional[int] = None, ): self.model = model.eval() self.tokenizer = tokenizer @@ -33,6 +35,10 @@ def __init__( self.world_size = world_size self.sliding_window = sliding_window + if speculate is None: + speculate = get_speculate() + self.speculate = speculate + self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) is not None @@ -50,6 +56,7 @@ def info(self) -> InfoResponse: dtype=str(self.dtype), device_type=self.device.type, window_size=self.sliding_window, + speculate=self.speculate ) @property diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d4d3cd19f97..279b5505f07 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -11,8 +11,7 @@ GeneratedText, Batch, Generation, - PrefillTokens, - TopTokens, + Tokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -733,10 +732,11 @@ def generate_token( # Prefill if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - prefill_tokens = PrefillTokens( + prefill_tokens = Tokens( [self.tokenizer.bos_token_id], [float("nan")], [self.tokenizer.bos_token], + [False] ) else: prefill_tokens = None @@ -750,7 +750,7 @@ def generate_token( special_toptokens = [ token_id in self.all_special_ids for token_id in top_token_ids ] - top_tokens = TopTokens( + top_tokens = Tokens( top_token_ids, top_token_logprobs, toptoken_texts, @@ -762,10 +762,12 @@ def generate_token( generation = Generation( request.id, prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), generated_text, top_tokens, ) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 0e27680d180..87c03d6362c 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -58,33 +58,15 @@ def to_pb(self) -> generate_pb2.GeneratedText: @dataclass -class PrefillTokens: - token_ids: List[int] - logprobs: List[float] - texts: List[str] - - def to_pb(self) -> generate_pb2.PrefillTokens: - return generate_pb2.PrefillTokens( - ids=self.token_ids, logprobs=self.logprobs, texts=self.texts - ) - - def __len__(self): - return len(self.token_ids) - - -@dataclass -class TopTokens: +class Tokens: token_ids: List[int] logprobs: List[float] texts: List[str] is_special: List[bool] - def to_pb(self) -> generate_pb2.TopTokens: - return generate_pb2.TopTokens( - ids=self.token_ids, - logprobs=self.logprobs, - texts=self.texts, - is_special=self.is_special, + def to_pb(self) -> generate_pb2.Tokens: + return generate_pb2.Tokens( + ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special ) def __len__(self): @@ -94,14 +76,11 @@ def __len__(self): @dataclass class Generation: request_id: int - prefill_tokens: Optional[PrefillTokens] - token_id: int - token_logprob: float - token_text: str - token_is_special: bool + prefill_tokens: Optional[Tokens] + tokens: Tokens generated_text: Optional[GeneratedText] # Optional for now, since it's not yet supported for every model. - top_tokens: Optional[TopTokens] + top_tokens: Optional[List[Tokens]] def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -109,10 +88,7 @@ def to_pb(self) -> generate_pb2.Generation: prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None, - token_id=self.token_id, - token_logprob=self.token_logprob, - token_text=self.token_text, - token_is_special=self.token_is_special, + tokens=self.tokens.to_pb(), generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index fa831682928..ebe066e32e8 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -132,6 +132,7 @@ def serve( revision: Optional[str], sharded: bool, quantize: Optional[str], + speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, uds_path: Path, @@ -141,6 +142,7 @@ async def serve_inner( revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, + speculate: Optional[int] = None, dtype: Optional[str] = None, trust_remote_code: bool = False, ): @@ -157,7 +159,7 @@ async def serve_inner( try: model = get_model( - model_id, revision, sharded, quantize, dtype, trust_remote_code + model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code ) except Exception: logger.exception("Error when initializing model") @@ -205,5 +207,5 @@ async def serve_inner( await server.stop(0) asyncio.run( - serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) + serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code) ) diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py new file mode 100644 index 00000000000..029de122fde --- /dev/null +++ b/server/text_generation_server/utils/medusa.py @@ -0,0 +1,51 @@ +import torch +from dataclasses import dataclass +from text_generation_server.utils.layers import TensorParallelHead, FastLinear + +@dataclass +class Output: + logits: torch.FloatTensor = None + speculative_logits: torch.FloatTensor = None + + +class ResBlock(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True) + self.act = torch.nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(torch.nn.Module): + def __init__( + self, + config, + weights, + lm_head + ): + super().__init__() + self.heads = torch.nn.ModuleList( + [MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])] + ) + self.lm_head = lm_head + + def forward(self, x): + logits = self.lm_head(x) + speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + return logits, speculative_logits + + +class MedusaHead(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])]) + n = len(self.blocks) + self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.out(x) + return x diff --git a/server/text_generation_server/utils/speculate.py b/server/text_generation_server/utils/speculate.py new file mode 100644 index 00000000000..38a91972056 --- /dev/null +++ b/server/text_generation_server/utils/speculate.py @@ -0,0 +1,12 @@ + +SPECULATE = None + +def get_speculate() -> int: + global SPECULATE + return SPECULATE + +def set_speculate(speculate: int): + global SPECULATE + SPECULATE = speculate + + diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 0ff07417145..a34c5afc184 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -16,7 +16,6 @@ from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor - class NextTokenChooser: def __init__( self, @@ -146,6 +145,20 @@ def from_pb( pb.ignore_eos_token, ) +def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool): + # Very trivial approach, find first match in the string. + # This is much less refined than actual n-gram but seems to work + # relatively OK in grounded mode and is by far much faster with + # much less worst case complexity as everything happens on device. + B = accepted_ids.shape[0] + device = input_ids.device + seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ] + indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1 + all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device) + all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1) + + speculative_ids = input_ids.gather(dim=-1, index=all_indices) + return speculative_ids class HeterogeneousNextTokenChooser: def __init__( @@ -215,20 +228,79 @@ def __init__( self.dtype = dtype self.device = device - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): - if self.watermark_processor is not None: - scores = self.watermark_processor(input_ids, scores) - if self.repetition_processor is not None: - scores = self.repetition_processor(input_ids, scores) - - for warper in self.warpers: - scores = warper(input_ids, scores) + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False): + if speculated_ids is not None: + B = scores.shape[0] // (speculated_ids.shape[1] + 1) + S = speculated_ids.shape[1] + 1 + scores = scores.view(B, S, -1) + else: + B = scores.shape[0] + S = 1 + scores = scores.view(B, S, -1) + + next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) + for j in range(S): + _scores = scores[:, j] + if self.watermark_processor is not None: + _scores = self.watermark_processor(input_ids, _scores) + if self.repetition_processor is not None: + _scores = self.repetition_processor(input_ids, _scores) + + for warper in self.warpers: + _scores = warper(input_ids, _scores) + + + _next_ids = self.choice(_scores) + scores[:, j] = _scores + next_ids[:, j] = _next_ids + next_ids = next_ids.view(B*S) + scores = scores.view( B* S, -1) + + if speculated_ids is not None: + accepted_ids = [] + B = next_ids.shape[0] // (speculated_ids.shape[1] + 1) + S = speculated_ids.shape[1] + 1 + indices = [] + for i in range(B): + _next_ids = next_ids[i*S: (i + 1)*S] + _speculated_ids = speculated_ids[i] + validate_speculative = _next_ids[:-1] == _speculated_ids + index = i * S + accepted = 1 + # First is always valid + indices.append(index) + for valid in validate_speculative.tolist(): + if valid: + index += 1 + accepted += 1 + indices.append(index) + else: + break + accepted_ids.append(accepted) + + accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) + next_ids = next_ids[indices] + scores = scores[indices] + indices = torch.arange(B, device=input_ids.device) * S + if speculative_scores is not None: + speculative_scores = speculative_scores[indices + accepted_ids - 1] + else: + accepted_ids = torch.ones_like(next_ids) - next_ids = self.choice(scores) logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) - return next_ids, next_logprobs, logprobs + if speculate > 0: + if speculative_scores is not None: + # Medusa provided some scores + speculative_ids = Greedy()(speculative_scores) + else: + # n-gram + speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose) + else: + speculative_ids = None + + return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids def filter(self, indices): if self.watermark_processor is not None: