Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama/GPTNeoX: add RoPE scaling #24653

Merged
merged 9 commits into from
Jul 13, 2023
Merged

Llama/GPTNeoX: add RoPE scaling #24653

merged 9 commits into from
Jul 13, 2023

Conversation

gante
Copy link
Member

@gante gante commented Jul 4, 2023

What does this PR do?

This is an experimental PR for discussion, so we can decide whether to add this pattern.

Context

In the past week, there have been several developments about scaling RoPE (Rotary Position Embeddings, i.e. Llama's position embeddings) so as to be able to extrapolate beyond 2048 tokens. Without any scaling and/or finetuning, the perplexity quickly explodes when we go beyond 2048 tokens. Here's the sequence of RoPE scaling improvements, announced mostly on Reddit:

  1. Linear scaling -- Simply divide the position index by a scaling factor. Needs fine-tuning to observe the best results. Discussed in this lmsys blog post. Credits to the reddit user /u/kaiokendev.
  2. NTK-aware scaling -- proposed in this reddit thread. Scaling the RoPE Fourier space linearly is not optimal to evenly distribute information, so this can be seen as a improved linear scaling. Works okay without fine-tuning, but seems to benefit from it. Credits to the reddit user /u/bloc97. EDIT: following the comments in this thread, this technique will not be added!
  3. Dynamic NTK scaling -- proposed in this reddit thread. It's a form of NTK-aware scaling that a) works quite well without fine-tuning; b) doesn't degrade the performance if the model is used with short sequences; c) gracefully scales to long sequences, under a fixed parameterization. Credits to the reddit user /u/emozilla.

Changes in the PR

The goal of this PR is to debate whether we want to include RoPE scaling support, with working code as reference. The field is evolving quite fast, so I've added it in a way that we can quicky add to new scaling strategies and keep surfing the wave 🏄 Of course, the implementation itself is up for discussion! (An alternative implementation would be to have separate classes for the scalable RoPEs)

Pros:

  • Flexible implementation that allows adding new scaling methods in minutes;
  • Works quite well with pre-trained models (see example below), through dynamic NTK scaling;
  • Supports strategies that are compatible with fine-tuning (it is unclear whether dynamic NTK works well with fine-tuning, and it seems like Linear scaling is better after fine-tuning)

Cons:

  • rope_scaling is a dictionary input, which is somewhat undesirable;
  • additional if/else branches in RoPE

Example

Consider the following prompt from a paper transcript, containing ~6k tokens:

prompt built from the transcript of https://arxiv.org/abs/2306.15595
prompt = '''
You are given this machine learning research paper, please read it carefully and answer the follow up question.

=== BEGIN ===

2306.15595v2 [cs.CL] 28 Jun 2023

arXiv

EXTENDING CONTEXT WINDOW OF LARGE LAN-
GUAGE MODELS VIA POSITION INTERPOLATION

Shouyuan Chen Sherman Wong Liangjian Chen  Yuandong Tian
Meta Platforms Inc.
{chenshouyuan, shermanwong, cli, yuandong}@meta . com

1 INTRODUCTION

Large language models (LLMs) typically come with a pre-defined context window size. For exam-
ple, inputs to LLaMA models (Touvron et al., 2023) must be fewer than 2048 tokens. This pre-set
context window limit is frequently exceeded in applications such as conducting long conversations,
summarizing long documents, or executing long-term planning. For these applications, LLMs with
longer context windows are preferred. However, training an LLM from scratch with long context
windows requires significant investments. This naturally leads to a question: Can we extend the
context window of an existing pre-trained LLM?

One straightforward approach is to fine-tune an existing pre-trained Transformer with a longer con-
text window. However, empirically, we found that models trained this way adapt to long context
windows very slowly. After training for more than 10000 batches, the effective context window
saw a minimal increase, moving from 2048 to 2560 (Table 4). This suggests that such method is
inefficient for extending to substantially longer context windows.

While certain techniques such as ALiBi (Press et al., 2022) and LeX (Sun et al., 2022) enable length
extrapolation of Transformers, i.e. train on short context windows and inference on longer ones,
many existing pre-trained LLMs, including LLaMA (Touvron et al., 2023), use positional encodings
that have weak extrapolation properties (e.g., RoPE (Su et al., 2021)). Therefore, the applicability
of these techniques for extending the context window sizes of such LLMs remains limited.

In this work, we introduce Position Interpolation to enable context window extensions for certain
existing pre-trained LLMs, including LLaMA. The key idea is, instead of extrapolation, we directly
down-scale the position indices so that the maximum position index matches the previous context
window limit in the pre-training stage. See Figure 1 for an illustration. In other words, to accom-
modate more input tokens, we interpolate the position encodings at neighboring integer positions,
utilizing the fact that position encodings can be applied on non-integer positions, as opposed to
extrapolating outside the trained positions, which may lead to catastrophic values. We verify our
approach theoretically, by showing that the interpolated attention score has a much smaller upper

bound (~ 600x smaller in LLaMA 7B setting) than the extrapolated one, and is thus much more
stable. Therefore, interpolated position encodings are easier for the model to adapt.

Empirically, we found that Position Interpolation is highly effective and efficient, requiring only a
very short period of fine-tuning for the model to fully adapt to greatly extended context windows.
We present experimental results for extending the context window to up to 32768 from the initial
2048 across 7B to 65B LLaMA models using Position Interpolation. Our results show that

1. Position Interpolation can easily enable very long context windows (e.g. 32768), requiring
only fine-tuning for 1000 steps on the Pile (Gao et al., 2020) to achieve a good quality.
The cost of fine-tuning is negligible compared to the pre-training costs. This confirms
our hypothesis that it is relatively easy for the models to adapt to interpolated position
encodings.

2. Position Interpolation generates strong models that can effectively make use of much ex-
tended context window. We show that models extended by Position Interpolation enjoy
significant perplexity gains from greatly extended context windows for text modeling, and
we show that the perplexity reduces graceful with the enlargement of context windows.
We also applied Position Interpolation in a long text summarization task, and demonstrate
competitive performances.

3. Position Interpolation preserves model quality relatively well for tasks within its original
context window sizes. We present a variety of evaluation results for the extended LLaMA
models on the original LLaMA benchmark. Compared with original LLaMA models, the
extended LLLaM A models saw a minor degradation on several standard benchmarks within
a 2048 token limit.

Our results highlight the innate ability of Transformer models to “extrapolate to sequence lengths
longer than the ones encountered during training” as hypothesized in the seminal work of Vaswani
et al. (2017). We reaffirm this hypothesis and suggest that the previously known weakness of ex-
trapolating to longer sequences for language modeling (Press et al., 2022) may be due to direct

extrapolation of positional encodings and it can be largely mitigated by interpolating position en-
codings instead.

Concurrent work. Right before our release, we are informed with a concurrent blogpost (Super-
HOT kaiokendev (2023)) that also interpolates positional encoding in RoPE to extend the context
window from 2K to 8K. Recently, open source community picks it up in Reddit post ! and Github
Issues 2, which shows that fine-tuning with LoRA (Hu et al., 2021) also seems to work well. Our
paper shows a full fine-tuning with up to 65B model work well with Position Interpolation, and we
also give theoretical explanations why interpolation achieves much more stable results than extrap-
olation, by showing that the upper bound of interplated attention score is much lower than that of
extrapolated ones.

2 METHOD

2.1 BACKGROUND: ROTARY POSITION EMBEDDING (ROPE)

Transformer models require explicit positional information to be injected, typically in the form of
positional encodings, to represent the order of inputs. We consider Rotary Position Embedding
(ROPE) (Su et al., 2021), which is the position encoding used in the LLLaMA model (Touvron et al.,
2023). Given a position index m € [0, ¢) and an embedding vector x := [zg, 71,..., 241], Where
d is the dimension of the attention head, RoPE defines a vector-valued complex function f{x, m) as
follows

Using RoPE, the self-attention score
is only dependent on relative position m — 7 through trigonometric functions. Here q and k are the
query and key vector for a specific attention head. At each layer, RoPE is applied on both query and
key embeddings for computing attention scores.

2.2 DIRECT EXTRAPOLATION

While the attention score in RoPE only depends on the relative positions, which is what we want,
its extrapolation performance is not great . In particular, when directly extending to larger context
windows unseen in the training, the perplexity may shoot up to very high numbers (i.e., > 10%),
comparable to untrained models.

Ideally, we want to see the model trained on a context window of size L = 2048 to still work
reasonably well on longer context window, but may not have the capability to leverage information
that appears beyond L. For example, to answer a question located at 3000, the model trained on
maximal window size of I = 2048 cannot leverage evidences provided at location 0, but still
can leverage the evidences provided at location 2900. In contrast, in reality we see catastrophic
behaviors, i.e., question at location 3000 cannot be answered correctly, even if the evidences are
located at location 2900.

What is the reason behind? How could this happen if the attention score a,,,—,, decays as the relative
distance |m — n/| increases, according to Section 3.4.3 of (Su et al., 2021), and content from very
far distances should not matter that much? It turns out that the upper bound derived in Section 3.4.3
of (Su et al., 2021) may be too loose: while it indeed decays with respect to |m — nl, the bound
can still be quite large (i.e., the bound can be critically depends on the magnitude of v;) and thus
vacuous. In fact, if we treat all trigonometric functions as basis functions (i.e, ¢;(s) := #93), and
think about Eqn. 2 as basis expansion as the following:

where s is the positional span between a query and a key and h; := (ga; + igaj+1){k2j — tk2j+1)
are complex coefficients depending on q and k (here the definition of h; is exactly the same as the
definition of k; in Sec 3.4.3 in RoPE (Su et al., 2021)). Now the the issue becomes clear: as shown
in Fig. 2, a, can be small in magnitude in the range of [0, 2048], but gives huge values out of the
region. The underlying reason is that the trigonometric family {¢;} (with sufficiently large d) is
a universal approximator and can fit any arbitrary functions. Therefore, for a, there always exist
coefficients {h;} (i.e. key and query) that corresponds to small function values in [0, 2048] but

much larger in regions beyond.

2.3 PROPOSED APPROACH: POSITION INTERPOLATION (PI)

In Fig. 2, thanks to the smoothness of bases functions ¢; interpolation is much more stable and will
not lead to wild values. Therefore, instead of extrapolate the attention score in Eqn. 3 to s > L,
how about we define an attention score a{s) = a(Ls/L’) where L’ is the longer context window?
Formally, we replace RoPE f by {’ defined as follows

We call this transformation on the position encoding Position Interpolation. In this step, we reduce
position indices from [0, L') to [0, L) to match the original range of indices before computing RoPE.
Consequently, as inputs to RoPE, the maximum relative distance between any two tokens has been
reduced from I’ to L. Since we align the ranges of position indices and relative distances before
and after extension, we mitigate the effect on attention score computation due to context window
extensions, which can allow the model easier to adapt. To further demonstrate this is the case, in the
following theorem, we show that the interpolated attention score is well-behaved:

While there is no close form for B(s) := 4/21 |Ag41(s)|, numerically it is at least larger than d, and for many positional difference s, B(s) is much larger than d
(check Appendix B for the plot). Therefore, the interpolation bound is at least 2 - 294.73 ~ 600 x
smaller than the extrapolation bound, and thus the interpolated attention score is much more stable
than extrapolated one.

Notably, our method of rescaling of position indices does not introduce extra weight, or modify
the model architecture in any way. This makes it attractive in practical applications, since most
infrastructure and optimization for the original model can be reused after the extension.

Fine-tuning. We can further fine-tune the interpolated model using the next token prediction task
with interpolated position encodings on the extended context window size using a pre-training cor-
pus such as the Pile (Gao et al., 2020). In the next section, we show that our fine-tuning process
only needs tens to hundreds thousands of examples. We also find that the result of the fine-tuning
is not sensitive to the choice of examples. The reason may be that the model is only adapting to the
new context window during the fine-tuning phase, starting from a good initialization, as opposed to
acquiring new knowledge.

Other ways to reduce interpolation/extrapolation bound. From the expression of the interpola-
tion (Eqn. 5) and extrapolation bound (Eqn. 8), a common term is max; ||, which is the maximal
magnitude of query/key products. If we enforce a regularization on || during LLM training, it is
possible that the catastrophic extrapolation error can be mitigated or even resolved. In fact, if we
apply ridge regression with proper regularization to fit a curve in Fig. 2, the magnitude of extrapo-
lated a(s) when s > L can be comparable to that within [0, L]. To our knowledge, we are not aware
of existing LLM pre-training techniques that leverage this regularization and will leave it for future
work.

3 EXPERIMENTS

We show Position Interpolation can effectively extend context window up to 32 times of the original
size, and such extension can be done with only several hundreds of training steps. We show the
resulting models are strong LLMs with fully effective long context windows. We demonstrate its
performance in a number of tasks including language modeling, passkey retrieval, and long doc-
ument summarization. We also present benchmark results of the extended models on the original
LLaMA evaluation benchmarks.
3.1 SETUP

Model Variants. We extended the pre-trained 7B, 13B, 33B and 65B LLaMA models (Touvron
et al., 2023) to various context window of sizes up to 32768, using either direct fine-tuning or
Position Interpoloation method. Except for rescaling the position indices for models extended with
Position Interpolation, we did not modify LLaMA model architectures (Touvron et al., 2023) in any
ways.

Training Procedure. We fine-tune all model variants using the next token prediction objective. We
use AdamW (Loshchilov & Hutter, 2019) with 5; = 0.9 and 2 = 0.95. We use a linear learning
rate warmup of 20 steps starting from 10% of the maximum learning rate. For 7B and 13B models,
we set the learning rate to 2 x 1075 and for 33B and 65B models we set the learning rate to 1072. We
set the weight decay to zero. For extending 7B, 13B and 33B models to the 8192 context window
size, we use 32 A100 GPUs and 64 global batch size. For all other cases we use 128 A100 GPUs and
128 global batch size. We note that the main need of using more GPUs is memory limitation during
fine-tuning, and it is possible to use fewer GPUs in certain cases. We train all models using PyTorch
(Paszke et al., 2019) with Fully Sharded Data Parallel (Zhao et al., 2023) and Flash Attention (Dao
et al., 2022).

If not specified otherwise, for the Position Interpolation method, we fine-tune the models for 1000
steps. For the direct fine-tuning method, we use 10000 steps. We primarily fine-tune using the Pile
training dataset (Gao et al., 2020). In Section 3.4 we also compared fine-tuning performance on the
RedPajama dataset (Computer, 2023).

3.2 LONG SEQUENCE LANGUAGE MODELING

We evaluate the long sequence language modeling performance of our extended models and base-
lines on two datasets: book corpus (PG-19) (Rae et al., 2020) and cleaned Arxiv Math proof-pile
dataset (Azerbayev et al., 2022).

We use the test splits of PG19 (Rae et al., 2020) and proof-pile (Azerbayev et al., 2022). For PG19,
we use the whole test split consisting of 100 documents. For the proof-pile dataset, we use a random
subsample of 128 documents with at least 32768 SentencePiece (Kudo & Richardson, 2018) tokens
and truncate to the first 32768 tokens for each test document. We evaluate perplexity at various
context window size by using a sliding window approach following Press et al. (2022) with stride
S = 256.

In Table 1 and Table 2, we report the perplexity results for our models and baselines on the datasets.
From the results, we found that models extended with our method enjoy a significantly improved
perplexity from longer context window sizes. By increasing the context window size from 2048 to
16384, we observed -0.28 and -0.5 reductions of perplexity for extending LLaMA 7B models on
both datasets, -0.27 and -0.48 reductions for extending LL.aMA 13B models, and -0.14 and -0.42
reductions for extending LLaMA 33B models. For LLaMA 65B models, we observed -0.12 and
-0.3 reductions of perplexity by extending to the 8192 context window size.

In general, we observed a consistent trend of our models achieving better perplexity with longer
context windows. This indicates our models can effectively make use of the longer context windows
to better predict next tokens in language modeling tasks. Moreover, we found this trend extends to
32768 window size without diminishing on the PG19 dataset for LLaMA 7B and 13B models. This
indicates that our method may enable extension to even longer context windows.

In contrast, we observed that models extended via the direct fine-tuning method has shown regres-
sion (up to +0.48) or minor improvement (up to -0.12) on the perplexity at longer context windows.
This indicates that models extended this way have limited capability of making use of context win-
dows longer than their pre-trained settings.

We saw a minor degradation of the perplexity on the original context window of 2048 for our ex-
tended models in some cases. For example, on the Proof-pile dataset, we saw a degradation ranging
from 0.01 to 0.05 across all models with extended with Position Interpolation. A small degradation
of performance within original evaluation context window is expected since Position Interpolation
forces position encodings in original context window to reside in a much narrower region, which
may negatively affect the language model’s performance. We present more benchmark results on
the original context window size in Section 3.4.

In Table 3 we report the relationship between perplexity and the number of fine-tuning steps for
LLaMA 7B model extending to 8192 and 16384 context window sizes using Position Interpolation
evaluated on the PG19 dataset. We can see without fine-tuning (at step 0) the model can exhibit
certain language modeling capability, as indicated by < 20 perplexity for extending to 8192 context
window (in contrast, the direct extrapolation method leads to > 10% perplexity). With fine-tuning,
we observed that the perplexity improves quickly. At 200 steps the models surpassed the original
model’s perplexity on 2048 context window size, indicating the models gaining ability of effectively
using sequences longer than the pre-training settings for language modeling. At 1000 steps, we can
see the models have improved steadily and achieve a significantly better perplexity.

3.3 MEASURING EFFECTIVE CONTEXT WINDOW SIZE THROUGH PASSKEY RETRIEVAL

We study the effective context window size, i.e. the maximum distance of a token can effectively
attend to during inference, of our models after extension. To measure this, we follow a synthetic
evaluation task of passkey retrieval proposed by Mohtashami & Jaggi (2023). In this task, the models
are asked to recover a random passkey hidden in a long document. See Figure 3 for the format of
the document.

Given a language model, we estimate the upper and lower bounds of effective context windows as
follows. Suppose the random passkey is k tokens away from the end of the input. When a model
persistently fails to retrieve the correct passkey value across several independent attempts, it suggests
that the effective context window size of the model is less than k. Conversely, if a model consistently
succeeds in retrieving the correct passkey value, we deduce that the effective context window size
of the model is at least k.

We evaluate the 7B and 33B LLaMA model variants that are extended via Position Interpolation or
direct fine-tuning. For each model, we use 32 different &£ uniformly spaced in the targeted context
window L’ and run the above tests for 10 times for each k, where each time a random passkey of 5
random digits is used. In Table 4, we report kyax as a function of the number of fine-tuning steps,

We can see that models extended via Position Interpolation all successfully attain their desired ex-
tension objectives in terms of effective context window sizes, indicating by the effective context
window size reaching maximum kp, = L/, after merely fine-tuning for 200 steps, consistently
across both 7B and 33B model sizes and up to 32768 context windows. In contrast, LLLaMA models
that are extended via direct fine-tuning only saw a minimal increase of the effective context win-
dow size kay from 2048 to 2560, even after fine-tuning for more than 10000 steps, with no clear
indication of an acceleration in the increase of window size.

3.4 BENCHMARKS ON ORIGINAL CONTEXT WINDOW SIZE

We evaluate the models extended by Position Interpolation on several standard benchmark tasks
within the original context window size of 2048. The evaluation results are listed in Table 5. From
the results, we saw that models extended to 8192 produce comparable results on the original bench-
mark which is designed for a much smaller context window, with a degradation of up to 2% on
the benchmark tasks, for both 7B and 33B model sizes. Models extended to longer context win-
dows regressed more on the benchmarks, but still in reasonable ranges for most tasks. We also note
that the choice of fine-tuning datasets does not seem to lead significant difference in the benchmark
performances, which may be due to the limited number of fine-tuning steps used in our method.
The regression on benchmark tasks is consistent with our observation on perplexity regression in
Section 3.2.

3.5 LONG DOCUMENT SUMMARIZATION

In this task, we evaluate our models’ performance on the long document summarization task. In
particular, we consider the GovReport (Huang et al., 2021) dataset, which contains 17457 documents
for training and 972 documents for evaluation. Each document comes with a human generated
summary. We truncate all input documents to their first 15000 tokens.

We fine-tune the LL.aMA models extended with Position Interpolation with a context window of
16384. Note the rescaling of position indices are still required during this fine-tuning step. We first
Model Size Context Window Fine-tune on  BoolQ PIQA Race-M Race-H WinoGrande

format the raw document using the prompt template in Figure 4, and then concatenate the prompt
with the ground-truth summary (truncate to 1000 tokens) associated with each document. We fine-
tune the model using the next token prediction task with the above setup for 10 epochs. The losses
from the input prompt proportion of training examples are excluded during our fine-tuning.

We use a generation temperature of 0.5 and top, = 0.95 as our inference parameter to generate a
summarization of each document in the test set. The final output is truncated at 1000 tokens. We
used the ROUGE-1/ROUGE-2/ROUGE-L scores (Lin, 2004) as the evaluation metrics to evaluate
the models’ outputs vs the ground-truth summaries.

In Table 6 we report our evaluation results. We have also included results from two baselines in
existing SCROLLS Leaderboard (Shaham et al., 2022; Ainslie et al., 2023). In general, we have
obtained competitive R1 score among other models with minimal tuning of hyper-parameters. This
result suggests our models with 16384 context window can effectively handle the long document
summarization task.

=== END OF FILE ===

'''

If we place it in the following example

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
model = AutoModelForCausalLM.from_pretrained(
    "huggyllama/llama-7b",
    load_in_8bit=True,
    device_map="auto",
)

prompt = ...
question = "Question: What is the paper about?"

inputs = tokenizer(prompt + question, return_tensors="pt").to("cuda")

print(inputs.input_ids.shape)
gen_out = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.batch_decode(gen_out)[0])

we get:

Question: What is the paper about? a a a a a a a a a a a a b: a a a a a a: a in a a a a a a [(b a b. a [b [b [b. [b [b [( [( [( [( [( [( [b [(b [b [b
[b [(( [((: [(: [: [: [((((((0:(((((al:

However, if we add rope_scaling={"type": "dynamic", "factor": 2.0} in from_pretrained, we now get:

Question: What is the paper about?
Answer: The paper is about extending the context window of Transformer models.
Answer: The paper is about extending the context window of Transformer models.
Answer: The paper is about extending the context window of Transformer models.
Answer: The paper is about extending the context window of Transformer models.
Answer: The paper is about extending the context window of Transformer models.
Answer: The paper is about extending the context window of Transformer models.
Answer: The

Better generation parameterization can definitely be selected, but you get the idea -- with these changes, models with RoPE can handle much larger contexts right out of the box 🔥

@gante
Copy link
Member Author

gante commented Jul 4, 2023

(Of course, tests are missing. Proper validation of whether the feature is working as expected is also missing. I'll add them if we decide to move forward with this feature!)

@@ -67,6 +67,12 @@ class OpenLlamaConfig(PretrainedConfig):
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(changes in open_llama are copy/paste)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 4, 2023

The documentation is not available anymore as the PR was closed or merged.

@TheBloke
Copy link

TheBloke commented Jul 5, 2023

Having this in transformers would be excellent!

I've uploaded a bunch of fp16 and GPTQ repos to HF using @jquesnelle 's trust_remote_code Llama modelling patch that implements RoPE using @kaiokendev's method, and I know there are quite a number of people using those already, and I've had a few requests to put out more. And even more are using RoPE outside of transformers via the ExLlama GPTQ implementation.

So there's a great deal of appetite for this feature amongst users, understandably.

@versae
Copy link
Contributor

versae commented Jul 6, 2023

Could this also be applied to GPT-J models?

@gante
Copy link
Member Author

gante commented Jul 6, 2023

@versae yes, it can :) The code needs to be modified there as well, but the concept can be applied to any model with rotary position embeddings

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len

if self.scaling_name == "dynamic":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For dynamic, we want to do the scaling only when seq_len > max_position_embeddings (i.e. when we're going past the model's pre-trained length). My original code did this by just having the scaling in the forward() code that re-calculated the frequency cache when seq_len > self.max_seq_len_cached but not in the __init__. Since this code has now been deduplicated (makes sense!), I think this needs to be

if self.scaling_name == "dynamic" and seq_len > self.max_position_embeddings:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jquesnelle that's a great catch, and completely missed in my attempts to unify and refactor the cos/sin cache!

@bloc97
Copy link

bloc97 commented Jul 9, 2023

Thank you for your work! Just letting you know that I've improved the NTK-aware method in this PR. jquesnelle/yarn#1 It decreases non-finetuned PPL even further (preliminary testing shows 4.9 -> 4.6 PPL at 8192 context size) and theoretically will significantly improve a finetune's convergence/stability compared to previous NTK-aware method.

Also because the alpha hyperparameter was difficult to use when predicting effective context size (alpha=4 was something close to ~6400 context size instead of 8192), that problem was fixed and it is now changed to a "scale" factor, which can be used the same way to the "scale" in linear RoPE scaling. (eg. for LLaMA scale=2 is 4096 and scale=4 is 8192)

I hope this improved method might be also considered one day as it is one more step towards extending context size for all LLMs! 🚀

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as this can be used with all existing checkpoints, I have no problem adding this as a new feature. Not too sure about the API you picked as a dictionary though. Since there are two arguments, I would just have put two arguments as a dict can be complicated. Or if we expect more arguments to arrive, having some dataclass to hold them and accept both dict/dataclass at init (though this can also be done later since we are warning this is an experimental feature subject to changes).

And as was mentioned in the comments, this should also be added to other models with Rotary Embeddings :-)

@@ -64,6 +64,13 @@ class LlamaConfig(PretrainedConfig):
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Experimental feature -- dictionary containing the scaling configuration for the RoPE embeddings. Currently
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put "Experimental feature" as a warning at the end of the docstring.

This is an experimental feature, subject to breaking API changes in future versions.

@gante
Copy link
Member Author

gante commented Jul 12, 2023

Hey @bloc97 @jquesnelle 👋

Looking at your recent PR (this one) -- am I right in saying that

  1. There is no way to parameterize the new class such that it is equivalent to the original NTK-aware scaling?
  2. @bloc97's PR and @jquesnelle's dynamic implementation are slightly different, in the sense that @bloc97's targets a specific length (but can extrapolate) and @jquesnelle's dynamically adjusts to the maximum observed length?
  3. Because @jquesnelle's implementation base may suddenly change due to a longer sequence, it is less friendly to fine-tune?

I'm trying to determine how to integrate and document the goodies, while keeping the diff size manageable 🤗

@gante
Copy link
Member Author

gante commented Jul 12, 2023

The technique also seems to work out-of-the-box with GPTNeoX models 🔥 With the latest commit, running the script below

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b-deduped")
model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/pythia-1.4b-deduped",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    rope_scaling={"type": "dynamic", "factor": 2.0},
)

prompt = ...  # see PR header for the prompt, >5k tokens
question = "Question: What is the paper about?"

inputs = tokenizer(prompt + question, return_tensors="pt").to("cuda")

print(inputs.input_ids.shape)
gen_out = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.batch_decode(gen_out)[0])

gets us

Question: What is the paper about?

3.6 CONCLUSION

We have shown that Position Interpolation can extend the context window of pre-trained models to substantially longer context windows. We have
demonstrated that the models can be effectively extended to longer context windows, and

Without the rope_scaling argument, we get

Question: What is the paper about? The. The.
. The.
.
.
.
.
4.
. The. The. The. The. The. The. The.al.s. The. The.... a. The.

This is particularly amazing, since we're talking about a 1.4B model 👀

(cc @bloc97 @jquesnelle, you may be interested in this finding)

@gante
Copy link
Member Author

gante commented Jul 12, 2023

@amyeroberts @sgugger I'd like to request the review of you two, since this is an unusual post hoc modeling change PR.

A few key points:

  1. The technique also works well with gptneox, see this comment for a cool example on a 1.4B model
  2. Adding the functionality to gptneox implied a minor modeling change -- the causal mask was limited to the original maximum sequence size, but there is no reason for that limitation. It's just a triangular matrix with ones.
  3. Decided NOT to implement on gptneox-japanese and esm, the two other models with rotary embeddings. I'm not sure if their usage justifies the implementation cost (it takes some time to validate everything is working correctly, as there are variations in the expected usage), so I'd suggest letting demand speak for itself :)
  4. RoPE scaling is parameterized by a dict, and not a dataclass. A dataclass would be better, as @sgugger suggested, but it complicates (de)serialization, needing extra code. I'd like to first work on the config file base class I've mentioned on slack, if you're okay with it -- it would make the new dataclass a ~50 line change, as opposed to a >200 one!
  5. There are new scaling strategies in the works, as mentioned in the comments above, so we can quickly add them in follow up PRs if their results are superior. As it stands, we can already hack llama and gptneox beyond their original maximum length without fine-tuning 🔥

@gante gante changed the title Llama: add RoPE scaling Llama/GPTNeoX: add RoPE scaling Jul 12, 2023
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! As said before, still not a fan of the API but since you are planning to address this in a follow-up PR, that works for me.

@jquesnelle
Copy link

(For brevity, I'll refer to the new NTK-By-Parts method as NTKv2)

NTKv2 is an improved version of NTK. We found that NTK did not perform well when fine-tuned; the reason for this was that the resulting embedding matrix still contained some extrapolated out-of-population values that model had not seen during training. Dynamic NTK hid this by continually scaling base so that you never actually got to this part of the embedding values.

NTKv2 is parameterized by scale, which has the same meaning as linear interpolation, e.g. you set it to 4 to target 8K context length. We've found that this method, when fine-tuned, beats fine-tuned linear interpolation, which is to say it gives even better results than the recent Meta paper.

In the repository there is also now a Dynamic NTKv2, which is the same idea as the previous dynamic method, i.e. scale the hyperparamter relative to the ratio between the current context length and the model's original trained context length, while using the original embedding values when under the native pre-trained length. This also beats Dynamic NTK in the no-fine-tuning scenario.

image

In the above graph, LLongMA are the fine-tuned OpenLLaMA models we've released, trained on 1B extra tokens (v2 still in the process of training)

  1. There is no way to parameterize the new class such that it is equivalent to the original NTK-aware scaling?

Unfortunately no. I understand these different methods can get unwieldly quickly, but NTKv2 appears to be strictly better than original NTK -- I would potentially just advocate replacing the original NTK with this, but that could also be done in a follow-up PR too; the results that this gives you is already Very Good (TM).

FWIW the LLongMA models use the exact modeling code here to maintain compatibility without needing trust_remote_code if/when this PR gets merged 🙂

@bloc97
Copy link

bloc97 commented Jul 12, 2023

Hey @bloc97 @jquesnelle 👋

Looking at your recent PR (this one) -- am I right in saying that

  1. There is no way to parameterize the new class such that it is equivalent to the original NTK-aware scaling?
  2. @bloc97's PR and @jquesnelle's dynamic implementation are slightly different, in the sense that @bloc97's targets a specific length (but can extrapolate) and @jquesnelle's dynamically adjusts to the maximum observed length?
  3. Because @jquesnelle's implementation base may suddenly change due to a longer sequence, it is less friendly to fine-tune?

I'm trying to determine how to integrate and document the goodies, while keeping the diff size manageable 🤗

  1. Unfortunately "NTK v1" was just not good for fine-tuning unless alpha is set correctly, so I think going forward people should strictly use "v2" for fine-tuning, and consider v1 to be only for inference. However it is possible for me to parameterize the "v2" class so that you can make it equivalent to original NTK scaling, but it will take additional effort that is probably best used elsewhere. There are only few "NTK v1" finetunes are out there.
  2. For points 2 and 3, finetuning with Dynamic method will need additional consideration in the code on the training side, because training happens on all the tokens at once, dynamic implemented as is (for inference) will probably not be applied correctly. We are still working on the theoretical side of potentially training a dynamic model.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice! 🤩 Thanks for adding ❤️

+1 on @sgugger's comment on using a dataclass, also happy for it to be done in a follow-up.

@gante
Copy link
Member Author

gante commented Jul 13, 2023

@bloc97 @jquesnelle thank you for your input -- and excited to hear about the performance of NTK-By-Parts!

Based on your comments, I will:
1 - Delete the ntk approach, as NTK-By-Parts is superior;
2 - Merge what I have now -- we are going to have a release early next week, so this would already be included in v4.31;
3 - Open a follow-up PR with NTK-By-Parts 🤗 Or, if you're interested in contributing with the technique, we'd highly appreciate it! Just let me know over the next days.

⚠️ Note -- the latest commits have changed the structure of the modeling code from overloading the existing RoPE class to inheriting from the original implementation, so we don't risk ending up with a Frankenstein class as we add more strategies. The parameterization stayed nearly the same, so you probably only need to make minor adjustments to the model config files to load without trust_remote_code! (changed from {"name": scaling type, "factor": scaling factor} to {"type": scaling type, "factor": scaling factor}, as name is often attributed to an instance name in transformers)

@fahadh4ilyas
Copy link

Hi, I'm very glad to see that transformers supports RoPE scaling! Experiments show low ppl on long input sequences.

But in the current implementation, would there be a mismatch in generation? Here are my thoughts.

Since the seq_len increases during the generation, the base is scaled in every generation step with different scaling factor. Since the history key_states are store in the kv_cache , they are not scaled with the new base. The scaling only affects the state of the current token.

For example, if the input sequence is of length 2048, after generating the first token, the new input length is 2049, and we scale the base with seq_len=2049. After generating the second token, the new input length is 2050, and we scale the base with seq_len=2050. But during the generation, the kv_cache is used and thus the key_states before position 2049 are not scaled according to the new length.

Should all the key_states be scaled with the same base? Would it be a problem?

I have question similar to this. The graph showing dynamic scaling in this reddit post showing that the perplexity of the model with dynamic scaling are same with model without scaling until 2048 tokens length (Of course this must be because the base value did not change before 2048 tokens).

This got me thinking, If I first generate with long context (say 4096 tokens), the base value would change accordingly (which is around 35000). Then, if I next generate with short context like 1024 context, the sin_cache and cos_cache will not be reverted back when the base value still 10000 hence the perplexity is raised. Should there be changed to forward call especially for dynamic scaled embeddings?

@airaria
Copy link

airaria commented Jul 17, 2023

This got me thinking, If I first generate with long context (say 4096 tokens), the base value would change accordingly (which is around 35000). Then, if I next generate with short context like 1024 context, the sin_cache and cos_cache will not be reverted back when the base value still 10000 hence the perplexity is raised. Should there be changed to forward call especially for dynamic scaled embeddings?

I have the same concern. In the dynamic scaling, the sin and os may should not be cached

@guozhiyao
Copy link

guozhiyao commented Jul 17, 2023

Hi
I try to test ntk effect on my trained neox model. Using dynamic ntk(https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/). However, it is found that the ppl will oscillate. What is the reason for this?
image

here is my test code. I modified from https://huggingface.co/docs/transformers/perplexity .

import json
from transformers import AutoModel, AutoTokenizer, AutoConfig
import torch
from tqdm import tqdm
import traceback

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
config.rope_scaling = {
    "type": "dynamic",
    "factor": 2,
}

model = AutoModel.from_pretrained(model_dir, config=config, trust_remote_code=True, torch_dtype=torch.float16)
model.eval()
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)

with torch.inference_mode():
    kv = {}
    try:
        for value in tqdm(range(32, 12000, 32)):
            max_length = stride = value

            with open("gov_report_test.json") as f:
                data = json.load(f)

            ppls = []
            for idx, line in enumerate(data):
                if idx >= 1:
                    break
                encodings = tokenizer(line, return_tensors="pt")
                seq_len = encodings.input_ids.size(1)

                nlls = []
                prev_end_loc = 0
                for begin_loc in range(0, seq_len, stride):
                    end_loc = min(begin_loc + max_length, seq_len)
                    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
                    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
                    target_ids = input_ids.clone()
                    target_ids[:, :-trg_len] = -100

                    outputs = model(input_ids, labels=target_ids)

                    # loss is calculated using CrossEntropyLoss which averages over valid labels
                    # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
                    # to the left by 1.
                    neg_log_likelihood = outputs.loss

                    nlls.append(neg_log_likelihood)

                    prev_end_loc = end_loc
                    if end_loc == seq_len:
                        break

                ppl = torch.exp(torch.stack(nlls).mean())
                ppls.append(ppl)

            total_ppl = torch.stack(ppls).mean()
            kv[value] = total_ppl.item()
            print(value, total_ppl.item())
    except Exception as e:
        print(e)
        print(value, seq_len)
        print(traceback.format_exc())

@gante
Copy link
Member Author

gante commented Jul 17, 2023

@guozhiyao Nothing immediately comes to mind, it could be even a model "feature" (looking at the plot for the original model, which also has the periodicity).

Would you be able to a) run the same script for LLaMA and b) repeat your experiment using the script @jquesnelle used (this one)? a) should rule out model-specific issues and b) should rule out code-specific issues.

@guozhiyao
Copy link

@guozhiyao Nothing immediately comes to mind, it could be even a model "feature" (looking at the plot for the original model, which also has the periodicity).

Would you be able to a) run the same script for LLaMA and b) repeat your experiment using the script @jquesnelle used (this one)? a) should rule out model-specific issues and b) should rule out code-specific issues.

@gante Thanks a lot. It is solved by using the code.

@guozhiyao
Copy link

This got me thinking, If I first generate with long context (say 4096 tokens), the base value would change accordingly (which is around 35000). Then, if I next generate with short context like 1024 context, the sin_cache and cos_cache will not be reverted back when the base value still 10000 hence the perplexity is raised. Should there be changed to forward call especially for dynamic scaled embeddings?

I have the same concern. In the dynamic scaling, the sin and os may should not be cached

@airaria I had the same problem, not only cos and sin, inv_freq also don't cache. The _set_cos_sin_cache of GPTNeoXDynamicNTKScalingRotaryEmbedding can be changed to the following form, but the efficiency is not optimized.

    def _set_cos_sin_cache(self, seq_len, device):
        self.max_seq_len_cached = 0

        base = self.base
        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))

        inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq)

        t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

@guozhiyao
Copy link

This got me thinking, If I first generate with long context (say 4096 tokens), the base value would change accordingly (which is around 35000). Then, if I next generate with short context like 1024 context, the sin_cache and cos_cache will not be reverted back when the base value still 10000 hence the perplexity is raised. Should there be changed to forward call especially for dynamic scaled embeddings?

I have the same concern. In the dynamic scaling, the sin and os may should not be cached

@airaria I had the same problem, not only cos and sin, inv_freq also don't cache. The _set_cos_sin_cache of GPTNeoXDynamicNTKScalingRotaryEmbedding can be changed to the following form, but the efficiency is not optimized.

    def _set_cos_sin_cache(self, seq_len, device):
        self.max_seq_len_cached = 0

        base = self.base
        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))

        inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq)

        t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

There is a precision difference between the inv_freq here and the inv_freq defined in __init__, and the reason is not found. In order to ensure the same performance as the original when seq_len <= self.max_position_embeddings, it can only be modified to this form.

    def _set_cos_sin_cache(self, seq_len, device):
        self.max_seq_len_cached = 0

        if seq_len > self.max_position_embeddings:
            base = self.base * (
                    (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        else:
            inv_freq = self.inv_freq

        t = torch.arange(max(seq_len, self.max_position_embeddings), device=device, dtype=inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

Narsil added a commit to huggingface/text-generation-inference that referenced this pull request Jul 31, 2023
# What does this PR do?


- Adds Rope NTK scaling.

Done because
#529 was
closed
Took some code from
huggingface/transformers#24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).


Fixes #512




<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
oobabooga added a commit to oobabooga/text-generation-webui that referenced this pull request Aug 9, 2023
@gante gante mentioned this pull request Aug 30, 2023
verdant621 pushed a commit to verdant621/text-generation-inference that referenced this pull request Oct 19, 2023
# What does this PR do?


- Adds Rope NTK scaling.

Done because
huggingface/text-generation-inference#529 was
closed
Took some code from
huggingface/transformers#24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).


Fixes #512




<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* add rope_scaling

* tmp commit

* add gptneox

* add tests

* GPTNeoX can now handle long inputs, so the pipeline test was wrong

* Update src/transformers/models/open_llama/configuration_open_llama.py

Co-authored-by: amyeroberts <[email protected]>

* remove ntk

* remove redundant validation

---------

Co-authored-by: amyeroberts <[email protected]>
cr313 added a commit to cr313/text-generation-inference-load-test that referenced this pull request Apr 19, 2024
# What does this PR do?


- Adds Rope NTK scaling.

Done because
huggingface/text-generation-inference#529 was
closed
Took some code from
huggingface/transformers#24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).


Fixes #512




<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
alfredgui2 pushed a commit to mlsys-io/kv.run that referenced this pull request Jul 6, 2024
# What does this PR do?


- Adds Rope NTK scaling.

Done because
huggingface/text-generation-inference#529 was
closed
Took some code from
huggingface/transformers#24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).


Fixes #512




<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
tjluyao added a commit to mlsys-io/kv.run that referenced this pull request Jul 7, 2024
Init

fix: cleanup

Add load testing

Refactored gRPC interface
Added validation logic

ValidationError was not correctly handled

Use axum

feat: Docker image

feat: Add AML deployment

Update aml deployment

feat: Improve error handling

feat: Add arguments to CLI

v0.1.0

fix(validation): Fix error messages

feat(router): Add max_waiting_tokens

Create LICENSE (#2)

feat(server): Use safetensors

Co-authored-by: OlivierDehaene <[email protected]>

feat(client): Simplify sharded logic

feat(server): Support bitsandbytes

feat(server): Support all AutoModelForCausalLM on a best effort basis

feat: Use json formatter by default in docker image

fix(models): Revert buggy support for AutoModel

feat(server): Support generic AutoModelForCausalLM

feat(server): Support AutoModelForSeq2SeqLM

feat(launcher): Pass CUDA_VISIBLE_DEVICES to the shard

feat(server): Improved doc

fix(server): Fix Transformers fork version

feat(server): Clarify CausalLMBatch concatenate method

feat(rust): Update to 1.65

fix(router): Fix HTTP status codes

fix(readme): Typo

fix(router): Handle tokenizer errors

feat(server): Support Galactica (#4)

fix(batching): Avoid theoretical hang in batcher loop (#5)

- Avoid theoretical hang in batcher loop
- Avoid a couple of clones in the router generate method
- Keep attention mask tensors as integers
- Remove num_heads attribute

Co-authored-by: OlivierDehaene <[email protected]>

feat(server): Add model tests (#6)

fix(server): Only pad to multiple of 8 on GPUs

feat: Support stop sequences (#7)

feat: Return logprobs (#8)

feat(launcher): Add integration tests (#9)

fix(server): Fix stop sequences (#11)

fix(server): Check for device type correctly when determining initial padding (#16)

AFAIK there is no torch device type called "gpu".

fix(router): Include special tokens when tokenizing (#14)

There's currently a discrepancy in the tokenization between the router
and python server code. The latter includes special tokens but former
does not.

This results in a token count mismatch for seq2seq models such as mt0
where the tokenizer emits an EOS token at the end.

This in turn results in some unexpected/incorrect output, in particular
when batch concatenation is involved, because the python code uses the
input length passed from the router for each row.

As far as I can tell, it is better to include this token in the encoder
`input_ids`, so I guess it's best to just adjust on the router side.

feat(router): Add const parameters to validation logic  (#15)

I noticed some opportunity to collapse some of the logic, in case you
are interested.

fix(server): Use cleanup_tokenization_spaces=False for lossless decoding (#13)

Fixes #12 in the easiest way I could think of.

feat(launcher): Log server stdout (#19)

Co-authored-by: Nick Hill <[email protected]>

fix(server): Minor refactorization using new_zeros (#24)

- Fix some type hints, in particular base tokenizer class
- Make use of `tensor.new_zero/empty` methods
- Simplify env var string parsing in launcher

fix(router): Obey max batch size (#23)

feat(server): Support SantaCoder (#26)

fix(server): Fix position ids (#28)

feat(docker): Make the image compatible with api-inference (#29)

fix(docker): fix api-inference deployment (#30)

fix(router): fix api-inference deployment (#31)

fix(dockerfile): fix docker build (#32)

feat(bloom): use torch.nn.Linear and torch.nn.GELU (#33)

feat(router): Remove second lock from batcher hot path (#27)

@njhill

feat: Support sampling seeding (#37)

Co-authored-by: Yannic Kilcher <[email protected]>

feat: Add token streaming using ServerSideEvents support (#36)

Add token streaming using ServerSideEvents (SSE).

The signature of the SSE events is:

```rust
struct Details {
    finish_reason: String,
    generated_tokens: u32,
    seed: Option<u64>,
}

struct StreamResponse {
    token: Token,
    generated_text: Option<String>,
    details: Option<Details>,
}

struct ErrorResponse {
    error: String,
}
```

Revert "feat: Add token streaming using ServerSideEvents support" (#40)

Reverts huggingface/text-generation-inference#36

fix(server): fix seeding on gpu (#42)

fix(server): fix seeding with multiple shards (#44)

feat: Add token streaming using ServerSideEvents support (#41)

fix(server): fix quantization for sharded models (#45)

feat(server): Support GPT-Neox (#39)

feat(ci): Docker build and push (#46)

feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)

feat(server): support repetition penalty (#47)

feat(server): allow the server to use a local weight cache (#49)

fix(server): allow greedy repetition penalty (#51)

feat(router): use background task to manage request queue (#52)

Co-authored-by: Nick Hill <[email protected]>

breaking(router): modify /generate API to only return generated text (#50)

@njhill, @yk FYI

generated_text was concatenated to the user prompt for legacy reason. We
want to remove this behaviour as we don't think it is useful and even
detrimonial to usability.

We also remove the unused Vec.

feat(router): refactor API and add openAPI schemas (#53)

feat(docs): Clarify installation steps (#54)

Adds some bits for first-time users (like me 😄 )

feat(ci): push to AML registry (#56)

fix(server): better handling of inference mode (#57)

V0.2.1 (#58)

feat(server): support t5 (#59)

fix(docker): increase shm size (#60)

fixed SSE naming (#61)

https://en.wikipedia.org/wiki/Server-sent_events

feat: add distributed tracing (#62)

feat: add safetensors conversion (#63)

feat(server): improve download logging (#66)

feat(launcher): add disable_custom_kernels arg (#67)

feat(router): add max_total_tokens and empty_input validation (#68)

closes #65

fix(launcher): copy current env vars to subprocesses (#70)

closes #69

feat(router): add prometheus metrics scrape endpoint (#71)

v0.3.0 (#72)

feat(router): add cors allow origin options (#73)

feat(server): enable hf-transfer (#76)

fix(server): remove position_ids from galactica forward (#82)

closes #80

feat(server): pre-allocate max attention mask (#75)

v0.3.1 (#84)

feat(server): add special token bool (#85)

fix(docs): fix openapi schema (#86)

fix(server): fix token_is_special (#87)

feat(router): add legacy route for api-inference support (#88)

feat(router): ask hf.co for pipelinetag to decide on compat_return_full_text (#89)

feat(router): add api-inference headers (#91)

feat(server): add logits watermark (#90)

feat(server): update to hf_transfer==0.1.2 (#93)

feat(ci): improve CI speed (#94)

fix(launcher): add router parameters to launcher (#95)

feat(server): fix transformers commit (#96)

v0.3.2 (#97)

fix(server): fix generate_stream by forcing tokens to be decoded correctly (#100)

feat: allow local models (#101)

closes #99

feat: add supported models (#102)

feat(clients): Python client (#103)

fix(server): fix galactica batch (#106)

closes #105

feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES (#107)

feat(launcher): default num_shard to CUDA_VISIBLE_DEVICES if possible (#108)

fix(python-client): stream not set on the sync client (#109)

fix(server): fix index out of range for watermarking (#110)

feat: support typical sampling (#114)

closes #112

fix(server): do not warp prefill logits (#116)

feat(router): support left truncation (#115)

closes #111

feat(router): add best_of parameter (#117)

feat(python-client): add new parameters (#118)

v0.4.0 (#119)

feat: add OpenAssistant/oasst-sft-1-pythia-12b to the list of supported models (#122)

…ed models

fix(server): revert gpt-neox optims (#123)

fix(server): add position ids to neox (#126)

fix(server): use server tokenizer as gt (#128)

fix(python-client): relax dependencies (#129)

feat(python-client): add cookies to Client constructors and requests (#132)

I have a use case where we need to pass cookies (for auth reasons) to an
internally hosted server.

Note: I couldn't get the client tests to pass - do you need to have an
HF token?

```python
FAILED tests/test_client.py::test_generate - text_generation.errors.BadRequestError: Authorization header is correct, but the token seems invalid
```

feat(ci): add ci paths (#134)

feat: Add note about NVIDIA drivers (#64)

Co-authored-by: OlivierDehaene <[email protected]>

feat(python-client): release v0.4.0 (#135)

feat(python-client): add CI (#136)

feat(server): flash neoX (#133)

fix(server): fix flash-neox scores warping (#137)

feat(server): cleanup flash neox loading (#139)

v0.4.1 (#140)

fix(server): Avoid using try/except to determine kind of AutoModel (#142)

feat(server): Add mypy-protobuf (#141)

Generates .pyi files for protobuf stubs which provide strong typing
information. Very helpful for IDE auto-completion, etc.

feat(server): clear cache on error (#143)

feat(server): reduce mlp and attn in one op for flash neox (#145)

feat: aws sagemaker compatible image (#147)

The only difference is that now it pushes to
registry.internal.huggingface.tech/api-inference/community/text-generation-inference/sagemaker:...
instead of
registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sagemaker-...

---------

Co-authored-by: Philipp Schmid <[email protected]>

fix(ci): fix sagemaker action (#148)

feat(benchmark): tui based benchmarking tool (#149)

fix(server): fix flash neox rotary embeddings (#150)

v0.4.2 (#151)

v0.4.3 (#152)

feat(server): flash santacoder (#153)

docs(readme): provide link Logits Warper README (#154)

fix(server): fix escape characters in stop sequence (#155)

feat(docker): improve flash_attention caching (#160)

feat(launcher): allow disabling hf_transfer (#161)

fix(rust-client): use join_all instead of select_all to hopefully fix nccl issues (#162)

fix(router): use buckets for metrics histograms (#163)

feat(router): make router input validation optional (#164)

feat(server): add flash attention llama (#144)

feat(server): support OPT models (#55)

OPT models do not all have a `tokenizer.json` file on the hub at the
moment. Can't merge for now.

v0.5.0 (#168)

feat(server): optimize decode for sane tokenizers (#170)

feat(server): support sharded santacoder (#167)

fix(launcher): revert change on shard errors (#173)

fix(ci): fix CVE in github-slug-action (#174)

feat(ci): add image signing with cosign (#175)

feat(ci): add Trivy and scan docker image (#178)

feat(ci): use large runners (#179)

feat(ci): faster scanning (#180)

fix(ci): fix ci permissions (#181)

fea(dockerfile): better layer caching (#159)

fix(ci): fix cosign error (#183)

fix(docker): fix docker image (#184)

fix(docker): fix image (#185)

fix(docker): revert dockerfile changes (#186)

fix(docker): fix docker image dependencies (#187)

fix(router): fix truncation (#190)

closes #189

feat(python-client): get list of currently deployed tgi models using the inference API (#191)

feat(router): add info route (#196)

close #125

feat(server): support quantization for flash models (#200)

closes #197

feat(server): check cuda capability when importing flash models (#201)

close #198

fix(server): fix hf_transfer issue with private repos (#203)

fix(docker): remove unused dependencies (#205)

fix(router): add auth token to get model info (#207)

feat(router): add git sha to info route (#208)

feat(router): drop requests when client closes the channel (#202)

fix(ci): fix sha in docker image (#212)

feat(server): flash attention past key value optimizations (#213)

feat(router): add device and dtype info (#215)

fix(server): fix past key values logic (#216)

@njhill fyi

fix(server): cleanup new flash past_key_values logic (#217)

fix(server): fix flash causal (#218)

fix(server): fix flash causal (#219)

fix(server): fix flash batch filtering (#220)

misc: update to rust 1.69 (#221)

v0.6.0 (#222)

feat(server): reduce memory requirement (#214)

chore(server): update huggingface-hub (#227)

feat(router): use number of tokens in batch as input for dynamic batching (#226)

Co-authored-by: Nick Hill <[email protected]>

feat(router): add endpoint info to /info route (#228)

chore(server): update safetensors version (#235)

fix(python-client): add auth headers to is supported requests (#234)

Starting some routing tests. (#233)

fix(benchmarking): fix benchmarking tool

chore(launcher): refactor logic (#242)

Hopefully it's cleaner

feat(router): add tests to validation (#237)

feat(router): new healthcheck that skips the queue (#244)

Co-authored-by: OlivierDehaene <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

fix(server): fix reshaping of bloom past_key_values in concatenate() (#252)

Introduced in #214

Fixes #249

fix(server): Small tidy of code from recent changes (#251)

remaining_decode_tokens was calculated twice in Seq2SeqLMBatch.filter()

chore(server): update transformers (#250)

feat(server): add watermarking tests (#248)

feat(docker): add nvidia env vars (#255)

doc(launcher): add more docs to the `launcher` itself and link in the README (#257)

feat(benchmark): add support for private tokenizers (#262)

Adding docs on how dynamic batching works. (#258)

This PR starts the minimal possible amount of explanation I could think
of. It tries to explain how dynamic batching occurs, the interactions
with past key values and ignores the padding problem.

Maybe some drawings could help too but I kept it to text for now.

chore(github): add templates (#264)

fix(server): fix typo in tokenizers decode (#269)

closes #268

feat(server): support hf endpoint weight layout (#266)

fix(launcher): pass weights cache override to the download process (#274)

closes #273

fix(launcher): handle hub branches (#278)

fix(server): Removes the parallelism in file convertion (during download) (#275)

feat(launcher): Improve error message when download process fails. (#276)

fix(server): fix convert (#284)

chore: add `flash-attention` to docker ignore (#287)

included when building docker locally.
(Where the local dirs might have the flash-attention folder.)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

fea(server): decrease convert RAM requirements (#286)

fix(dockerfile): fix nvidia env vars (#297)

Fixes #291

feat(router): Adding response schema for compat_generate (#292)

feat(docker): add benchmarking tool to docker image (#298)

fix(docker): fix docker build (#299)

feat(server): optim flash causal lm decode_token (#285)

fix(docker): fix nvidia env vars (#305)

fix(docker): remove nvidia require cuda env (#310)

feat(server): shard token decode (#303)

feat(server): use float16 (#304)

fix(docker): remove CUDA_VERSION

feat(server): use cuda graph in logits warping (#302)

fix(server): fix multinomial implem in Sampling

feat(server): GPTQ quantization (step1) (#277)

Changes only the type from `bool` to `Option<Enum>` pretty much
everywhere.
- Use `Optional[str]` in Python (easier to manage than importing type
everywhere). Except for the cli to get proper validation
- Updated all models to handle gracefully new values. (Error out if
unknown value, or gptq since not implemented).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

chore(docker): use nvidia base image (#318)

fix(docker): remove quantize default

fix(docker): use ubuntu20.04

Hotfixes for santacoder/bigcode. (#294)

Hotfixes:

- Uses `model_type`=`gpt_bigcode` for more general usage.
- Hotfixes linked lm_head vs wte_embedding (safetensors file do not
contain the key, correctly when the file is sharded, where as pytorch
copies the tensor)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

Lifting check_unitialized. (#325)

Lifting check_unitialized.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

Removing dead variables. (#327)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat(ci): custom gpu runners (#328)

Single place for TP layers + Dropout Layer Norm + FastLinear (#329)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat: add snapshot testing (#282)

feat(integration-tests): improve comparison and health checks (#336)

fix(server): fix decode token (#334)

Fixes #333

---------

Co-authored-by: Nicolas Patry <[email protected]>

fix: set MODEL_ID in sagemaker-entrypoint script (#343)

feat(server): Support BLOOMChat-176B (#348) (#351)

@njhill,
temporary workaround to be able to run our CI as secrets are not
available to runners run by external contributors. I will ask around to
see if there is a better way.

Co-authored-by: Nick Hill <[email protected]>

fix(server): fix init for flash causal lm (#352)

Fixes #347

fix(server): t5 cannot run in f16 (#356)

Fix #349

fix(ci): fix security group (#359)

Switch security group used for ci
(open outbound rules)

Signed-off-by: Raphael <[email protected]>
Co-authored-by: Raphael <[email protected]>

feat: add nightly load testing (#358)

chore(sever): update requirements (#357)

Fixes #338

feat(server): support fp16 for t5 (#360)

Fixes #349

feat(server): do not use device_map auto on single GPU (#362)

feat(server): support trust_remote_code (#363)

feat(router): log input/ouput at debug level (#364)

@njhill FYI

v0.7.0 (#353)

feat: decrease IPC proto size (#367)

Closes #307 #308

feat(benchmarker): add summary tables (#368)

feat(server): support vectorized warpers in flash causal lm (#317)

Co-authored-by: Joel Lamy-Poirier <[email protected]>

Fix issue when load AutoModelForSeq2SeqLM model (#370)

fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

fix(server): fix quantization

feat(server): support RefinedWeb models (#379)

v0.8.0

increase health checks

feat(server): add retry on download (#384)

fix(server): fix bnb quantization for CausalLM models (#385)

v0.8.1

fix(server): fix has_position_ids (#395)

Fix #389

feat(server): remove trust_remote_code requirement for falcon models (#396)

feat(server): load santacoder/starcoder models with safetensors (#393)

Fix #366

v0.8.2

feat(sagemaker): add trust remote code to entrypoint (#394)

feat(launcher): parse oom signal (#404)

feat(server): only compute prefill logprobs when asked (#406)

Close #288

feat(server): batch tokenization for flash causal lm (#411)

chore: update openapi schema

feat(server): Rework model loading (#344)

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

feat(server): optimize dist ops (#434)

docs(launcher): fix CUDA_VISIBLE_DEVICES helper comment (#441)

It solves a typo in the comment sections referencing the environment
variable `CUDA_VISIBLE_DEVICES`. No misspelling references to this
variable have been found in code logic leading to undefined behaviour or
bugs. This PR is not expected to perform any code logic modification.

fix(makefile): Fix typo and use POSIX comparison in the makefile (#443)

This PR fixes:
- The usage of non posix comparison which may fail depending on the
shell used (`=` will always work, `==` only with bash)
- Typo in the env variable name displayed in the error message
`BUILD_EXTENSION` instead of `BUILD_EXTENSIONS`

<!-- Remove if not applicable -->

Fixes #422

feat(server): pre-allocate past key values for flash causal LM (#412)

feat(router): add ngrok integration (#453)

feat(server): improve flash attention import errors (#465)

@lewtun, is this enough?

Closes #458
Closes #456

fix(server): fix warpers on CPU (#472)

Closes #471

fix(server): Fixing T5 in case the names are mixed up. (#475)

feat(server): Update convert logic. (#483)

Should be more robust to shared tensors (ok when using
      `from_pretrained). But forcing us to add new checks in our loading
      code (since the chosen key to keep might be different from
      `transformers`).

---------

Co-authored-by: Ubuntu <[email protected]>

feat(server): Adding new ignore_rule for conversion. (#485)

fix(router): add timeout on flume sends (#488)

feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)

Let's start discussing implementation.

- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).

Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.

My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

fix(server): Do not init process group if already initialized (#388)

feat(router): add header option to disable buffering for the generate_stream response (#498)

generate_stream endpoint response stream.

Problem: If a model is run behind a proxy server such as nginx that has
buffering enabled then the response stream from generate_stream gets
aggregated into a single response which basically disables streaming.
Instead of getting a chunked response where each token is presented over
time the response presents everything all at once.

Solution: This change adds the `X-Accel-Buffering` http header which
disables buffering for the generate_stream response, allowing the
response to stream properly.

feat(server): add paged attention to flash models (#516)

Closes #478

feat(router): arg validation (#519)

feat: Add the option to force another dtype than `f16`. (#513)

fix(launcher): fix issue where launcher does not properly report shard failures (#522)

v0.9.0 (#525)

feat(server): Add Non flash MPT. (#514)

This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
https://github.com/huggingface/text-generation-inference/issues/361
Fixes
https://github.com/huggingface/text-generation-inference/issues/491
Fixes
https://github.com/huggingface/text-generation-inference/issues/290

fix: Update server/Makefile to include Makefile-vllm (#520)

For consistency and ease of use (you can just run `make` to install vllm
without any extra steps).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs(benchmarker): Adding some help for the options in `text-generation-benchmark`. (#462)

fix(server): Handle loading from local files for MPT (#534)

This PR allows the MPT model to be loaded from local files. Without this
change, an exception will be thrown by `hf_hub_download` function if
`model_id` is a local path.

fix(server): avoid errors for very small top_p values (#544)

See https://github.com/huggingface/transformers/pull/24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.

feat(server): use latest flash attention commit (#543)

@njhill FYI

feat(router): add argument for hostname in router (#545) (#550)

In title. Adds argument `--hostname` in router to support something like
`--hostname ::`. Tested with

```commandline
cargo run -- --port 8080 --hostname ::
curl -I -X GET 'http://[::1]:8080/health'  # failed before this commit
```

Trigger CI

---------

Co-authored-by: Phil Chen <[email protected]>

fix(server): decrease memory fragmentation (#557)

v0.9.1 (#558)

fix(server): harden the weights choice to save on disk. (#561)

- Look at `transformers` base class to check for
  `_key_to_ignore_on_load_missing` or `_tied_weights` which are the
  standard attributes to select the keys to NOT save on disk (since they
  are ignored)

- Modified safetensors code (to be reflected in safetensors even if it's
  an internal function).

- Will not work for trust_remote_code=True repos (like santacoder).

Should help with :
https://github.com/huggingface/text-generation-inference/issues/555
and : https://github.com/huggingface/text-generation-inference/pull/501
and https://github.com/huggingface/text-generation-inference/issues/556
and
https://github.com/huggingface/text-generation-inference/issues/482#issuecomment-1623713593

feat: better errors for warmup and TP (#575)

Close #571

fix(server): Fixing RW code (it's remote code so the Arch checking doesn't work to see which weights to keep). (#579)

Fixes #555

feat(server): Support for env value for GPTQ_BITS and GPTQ_GROUPSIZE. (#580)

Some models are already converted, and do not have those values in the
file, this enables users to use them with less friction.

Went for pure env based because adding flags would end up (imo) very
tedious to maintain. There's a lot of sanitation to do: those flags
would be errors if not used in conjuction with `--quantize gptq`.
Then the flags need to exist in the launcher and the server passing them
all throughout all function calls.

This PR is intended as an easy escape hatch, not the defacto method to
use gptq in TGI.

Fixes #500

chore: migrate ci region for more availability. (#581)

fix(server): T5 weights names. (#582)

Fixes #541

fix(server): Adding logger import to t5_modeling.py (#585)

Logger is referenced during the apex importing but is not imported,
causing a NameError

fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590)

This fixes a typo and extends the GPTP_BITS environment variables
through to the second method which requires the same logic. Please let
me know if there's anything I've misunderstood in this change.

Thanks @Narsil for the original fix.

feat(server): Implements sharding for non divisible `vocab_size`. (#583)

- The code is relatively easy (just disable the checks on Embedding and
Head)

This cannot be done in the same easy fashion for hidden_dim/head_dim.
It's relatively easy on some models (classic MHA) but it would make the
other
models (MQA) much more complex, and GPTQ quantization another quite
hairy piece
of code.

feat(server): empty cache on errors

GPTQ Env vars: catch correct type of error (#596)

When passing in environment variables like gptq_bits, we still get
errors thrown from TGI because the try/catch block is catching the wrong
type of error. This PR aims to fix that.

@Narsil - let me know if this is how you want this formatted. My Python
is a little shaky, so I hope this syntax is correct.

feat(launcher): add arg validation and drop subprocess (#595)

feat(router): explicit warning if revision is not set (#608)

docs: README: Add logo + baseline (#611)

![image](https://github.com/huggingface/text-generation-inference/assets/3841370/58177321-479f-4ad1-b3bc-cec027423984)

fix(server): blacklist local files (#609)

Close #589 #602

v0.9.2 (#616)

fix(server): empty_cache when stopped

fix(launcher): Rename `b-float16` to `bfloat16` in the launcher arg (#621)

fea(launcher): debug logs (#623)

feat(server): Reworking the quantization script so it's still universal (not llama specific) (#587)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).

Reworking the quantization script so it's still universal (not llama
specific)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).

Still need to investigate the potential differences in quantization
results.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat(server): flash attention v2 (#624)

feat(server): add support for llamav2 (#633)

v0.9.3 (#634)

fix(server): fix llamav2 config (#635)

feat(server): auto max_batch_total_tokens for flash att models (#630)

feat(router): ngrok edge (#642)

docs: Update README.md (#639)

docs: Update README.md (#643)

Add trust_remote_code to quantize script (#647)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes a bug appeared with MR #587 fixing issue #552.
See the discussion in #552.

With MR #587 the trust_remote_code variable is not passed to
AutoModelForCausalLM, but is found in the function signature. This
prevents models like falcon to be quantized, because trust_remote_code
is required. This MR fixes the issue.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [X] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
@Narsil
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

 -->

fix(server): llama v2 GPTQ (#648)

As per title & reported
https://github.com/huggingface/text-generation-inference/issues/601#issuecomment-1641435956
https://huggingface.co/TheBloke/Llama-2-70B-chat-GPTQ/discussions/5

Test it:

```
GPTQ_BITS=4 GPTQ_GROUPSIZE=1 text-generation-launcher --model-id TheBloke/Llama-2-70B-chat-GPTQ --port 8080 --num-shard 4 --quantize gptq
```
&
```
curl 127.0.0.1:8080/generate \
    -X POST \
    -d '{"inputs":"hey llama","parameters":{"max_new_tokens":256}}' \
    -H 'Content-Type: application/json'
```

fix(server): Fixing non parameters in quantize script `bigcode/starcoder` was an example. (#661)

fix(server): use mem_get_info to get kv cache size (#664)

Close
https://github.com/huggingface/text-generation-inference/issues/649
Close
https://github.com/huggingface/text-generation-inference/issues/651
Close
https://github.com/huggingface/text-generation-inference/issues/653
Close #636

feat(server): Add exllama GPTQ CUDA kernel support #553 (#666)

Just trying to get the integration tests to pass.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Felix Marty <[email protected]>

Directly load GPTBigCode to specified device (#618)

This PR directly load GPTBigCode to specified device, avoiding moving
model between devices.

This PR directly load GPTBigCode to specified device, avoiding moving
model between devices.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

@OlivierDehaene OR @Narsil

feat(server): add local prom and health routes if running w/ ngrok

feat: add cuda memory fraction (#659)

Close #673

fix(server): fix exllama buffers (#689)

Close #683

feat(server): Using `quantize_config.json` instead of GPTQ_BITS env variables. (#671)

- Current PR is not great because we're side stepping the
  `Weights.__init__` but Weights shouldn't requires anything related
  to the config or the model_id as it aims to be a simple Wrapper
  over multi file loading.
- Ideal solution would be to use something like Rust enum
  ```
  enum Quantize{
    Bitandbytes(Bitsandbytes),
    GPTQ(bits: usize, groupsize: usize)
  ```
  And passing that around during load. Unfortunately we don't
  have access to this, so for now, side-stepping seems easier.

- Re-enabling groupsize<0 with exllama (confirmed it works.)

Helps #601

In next steps we should make sure our quantization script uses that
format and make it standard.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs(README): update readme

fix(server): fix quantization python requirements (#708)

fix(server): fix missing datasets in quantize

feat(server): support new falcon config (#712)

v0.9.4 (#713)

Add section about TGI on other AI hardware accelerators in README (#715)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

As per title.

- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs: Add hardware section to TOC in README (#721)

feat(server): update vllm version (#723)

chore: update license to HFOIL (#725)

v1.0.0 (#727)

Local gptq support. (#738)

Redoes #719

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

Fix typing in `Model.generate_token` (#733)

This PR fixes a minor type annotation issue in the signature of
`Model.generate_token`.

All existing overrides of `Model.generate_token` return
`Tuple[List[Generation], Optional[B]]`:

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/causal_lm.py#L535-L537

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/flash_causal_lm.py#L802-L804

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/seq2seq_lm.py#L589-L591

I suspect that back in 017a2a8c when `GeneratedText` and `Generation`
were separated, the function signature was not updated.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

CC @OlivierDehaene

Adding Rope scaling. (#741)

- Adds Rope NTK scaling.

Done because
https://github.com/huggingface/text-generation-inference/pull/529 was
closed
Took some code from
https://github.com/huggingface/transformers/pull/24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).

Fixes #512

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

chore: fix typo in mpt_modeling.py (#737)

Fixed typo.
<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

implemetation -> implementation

- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update…
tjluyao added a commit to mlsys-io/kv.run that referenced this pull request Jul 7, 2024
Init

fix: cleanup

Add load testing

Refactored gRPC interface
Added validation logic

ValidationError was not correctly handled

Use axum

feat: Docker image

feat: Add AML deployment

Update aml deployment

feat: Improve error handling

feat: Add arguments to CLI

v0.1.0

fix(validation): Fix error messages

feat(router): Add max_waiting_tokens

Create LICENSE (#2)

feat(server): Use safetensors

Co-authored-by: OlivierDehaene <[email protected]>

feat(client): Simplify sharded logic

feat(server): Support bitsandbytes

feat(server): Support all AutoModelForCausalLM on a best effort basis

feat: Use json formatter by default in docker image

fix(models): Revert buggy support for AutoModel

feat(server): Support generic AutoModelForCausalLM

feat(server): Support AutoModelForSeq2SeqLM

feat(launcher): Pass CUDA_VISIBLE_DEVICES to the shard

feat(server): Improved doc

fix(server): Fix Transformers fork version

feat(server): Clarify CausalLMBatch concatenate method

feat(rust): Update to 1.65

fix(router): Fix HTTP status codes

fix(readme): Typo

fix(router): Handle tokenizer errors

feat(server): Support Galactica (#4)

fix(batching): Avoid theoretical hang in batcher loop (#5)

- Avoid theoretical hang in batcher loop
- Avoid a couple of clones in the router generate method
- Keep attention mask tensors as integers
- Remove num_heads attribute

Co-authored-by: OlivierDehaene <[email protected]>

feat(server): Add model tests (#6)

fix(server): Only pad to multiple of 8 on GPUs

feat: Support stop sequences (#7)

feat: Return logprobs (#8)

feat(launcher): Add integration tests (#9)

fix(server): Fix stop sequences (#11)

fix(server): Check for device type correctly when determining initial padding (#16)

AFAIK there is no torch device type called "gpu".

fix(router): Include special tokens when tokenizing (#14)

There's currently a discrepancy in the tokenization between the router
and python server code. The latter includes special tokens but former
does not.

This results in a token count mismatch for seq2seq models such as mt0
where the tokenizer emits an EOS token at the end.

This in turn results in some unexpected/incorrect output, in particular
when batch concatenation is involved, because the python code uses the
input length passed from the router for each row.

As far as I can tell, it is better to include this token in the encoder
`input_ids`, so I guess it's best to just adjust on the router side.

feat(router): Add const parameters to validation logic  (#15)

I noticed some opportunity to collapse some of the logic, in case you
are interested.

fix(server): Use cleanup_tokenization_spaces=False for lossless decoding (#13)

Fixes #12 in the easiest way I could think of.

feat(launcher): Log server stdout (#19)

Co-authored-by: Nick Hill <[email protected]>

fix(server): Minor refactorization using new_zeros (#24)

- Fix some type hints, in particular base tokenizer class
- Make use of `tensor.new_zero/empty` methods
- Simplify env var string parsing in launcher

fix(router): Obey max batch size (#23)

feat(server): Support SantaCoder (#26)

fix(server): Fix position ids (#28)

feat(docker): Make the image compatible with api-inference (#29)

fix(docker): fix api-inference deployment (#30)

fix(router): fix api-inference deployment (#31)

fix(dockerfile): fix docker build (#32)

feat(bloom): use torch.nn.Linear and torch.nn.GELU (#33)

feat(router): Remove second lock from batcher hot path (#27)

@njhill

feat: Support sampling seeding (#37)

Co-authored-by: Yannic Kilcher <[email protected]>

feat: Add token streaming using ServerSideEvents support (#36)

Add token streaming using ServerSideEvents (SSE).

The signature of the SSE events is:

```rust
struct Details {
    finish_reason: String,
    generated_tokens: u32,
    seed: Option<u64>,
}

struct StreamResponse {
    token: Token,
    generated_text: Option<String>,
    details: Option<Details>,
}

struct ErrorResponse {
    error: String,
}
```

Revert "feat: Add token streaming using ServerSideEvents support" (#40)

Reverts huggingface/text-generation-inference#36

fix(server): fix seeding on gpu (#42)

fix(server): fix seeding with multiple shards (#44)

feat: Add token streaming using ServerSideEvents support (#41)

fix(server): fix quantization for sharded models (#45)

feat(server): Support GPT-Neox (#39)

feat(ci): Docker build and push (#46)

feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)

feat(server): support repetition penalty (#47)

feat(server): allow the server to use a local weight cache (#49)

fix(server): allow greedy repetition penalty (#51)

feat(router): use background task to manage request queue (#52)

Co-authored-by: Nick Hill <[email protected]>

breaking(router): modify /generate API to only return generated text (#50)

@njhill, @yk FYI

generated_text was concatenated to the user prompt for legacy reason. We
want to remove this behaviour as we don't think it is useful and even
detrimonial to usability.

We also remove the unused Vec.

feat(router): refactor API and add openAPI schemas (#53)

feat(docs): Clarify installation steps (#54)

Adds some bits for first-time users (like me 😄 )

feat(ci): push to AML registry (#56)

fix(server): better handling of inference mode (#57)

V0.2.1 (#58)

feat(server): support t5 (#59)

fix(docker): increase shm size (#60)

fixed SSE naming (#61)

https://en.wikipedia.org/wiki/Server-sent_events

feat: add distributed tracing (#62)

feat: add safetensors conversion (#63)

feat(server): improve download logging (#66)

feat(launcher): add disable_custom_kernels arg (#67)

feat(router): add max_total_tokens and empty_input validation (#68)

closes #65

fix(launcher): copy current env vars to subprocesses (#70)

closes #69

feat(router): add prometheus metrics scrape endpoint (#71)

v0.3.0 (#72)

feat(router): add cors allow origin options (#73)

feat(server): enable hf-transfer (#76)

fix(server): remove position_ids from galactica forward (#82)

closes #80

feat(server): pre-allocate max attention mask (#75)

v0.3.1 (#84)

feat(server): add special token bool (#85)

fix(docs): fix openapi schema (#86)

fix(server): fix token_is_special (#87)

feat(router): add legacy route for api-inference support (#88)

feat(router): ask hf.co for pipelinetag to decide on compat_return_full_text (#89)

feat(router): add api-inference headers (#91)

feat(server): add logits watermark (#90)

feat(server): update to hf_transfer==0.1.2 (#93)

feat(ci): improve CI speed (#94)

fix(launcher): add router parameters to launcher (#95)

feat(server): fix transformers commit (#96)

v0.3.2 (#97)

fix(server): fix generate_stream by forcing tokens to be decoded correctly (#100)

feat: allow local models (#101)

closes #99

feat: add supported models (#102)

feat(clients): Python client (#103)

fix(server): fix galactica batch (#106)

closes #105

feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES (#107)

feat(launcher): default num_shard to CUDA_VISIBLE_DEVICES if possible (#108)

fix(python-client): stream not set on the sync client (#109)

fix(server): fix index out of range for watermarking (#110)

feat: support typical sampling (#114)

closes #112

fix(server): do not warp prefill logits (#116)

feat(router): support left truncation (#115)

closes #111

feat(router): add best_of parameter (#117)

feat(python-client): add new parameters (#118)

v0.4.0 (#119)

feat: add OpenAssistant/oasst-sft-1-pythia-12b to the list of supported models (#122)

…ed models

fix(server): revert gpt-neox optims (#123)

fix(server): add position ids to neox (#126)

fix(server): use server tokenizer as gt (#128)

fix(python-client): relax dependencies (#129)

feat(python-client): add cookies to Client constructors and requests (#132)

I have a use case where we need to pass cookies (for auth reasons) to an
internally hosted server.

Note: I couldn't get the client tests to pass - do you need to have an
HF token?

```python
FAILED tests/test_client.py::test_generate - text_generation.errors.BadRequestError: Authorization header is correct, but the token seems invalid
```

feat(ci): add ci paths (#134)

feat: Add note about NVIDIA drivers (#64)

Co-authored-by: OlivierDehaene <[email protected]>

feat(python-client): release v0.4.0 (#135)

feat(python-client): add CI (#136)

feat(server): flash neoX (#133)

fix(server): fix flash-neox scores warping (#137)

feat(server): cleanup flash neox loading (#139)

v0.4.1 (#140)

fix(server): Avoid using try/except to determine kind of AutoModel (#142)

feat(server): Add mypy-protobuf (#141)

Generates .pyi files for protobuf stubs which provide strong typing
information. Very helpful for IDE auto-completion, etc.

feat(server): clear cache on error (#143)

feat(server): reduce mlp and attn in one op for flash neox (#145)

feat: aws sagemaker compatible image (#147)

The only difference is that now it pushes to
registry.internal.huggingface.tech/api-inference/community/text-generation-inference/sagemaker:...
instead of
registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sagemaker-...

---------

Co-authored-by: Philipp Schmid <[email protected]>

fix(ci): fix sagemaker action (#148)

feat(benchmark): tui based benchmarking tool (#149)

fix(server): fix flash neox rotary embeddings (#150)

v0.4.2 (#151)

v0.4.3 (#152)

feat(server): flash santacoder (#153)

docs(readme): provide link Logits Warper README (#154)

fix(server): fix escape characters in stop sequence (#155)

feat(docker): improve flash_attention caching (#160)

feat(launcher): allow disabling hf_transfer (#161)

fix(rust-client): use join_all instead of select_all to hopefully fix nccl issues (#162)

fix(router): use buckets for metrics histograms (#163)

feat(router): make router input validation optional (#164)

feat(server): add flash attention llama (#144)

feat(server): support OPT models (#55)

OPT models do not all have a `tokenizer.json` file on the hub at the
moment. Can't merge for now.

v0.5.0 (#168)

feat(server): optimize decode for sane tokenizers (#170)

feat(server): support sharded santacoder (#167)

fix(launcher): revert change on shard errors (#173)

fix(ci): fix CVE in github-slug-action (#174)

feat(ci): add image signing with cosign (#175)

feat(ci): add Trivy and scan docker image (#178)

feat(ci): use large runners (#179)

feat(ci): faster scanning (#180)

fix(ci): fix ci permissions (#181)

fea(dockerfile): better layer caching (#159)

fix(ci): fix cosign error (#183)

fix(docker): fix docker image (#184)

fix(docker): fix image (#185)

fix(docker): revert dockerfile changes (#186)

fix(docker): fix docker image dependencies (#187)

fix(router): fix truncation (#190)

closes #189

feat(python-client): get list of currently deployed tgi models using the inference API (#191)

feat(router): add info route (#196)

close #125

feat(server): support quantization for flash models (#200)

closes #197

feat(server): check cuda capability when importing flash models (#201)

close #198

fix(server): fix hf_transfer issue with private repos (#203)

fix(docker): remove unused dependencies (#205)

fix(router): add auth token to get model info (#207)

feat(router): add git sha to info route (#208)

feat(router): drop requests when client closes the channel (#202)

fix(ci): fix sha in docker image (#212)

feat(server): flash attention past key value optimizations (#213)

feat(router): add device and dtype info (#215)

fix(server): fix past key values logic (#216)

@njhill fyi

fix(server): cleanup new flash past_key_values logic (#217)

fix(server): fix flash causal (#218)

fix(server): fix flash causal (#219)

fix(server): fix flash batch filtering (#220)

misc: update to rust 1.69 (#221)

v0.6.0 (#222)

feat(server): reduce memory requirement (#214)

chore(server): update huggingface-hub (#227)

feat(router): use number of tokens in batch as input for dynamic batching (#226)

Co-authored-by: Nick Hill <[email protected]>

feat(router): add endpoint info to /info route (#228)

chore(server): update safetensors version (#235)

fix(python-client): add auth headers to is supported requests (#234)

Starting some routing tests. (#233)

fix(benchmarking): fix benchmarking tool

chore(launcher): refactor logic (#242)

Hopefully it's cleaner

feat(router): add tests to validation (#237)

feat(router): new healthcheck that skips the queue (#244)

Co-authored-by: OlivierDehaene <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

fix(server): fix reshaping of bloom past_key_values in concatenate() (#252)

Introduced in #214

Fixes #249

fix(server): Small tidy of code from recent changes (#251)

remaining_decode_tokens was calculated twice in Seq2SeqLMBatch.filter()

chore(server): update transformers (#250)

feat(server): add watermarking tests (#248)

feat(docker): add nvidia env vars (#255)

doc(launcher): add more docs to the `launcher` itself and link in the README (#257)

feat(benchmark): add support for private tokenizers (#262)

Adding docs on how dynamic batching works. (#258)

This PR starts the minimal possible amount of explanation I could think
of. It tries to explain how dynamic batching occurs, the interactions
with past key values and ignores the padding problem.

Maybe some drawings could help too but I kept it to text for now.

chore(github): add templates (#264)

fix(server): fix typo in tokenizers decode (#269)

closes #268

feat(server): support hf endpoint weight layout (#266)

fix(launcher): pass weights cache override to the download process (#274)

closes #273

fix(launcher): handle hub branches (#278)

fix(server): Removes the parallelism in file convertion (during download) (#275)

feat(launcher): Improve error message when download process fails. (#276)

fix(server): fix convert (#284)

chore: add `flash-attention` to docker ignore (#287)

included when building docker locally.
(Where the local dirs might have the flash-attention folder.)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

fea(server): decrease convert RAM requirements (#286)

fix(dockerfile): fix nvidia env vars (#297)

Fixes #291

feat(router): Adding response schema for compat_generate (#292)

feat(docker): add benchmarking tool to docker image (#298)

fix(docker): fix docker build (#299)

feat(server): optim flash causal lm decode_token (#285)

fix(docker): fix nvidia env vars (#305)

fix(docker): remove nvidia require cuda env (#310)

feat(server): shard token decode (#303)

feat(server): use float16 (#304)

fix(docker): remove CUDA_VERSION

feat(server): use cuda graph in logits warping (#302)

fix(server): fix multinomial implem in Sampling

feat(server): GPTQ quantization (step1) (#277)

Changes only the type from `bool` to `Option<Enum>` pretty much
everywhere.
- Use `Optional[str]` in Python (easier to manage than importing type
everywhere). Except for the cli to get proper validation
- Updated all models to handle gracefully new values. (Error out if
unknown value, or gptq since not implemented).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

chore(docker): use nvidia base image (#318)

fix(docker): remove quantize default

fix(docker): use ubuntu20.04

Hotfixes for santacoder/bigcode. (#294)

Hotfixes:

- Uses `model_type`=`gpt_bigcode` for more general usage.
- Hotfixes linked lm_head vs wte_embedding (safetensors file do not
contain the key, correctly when the file is sharded, where as pytorch
copies the tensor)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

Lifting check_unitialized. (#325)

Lifting check_unitialized.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

Removing dead variables. (#327)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat(ci): custom gpu runners (#328)

Single place for TP layers + Dropout Layer Norm + FastLinear (#329)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat: add snapshot testing (#282)

feat(integration-tests): improve comparison and health checks (#336)

fix(server): fix decode token (#334)

Fixes #333

---------

Co-authored-by: Nicolas Patry <[email protected]>

fix: set MODEL_ID in sagemaker-entrypoint script (#343)

feat(server): Support BLOOMChat-176B (#348) (#351)

@njhill,
temporary workaround to be able to run our CI as secrets are not
available to runners run by external contributors. I will ask around to
see if there is a better way.

Co-authored-by: Nick Hill <[email protected]>

fix(server): fix init for flash causal lm (#352)

Fixes #347

fix(server): t5 cannot run in f16 (#356)

Fix #349

fix(ci): fix security group (#359)

Switch security group used for ci
(open outbound rules)

Signed-off-by: Raphael <[email protected]>
Co-authored-by: Raphael <[email protected]>

feat: add nightly load testing (#358)

chore(sever): update requirements (#357)

Fixes #338

feat(server): support fp16 for t5 (#360)

Fixes #349

feat(server): do not use device_map auto on single GPU (#362)

feat(server): support trust_remote_code (#363)

feat(router): log input/ouput at debug level (#364)

@njhill FYI

v0.7.0 (#353)

feat: decrease IPC proto size (#367)

Closes #307 #308

feat(benchmarker): add summary tables (#368)

feat(server): support vectorized warpers in flash causal lm (#317)

Co-authored-by: Joel Lamy-Poirier <[email protected]>

Fix issue when load AutoModelForSeq2SeqLM model (#370)

fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

fix(server): fix quantization

feat(server): support RefinedWeb models (#379)

v0.8.0

increase health checks

feat(server): add retry on download (#384)

fix(server): fix bnb quantization for CausalLM models (#385)

v0.8.1

fix(server): fix has_position_ids (#395)

Fix #389

feat(server): remove trust_remote_code requirement for falcon models (#396)

feat(server): load santacoder/starcoder models with safetensors (#393)

Fix #366

v0.8.2

feat(sagemaker): add trust remote code to entrypoint (#394)

feat(launcher): parse oom signal (#404)

feat(server): only compute prefill logprobs when asked (#406)

Close #288

feat(server): batch tokenization for flash causal lm (#411)

chore: update openapi schema

feat(server): Rework model loading (#344)

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

feat(server): optimize dist ops (#434)

docs(launcher): fix CUDA_VISIBLE_DEVICES helper comment (#441)

It solves a typo in the comment sections referencing the environment
variable `CUDA_VISIBLE_DEVICES`. No misspelling references to this
variable have been found in code logic leading to undefined behaviour or
bugs. This PR is not expected to perform any code logic modification.

fix(makefile): Fix typo and use POSIX comparison in the makefile (#443)

This PR fixes:
- The usage of non posix comparison which may fail depending on the
shell used (`=` will always work, `==` only with bash)
- Typo in the env variable name displayed in the error message
`BUILD_EXTENSION` instead of `BUILD_EXTENSIONS`

<!-- Remove if not applicable -->

Fixes #422

feat(server): pre-allocate past key values for flash causal LM (#412)

feat(router): add ngrok integration (#453)

feat(server): improve flash attention import errors (#465)

@lewtun, is this enough?

Closes #458
Closes #456

fix(server): fix warpers on CPU (#472)

Closes #471

fix(server): Fixing T5 in case the names are mixed up. (#475)

feat(server): Update convert logic. (#483)

Should be more robust to shared tensors (ok when using
      `from_pretrained). But forcing us to add new checks in our loading
      code (since the chosen key to keep might be different from
      `transformers`).

---------

Co-authored-by: Ubuntu <[email protected]>

feat(server): Adding new ignore_rule for conversion. (#485)

fix(router): add timeout on flume sends (#488)

feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)

Let's start discussing implementation.

- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).

Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.

My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

fix(server): Do not init process group if already initialized (#388)

feat(router): add header option to disable buffering for the generate_stream response (#498)

generate_stream endpoint response stream.

Problem: If a model is run behind a proxy server such as nginx that has
buffering enabled then the response stream from generate_stream gets
aggregated into a single response which basically disables streaming.
Instead of getting a chunked response where each token is presented over
time the response presents everything all at once.

Solution: This change adds the `X-Accel-Buffering` http header which
disables buffering for the generate_stream response, allowing the
response to stream properly.

feat(server): add paged attention to flash models (#516)

Closes #478

feat(router): arg validation (#519)

feat: Add the option to force another dtype than `f16`. (#513)

fix(launcher): fix issue where launcher does not properly report shard failures (#522)

v0.9.0 (#525)

feat(server): Add Non flash MPT. (#514)

This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
https://github.com/huggingface/text-generation-inference/issues/361
Fixes
https://github.com/huggingface/text-generation-inference/issues/491
Fixes
https://github.com/huggingface/text-generation-inference/issues/290

fix: Update server/Makefile to include Makefile-vllm (#520)

For consistency and ease of use (you can just run `make` to install vllm
without any extra steps).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs(benchmarker): Adding some help for the options in `text-generation-benchmark`. (#462)

fix(server): Handle loading from local files for MPT (#534)

This PR allows the MPT model to be loaded from local files. Without this
change, an exception will be thrown by `hf_hub_download` function if
`model_id` is a local path.

fix(server): avoid errors for very small top_p values (#544)

See https://github.com/huggingface/transformers/pull/24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.

feat(server): use latest flash attention commit (#543)

@njhill FYI

feat(router): add argument for hostname in router (#545) (#550)

In title. Adds argument `--hostname` in router to support something like
`--hostname ::`. Tested with

```commandline
cargo run -- --port 8080 --hostname ::
curl -I -X GET 'http://[::1]:8080/health'  # failed before this commit
```

Trigger CI

---------

Co-authored-by: Phil Chen <[email protected]>

fix(server): decrease memory fragmentation (#557)

v0.9.1 (#558)

fix(server): harden the weights choice to save on disk. (#561)

- Look at `transformers` base class to check for
  `_key_to_ignore_on_load_missing` or `_tied_weights` which are the
  standard attributes to select the keys to NOT save on disk (since they
  are ignored)

- Modified safetensors code (to be reflected in safetensors even if it's
  an internal function).

- Will not work for trust_remote_code=True repos (like santacoder).

Should help with :
https://github.com/huggingface/text-generation-inference/issues/555
and : https://github.com/huggingface/text-generation-inference/pull/501
and https://github.com/huggingface/text-generation-inference/issues/556
and
https://github.com/huggingface/text-generation-inference/issues/482#issuecomment-1623713593

feat: better errors for warmup and TP (#575)

Close #571

fix(server): Fixing RW code (it's remote code so the Arch checking doesn't work to see which weights to keep). (#579)

Fixes #555

feat(server): Support for env value for GPTQ_BITS and GPTQ_GROUPSIZE. (#580)

Some models are already converted, and do not have those values in the
file, this enables users to use them with less friction.

Went for pure env based because adding flags would end up (imo) very
tedious to maintain. There's a lot of sanitation to do: those flags
would be errors if not used in conjuction with `--quantize gptq`.
Then the flags need to exist in the launcher and the server passing them
all throughout all function calls.

This PR is intended as an easy escape hatch, not the defacto method to
use gptq in TGI.

Fixes #500

chore: migrate ci region for more availability. (#581)

fix(server): T5 weights names. (#582)

Fixes #541

fix(server): Adding logger import to t5_modeling.py (#585)

Logger is referenced during the apex importing but is not imported,
causing a NameError

fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590)

This fixes a typo and extends the GPTP_BITS environment variables
through to the second method which requires the same logic. Please let
me know if there's anything I've misunderstood in this change.

Thanks @Narsil for the original fix.

feat(server): Implements sharding for non divisible `vocab_size`. (#583)

- The code is relatively easy (just disable the checks on Embedding and
Head)

This cannot be done in the same easy fashion for hidden_dim/head_dim.
It's relatively easy on some models (classic MHA) but it would make the
other
models (MQA) much more complex, and GPTQ quantization another quite
hairy piece
of code.

feat(server): empty cache on errors

GPTQ Env vars: catch correct type of error (#596)

When passing in environment variables like gptq_bits, we still get
errors thrown from TGI because the try/catch block is catching the wrong
type of error. This PR aims to fix that.

@Narsil - let me know if this is how you want this formatted. My Python
is a little shaky, so I hope this syntax is correct.

feat(launcher): add arg validation and drop subprocess (#595)

feat(router): explicit warning if revision is not set (#608)

docs: README: Add logo + baseline (#611)

![image](https://github.com/huggingface/text-generation-inference/assets/3841370/58177321-479f-4ad1-b3bc-cec027423984)

fix(server): blacklist local files (#609)

Close #589 #602

v0.9.2 (#616)

fix(server): empty_cache when stopped

fix(launcher): Rename `b-float16` to `bfloat16` in the launcher arg (#621)

fea(launcher): debug logs (#623)

feat(server): Reworking the quantization script so it's still universal (not llama specific) (#587)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).

Reworking the quantization script so it's still universal (not llama
specific)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).

Still need to investigate the potential differences in quantization
results.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat(server): flash attention v2 (#624)

feat(server): add support for llamav2 (#633)

v0.9.3 (#634)

fix(server): fix llamav2 config (#635)

feat(server): auto max_batch_total_tokens for flash att models (#630)

feat(router): ngrok edge (#642)

docs: Update README.md (#639)

docs: Update README.md (#643)

Add trust_remote_code to quantize script (#647)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes a bug appeared with MR #587 fixing issue #552.
See the discussion in #552.

With MR #587 the trust_remote_code variable is not passed to
AutoModelForCausalLM, but is found in the function signature. This
prevents models like falcon to be quantized, because trust_remote_code
is required. This MR fixes the issue.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [X] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
@Narsil
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

 -->

fix(server): llama v2 GPTQ (#648)

As per title & reported
https://github.com/huggingface/text-generation-inference/issues/601#issuecomment-1641435956
https://huggingface.co/TheBloke/Llama-2-70B-chat-GPTQ/discussions/5

Test it:

```
GPTQ_BITS=4 GPTQ_GROUPSIZE=1 text-generation-launcher --model-id TheBloke/Llama-2-70B-chat-GPTQ --port 8080 --num-shard 4 --quantize gptq
```
&
```
curl 127.0.0.1:8080/generate \
    -X POST \
    -d '{"inputs":"hey llama","parameters":{"max_new_tokens":256}}' \
    -H 'Content-Type: application/json'
```

fix(server): Fixing non parameters in quantize script `bigcode/starcoder` was an example. (#661)

fix(server): use mem_get_info to get kv cache size (#664)

Close
https://github.com/huggingface/text-generation-inference/issues/649
Close
https://github.com/huggingface/text-generation-inference/issues/651
Close
https://github.com/huggingface/text-generation-inference/issues/653
Close #636

feat(server): Add exllama GPTQ CUDA kernel support #553 (#666)

Just trying to get the integration tests to pass.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Felix Marty <[email protected]>

Directly load GPTBigCode to specified device (#618)

This PR directly load GPTBigCode to specified device, avoiding moving
model between devices.

This PR directly load GPTBigCode to specified device, avoiding moving
model between devices.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

@OlivierDehaene OR @Narsil

feat(server): add local prom and health routes if running w/ ngrok

feat: add cuda memory fraction (#659)

Close #673

fix(server): fix exllama buffers (#689)

Close #683

feat(server): Using `quantize_config.json` instead of GPTQ_BITS env variables. (#671)

- Current PR is not great because we're side stepping the
  `Weights.__init__` but Weights shouldn't requires anything related
  to the config or the model_id as it aims to be a simple Wrapper
  over multi file loading.
- Ideal solution would be to use something like Rust enum
  ```
  enum Quantize{
    Bitandbytes(Bitsandbytes),
    GPTQ(bits: usize, groupsize: usize)
  ```
  And passing that around during load. Unfortunately we don't
  have access to this, so for now, side-stepping seems easier.

- Re-enabling groupsize<0 with exllama (confirmed it works.)

Helps #601

In next steps we should make sure our quantization script uses that
format and make it standard.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs(README): update readme

fix(server): fix quantization python requirements (#708)

fix(server): fix missing datasets in quantize

feat(server): support new falcon config (#712)

v0.9.4 (#713)

Add section about TGI on other AI hardware accelerators in README (#715)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

As per title.

- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs: Add hardware section to TOC in README (#721)

feat(server): update vllm version (#723)

chore: update license to HFOIL (#725)

v1.0.0 (#727)

Local gptq support. (#738)

Redoes #719

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

Fix typing in `Model.generate_token` (#733)

This PR fixes a minor type annotation issue in the signature of
`Model.generate_token`.

All existing overrides of `Model.generate_token` return
`Tuple[List[Generation], Optional[B]]`:

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/causal_lm.py#L535-L537

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/flash_causal_lm.py#L802-L804

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/seq2seq_lm.py#L589-L591

I suspect that back in 017a2a8c when `GeneratedText` and `Generation`
were separated, the function signature was not updated.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

CC @OlivierDehaene

Adding Rope scaling. (#741)

- Adds Rope NTK scaling.

Done because
https://github.com/huggingface/text-generation-inference/pull/529 was
closed
Took some code from
https://github.com/huggingface/transformers/pull/24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).

Fixes #512

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

chore: fix typo in mpt_modeling.py (#737)

Fixed typo.
<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

implemetation -> implementation

- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update…
tjluyao added a commit to mlsys-io/kv.run that referenced this pull request Jul 7, 2024
Init

fix: cleanup

Add load testing

Refactored gRPC interface
Added validation logic

ValidationError was not correctly handled

Use axum

feat: Docker image

feat: Add AML deployment

Update aml deployment

feat: Improve error handling

feat: Add arguments to CLI

v0.1.0

fix(validation): Fix error messages

feat(router): Add max_waiting_tokens

Create LICENSE (#2)

feat(server): Use safetensors

Co-authored-by: OlivierDehaene <[email protected]>

feat(client): Simplify sharded logic

feat(server): Support bitsandbytes

feat(server): Support all AutoModelForCausalLM on a best effort basis

feat: Use json formatter by default in docker image

fix(models): Revert buggy support for AutoModel

feat(server): Support generic AutoModelForCausalLM

feat(server): Support AutoModelForSeq2SeqLM

feat(launcher): Pass CUDA_VISIBLE_DEVICES to the shard

feat(server): Improved doc

fix(server): Fix Transformers fork version

feat(server): Clarify CausalLMBatch concatenate method

feat(rust): Update to 1.65

fix(router): Fix HTTP status codes

fix(readme): Typo

fix(router): Handle tokenizer errors

feat(server): Support Galactica (#4)

fix(batching): Avoid theoretical hang in batcher loop (#5)

- Avoid theoretical hang in batcher loop
- Avoid a couple of clones in the router generate method
- Keep attention mask tensors as integers
- Remove num_heads attribute

Co-authored-by: OlivierDehaene <[email protected]>

feat(server): Add model tests (#6)

fix(server): Only pad to multiple of 8 on GPUs

feat: Support stop sequences (#7)

feat: Return logprobs (#8)

feat(launcher): Add integration tests (#9)

fix(server): Fix stop sequences (#11)

fix(server): Check for device type correctly when determining initial padding (#16)

AFAIK there is no torch device type called "gpu".

fix(router): Include special tokens when tokenizing (#14)

There's currently a discrepancy in the tokenization between the router
and python server code. The latter includes special tokens but former
does not.

This results in a token count mismatch for seq2seq models such as mt0
where the tokenizer emits an EOS token at the end.

This in turn results in some unexpected/incorrect output, in particular
when batch concatenation is involved, because the python code uses the
input length passed from the router for each row.

As far as I can tell, it is better to include this token in the encoder
`input_ids`, so I guess it's best to just adjust on the router side.

feat(router): Add const parameters to validation logic  (#15)

I noticed some opportunity to collapse some of the logic, in case you
are interested.

fix(server): Use cleanup_tokenization_spaces=False for lossless decoding (#13)

Fixes #12 in the easiest way I could think of.

feat(launcher): Log server stdout (#19)

Co-authored-by: Nick Hill <[email protected]>

fix(server): Minor refactorization using new_zeros (#24)

- Fix some type hints, in particular base tokenizer class
- Make use of `tensor.new_zero/empty` methods
- Simplify env var string parsing in launcher

fix(router): Obey max batch size (#23)

feat(server): Support SantaCoder (#26)

fix(server): Fix position ids (#28)

feat(docker): Make the image compatible with api-inference (#29)

fix(docker): fix api-inference deployment (#30)

fix(router): fix api-inference deployment (#31)

fix(dockerfile): fix docker build (#32)

feat(bloom): use torch.nn.Linear and torch.nn.GELU (#33)

feat(router): Remove second lock from batcher hot path (#27)

@njhill

feat: Support sampling seeding (#37)

Co-authored-by: Yannic Kilcher <[email protected]>

feat: Add token streaming using ServerSideEvents support (#36)

Add token streaming using ServerSideEvents (SSE).

The signature of the SSE events is:

```rust
struct Details {
    finish_reason: String,
    generated_tokens: u32,
    seed: Option<u64>,
}

struct StreamResponse {
    token: Token,
    generated_text: Option<String>,
    details: Option<Details>,
}

struct ErrorResponse {
    error: String,
}
```

Revert "feat: Add token streaming using ServerSideEvents support" (#40)

Reverts huggingface/text-generation-inference#36

fix(server): fix seeding on gpu (#42)

fix(server): fix seeding with multiple shards (#44)

feat: Add token streaming using ServerSideEvents support (#41)

fix(server): fix quantization for sharded models (#45)

feat(server): Support GPT-Neox (#39)

feat(ci): Docker build and push (#46)

feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)

feat(server): support repetition penalty (#47)

feat(server): allow the server to use a local weight cache (#49)

fix(server): allow greedy repetition penalty (#51)

feat(router): use background task to manage request queue (#52)

Co-authored-by: Nick Hill <[email protected]>

breaking(router): modify /generate API to only return generated text (#50)

@njhill, @yk FYI

generated_text was concatenated to the user prompt for legacy reason. We
want to remove this behaviour as we don't think it is useful and even
detrimonial to usability.

We also remove the unused Vec.

feat(router): refactor API and add openAPI schemas (#53)

feat(docs): Clarify installation steps (#54)

Adds some bits for first-time users (like me 😄 )

feat(ci): push to AML registry (#56)

fix(server): better handling of inference mode (#57)

V0.2.1 (#58)

feat(server): support t5 (#59)

fix(docker): increase shm size (#60)

fixed SSE naming (#61)

https://en.wikipedia.org/wiki/Server-sent_events

feat: add distributed tracing (#62)

feat: add safetensors conversion (#63)

feat(server): improve download logging (#66)

feat(launcher): add disable_custom_kernels arg (#67)

feat(router): add max_total_tokens and empty_input validation (#68)

closes #65

fix(launcher): copy current env vars to subprocesses (#70)

closes #69

feat(router): add prometheus metrics scrape endpoint (#71)

v0.3.0 (#72)

feat(router): add cors allow origin options (#73)

feat(server): enable hf-transfer (#76)

fix(server): remove position_ids from galactica forward (#82)

closes #80

feat(server): pre-allocate max attention mask (#75)

v0.3.1 (#84)

feat(server): add special token bool (#85)

fix(docs): fix openapi schema (#86)

fix(server): fix token_is_special (#87)

feat(router): add legacy route for api-inference support (#88)

feat(router): ask hf.co for pipelinetag to decide on compat_return_full_text (#89)

feat(router): add api-inference headers (#91)

feat(server): add logits watermark (#90)

feat(server): update to hf_transfer==0.1.2 (#93)

feat(ci): improve CI speed (#94)

fix(launcher): add router parameters to launcher (#95)

feat(server): fix transformers commit (#96)

v0.3.2 (#97)

fix(server): fix generate_stream by forcing tokens to be decoded correctly (#100)

feat: allow local models (#101)

closes #99

feat: add supported models (#102)

feat(clients): Python client (#103)

fix(server): fix galactica batch (#106)

closes #105

feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES (#107)

feat(launcher): default num_shard to CUDA_VISIBLE_DEVICES if possible (#108)

fix(python-client): stream not set on the sync client (#109)

fix(server): fix index out of range for watermarking (#110)

feat: support typical sampling (#114)

closes #112

fix(server): do not warp prefill logits (#116)

feat(router): support left truncation (#115)

closes #111

feat(router): add best_of parameter (#117)

feat(python-client): add new parameters (#118)

v0.4.0 (#119)

feat: add OpenAssistant/oasst-sft-1-pythia-12b to the list of supported models (#122)

…ed models

fix(server): revert gpt-neox optims (#123)

fix(server): add position ids to neox (#126)

fix(server): use server tokenizer as gt (#128)

fix(python-client): relax dependencies (#129)

feat(python-client): add cookies to Client constructors and requests (#132)

I have a use case where we need to pass cookies (for auth reasons) to an
internally hosted server.

Note: I couldn't get the client tests to pass - do you need to have an
HF token?

```python
FAILED tests/test_client.py::test_generate - text_generation.errors.BadRequestError: Authorization header is correct, but the token seems invalid
```

feat(ci): add ci paths (#134)

feat: Add note about NVIDIA drivers (#64)

Co-authored-by: OlivierDehaene <[email protected]>

feat(python-client): release v0.4.0 (#135)

feat(python-client): add CI (#136)

feat(server): flash neoX (#133)

fix(server): fix flash-neox scores warping (#137)

feat(server): cleanup flash neox loading (#139)

v0.4.1 (#140)

fix(server): Avoid using try/except to determine kind of AutoModel (#142)

feat(server): Add mypy-protobuf (#141)

Generates .pyi files for protobuf stubs which provide strong typing
information. Very helpful for IDE auto-completion, etc.

feat(server): clear cache on error (#143)

feat(server): reduce mlp and attn in one op for flash neox (#145)

feat: aws sagemaker compatible image (#147)

The only difference is that now it pushes to
registry.internal.huggingface.tech/api-inference/community/text-generation-inference/sagemaker:...
instead of
registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sagemaker-...

---------

Co-authored-by: Philipp Schmid <[email protected]>

fix(ci): fix sagemaker action (#148)

feat(benchmark): tui based benchmarking tool (#149)

fix(server): fix flash neox rotary embeddings (#150)

v0.4.2 (#151)

v0.4.3 (#152)

feat(server): flash santacoder (#153)

docs(readme): provide link Logits Warper README (#154)

fix(server): fix escape characters in stop sequence (#155)

feat(docker): improve flash_attention caching (#160)

feat(launcher): allow disabling hf_transfer (#161)

fix(rust-client): use join_all instead of select_all to hopefully fix nccl issues (#162)

fix(router): use buckets for metrics histograms (#163)

feat(router): make router input validation optional (#164)

feat(server): add flash attention llama (#144)

feat(server): support OPT models (#55)

OPT models do not all have a `tokenizer.json` file on the hub at the
moment. Can't merge for now.

v0.5.0 (#168)

feat(server): optimize decode for sane tokenizers (#170)

feat(server): support sharded santacoder (#167)

fix(launcher): revert change on shard errors (#173)

fix(ci): fix CVE in github-slug-action (#174)

feat(ci): add image signing with cosign (#175)

feat(ci): add Trivy and scan docker image (#178)

feat(ci): use large runners (#179)

feat(ci): faster scanning (#180)

fix(ci): fix ci permissions (#181)

fea(dockerfile): better layer caching (#159)

fix(ci): fix cosign error (#183)

fix(docker): fix docker image (#184)

fix(docker): fix image (#185)

fix(docker): revert dockerfile changes (#186)

fix(docker): fix docker image dependencies (#187)

fix(router): fix truncation (#190)

closes #189

feat(python-client): get list of currently deployed tgi models using the inference API (#191)

feat(router): add info route (#196)

close #125

feat(server): support quantization for flash models (#200)

closes #197

feat(server): check cuda capability when importing flash models (#201)

close #198

fix(server): fix hf_transfer issue with private repos (#203)

fix(docker): remove unused dependencies (#205)

fix(router): add auth token to get model info (#207)

feat(router): add git sha to info route (#208)

feat(router): drop requests when client closes the channel (#202)

fix(ci): fix sha in docker image (#212)

feat(server): flash attention past key value optimizations (#213)

feat(router): add device and dtype info (#215)

fix(server): fix past key values logic (#216)

@njhill fyi

fix(server): cleanup new flash past_key_values logic (#217)

fix(server): fix flash causal (#218)

fix(server): fix flash causal (#219)

fix(server): fix flash batch filtering (#220)

misc: update to rust 1.69 (#221)

v0.6.0 (#222)

feat(server): reduce memory requirement (#214)

chore(server): update huggingface-hub (#227)

feat(router): use number of tokens in batch as input for dynamic batching (#226)

Co-authored-by: Nick Hill <[email protected]>

feat(router): add endpoint info to /info route (#228)

chore(server): update safetensors version (#235)

fix(python-client): add auth headers to is supported requests (#234)

Starting some routing tests. (#233)

fix(benchmarking): fix benchmarking tool

chore(launcher): refactor logic (#242)

Hopefully it's cleaner

feat(router): add tests to validation (#237)

feat(router): new healthcheck that skips the queue (#244)

Co-authored-by: OlivierDehaene <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

fix(server): fix reshaping of bloom past_key_values in concatenate() (#252)

Introduced in #214

Fixes #249

fix(server): Small tidy of code from recent changes (#251)

remaining_decode_tokens was calculated twice in Seq2SeqLMBatch.filter()

chore(server): update transformers (#250)

feat(server): add watermarking tests (#248)

feat(docker): add nvidia env vars (#255)

doc(launcher): add more docs to the `launcher` itself and link in the README (#257)

feat(benchmark): add support for private tokenizers (#262)

Adding docs on how dynamic batching works. (#258)

This PR starts the minimal possible amount of explanation I could think
of. It tries to explain how dynamic batching occurs, the interactions
with past key values and ignores the padding problem.

Maybe some drawings could help too but I kept it to text for now.

chore(github): add templates (#264)

fix(server): fix typo in tokenizers decode (#269)

closes #268

feat(server): support hf endpoint weight layout (#266)

fix(launcher): pass weights cache override to the download process (#274)

closes #273

fix(launcher): handle hub branches (#278)

fix(server): Removes the parallelism in file convertion (during download) (#275)

feat(launcher): Improve error message when download process fails. (#276)

fix(server): fix convert (#284)

chore: add `flash-attention` to docker ignore (#287)

included when building docker locally.
(Where the local dirs might have the flash-attention folder.)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

fea(server): decrease convert RAM requirements (#286)

fix(dockerfile): fix nvidia env vars (#297)

Fixes #291

feat(router): Adding response schema for compat_generate (#292)

feat(docker): add benchmarking tool to docker image (#298)

fix(docker): fix docker build (#299)

feat(server): optim flash causal lm decode_token (#285)

fix(docker): fix nvidia env vars (#305)

fix(docker): remove nvidia require cuda env (#310)

feat(server): shard token decode (#303)

feat(server): use float16 (#304)

fix(docker): remove CUDA_VERSION

feat(server): use cuda graph in logits warping (#302)

fix(server): fix multinomial implem in Sampling

feat(server): GPTQ quantization (step1) (#277)

Changes only the type from `bool` to `Option<Enum>` pretty much
everywhere.
- Use `Optional[str]` in Python (easier to manage than importing type
everywhere). Except for the cli to get proper validation
- Updated all models to handle gracefully new values. (Error out if
unknown value, or gptq since not implemented).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

chore(docker): use nvidia base image (#318)

fix(docker): remove quantize default

fix(docker): use ubuntu20.04

Hotfixes for santacoder/bigcode. (#294)

Hotfixes:

- Uses `model_type`=`gpt_bigcode` for more general usage.
- Hotfixes linked lm_head vs wte_embedding (safetensors file do not
contain the key, correctly when the file is sharded, where as pytorch
copies the tensor)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

Lifting check_unitialized. (#325)

Lifting check_unitialized.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

Removing dead variables. (#327)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat(ci): custom gpu runners (#328)

Single place for TP layers + Dropout Layer Norm + FastLinear (#329)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat: add snapshot testing (#282)

feat(integration-tests): improve comparison and health checks (#336)

fix(server): fix decode token (#334)

Fixes #333

---------

Co-authored-by: Nicolas Patry <[email protected]>

fix: set MODEL_ID in sagemaker-entrypoint script (#343)

feat(server): Support BLOOMChat-176B (#348) (#351)

@njhill,
temporary workaround to be able to run our CI as secrets are not
available to runners run by external contributors. I will ask around to
see if there is a better way.

Co-authored-by: Nick Hill <[email protected]>

fix(server): fix init for flash causal lm (#352)

Fixes #347

fix(server): t5 cannot run in f16 (#356)

Fix #349

fix(ci): fix security group (#359)

Switch security group used for ci
(open outbound rules)

Signed-off-by: Raphael <[email protected]>
Co-authored-by: Raphael <[email protected]>

feat: add nightly load testing (#358)

chore(sever): update requirements (#357)

Fixes #338

feat(server): support fp16 for t5 (#360)

Fixes #349

feat(server): do not use device_map auto on single GPU (#362)

feat(server): support trust_remote_code (#363)

feat(router): log input/ouput at debug level (#364)

@njhill FYI

v0.7.0 (#353)

feat: decrease IPC proto size (#367)

Closes #307 #308

feat(benchmarker): add summary tables (#368)

feat(server): support vectorized warpers in flash causal lm (#317)

Co-authored-by: Joel Lamy-Poirier <[email protected]>

Fix issue when load AutoModelForSeq2SeqLM model (#370)

fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

fix(server): fix quantization

feat(server): support RefinedWeb models (#379)

v0.8.0

increase health checks

feat(server): add retry on download (#384)

fix(server): fix bnb quantization for CausalLM models (#385)

v0.8.1

fix(server): fix has_position_ids (#395)

Fix #389

feat(server): remove trust_remote_code requirement for falcon models (#396)

feat(server): load santacoder/starcoder models with safetensors (#393)

Fix #366

v0.8.2

feat(sagemaker): add trust remote code to entrypoint (#394)

feat(launcher): parse oom signal (#404)

feat(server): only compute prefill logprobs when asked (#406)

Close #288

feat(server): batch tokenization for flash causal lm (#411)

chore: update openapi schema

feat(server): Rework model loading (#344)

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

feat(server): optimize dist ops (#434)

docs(launcher): fix CUDA_VISIBLE_DEVICES helper comment (#441)

It solves a typo in the comment sections referencing the environment
variable `CUDA_VISIBLE_DEVICES`. No misspelling references to this
variable have been found in code logic leading to undefined behaviour or
bugs. This PR is not expected to perform any code logic modification.

fix(makefile): Fix typo and use POSIX comparison in the makefile (#443)

This PR fixes:
- The usage of non posix comparison which may fail depending on the
shell used (`=` will always work, `==` only with bash)
- Typo in the env variable name displayed in the error message
`BUILD_EXTENSION` instead of `BUILD_EXTENSIONS`

<!-- Remove if not applicable -->

Fixes #422

feat(server): pre-allocate past key values for flash causal LM (#412)

feat(router): add ngrok integration (#453)

feat(server): improve flash attention import errors (#465)

@lewtun, is this enough?

Closes #458
Closes #456

fix(server): fix warpers on CPU (#472)

Closes #471

fix(server): Fixing T5 in case the names are mixed up. (#475)

feat(server): Update convert logic. (#483)

Should be more robust to shared tensors (ok when using
      `from_pretrained). But forcing us to add new checks in our loading
      code (since the chosen key to keep might be different from
      `transformers`).

---------

Co-authored-by: Ubuntu <[email protected]>

feat(server): Adding new ignore_rule for conversion. (#485)

fix(router): add timeout on flume sends (#488)

feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)

Let's start discussing implementation.

- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).

Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.

My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

fix(server): Do not init process group if already initialized (#388)

feat(router): add header option to disable buffering for the generate_stream response (#498)

generate_stream endpoint response stream.

Problem: If a model is run behind a proxy server such as nginx that has
buffering enabled then the response stream from generate_stream gets
aggregated into a single response which basically disables streaming.
Instead of getting a chunked response where each token is presented over
time the response presents everything all at once.

Solution: This change adds the `X-Accel-Buffering` http header which
disables buffering for the generate_stream response, allowing the
response to stream properly.

feat(server): add paged attention to flash models (#516)

Closes #478

feat(router): arg validation (#519)

feat: Add the option to force another dtype than `f16`. (#513)

fix(launcher): fix issue where launcher does not properly report shard failures (#522)

v0.9.0 (#525)

feat(server): Add Non flash MPT. (#514)

This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
https://github.com/huggingface/text-generation-inference/issues/361
Fixes
https://github.com/huggingface/text-generation-inference/issues/491
Fixes
https://github.com/huggingface/text-generation-inference/issues/290

fix: Update server/Makefile to include Makefile-vllm (#520)

For consistency and ease of use (you can just run `make` to install vllm
without any extra steps).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs(benchmarker): Adding some help for the options in `text-generation-benchmark`. (#462)

fix(server): Handle loading from local files for MPT (#534)

This PR allows the MPT model to be loaded from local files. Without this
change, an exception will be thrown by `hf_hub_download` function if
`model_id` is a local path.

fix(server): avoid errors for very small top_p values (#544)

See https://github.com/huggingface/transformers/pull/24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.

feat(server): use latest flash attention commit (#543)

@njhill FYI

feat(router): add argument for hostname in router (#545) (#550)

In title. Adds argument `--hostname` in router to support something like
`--hostname ::`. Tested with

```commandline
cargo run -- --port 8080 --hostname ::
curl -I -X GET 'http://[::1]:8080/health'  # failed before this commit
```

Trigger CI

---------

Co-authored-by: Phil Chen <[email protected]>

fix(server): decrease memory fragmentation (#557)

v0.9.1 (#558)

fix(server): harden the weights choice to save on disk. (#561)

- Look at `transformers` base class to check for
  `_key_to_ignore_on_load_missing` or `_tied_weights` which are the
  standard attributes to select the keys to NOT save on disk (since they
  are ignored)

- Modified safetensors code (to be reflected in safetensors even if it's
  an internal function).

- Will not work for trust_remote_code=True repos (like santacoder).

Should help with :
https://github.com/huggingface/text-generation-inference/issues/555
and : https://github.com/huggingface/text-generation-inference/pull/501
and https://github.com/huggingface/text-generation-inference/issues/556
and
https://github.com/huggingface/text-generation-inference/issues/482#issuecomment-1623713593

feat: better errors for warmup and TP (#575)

Close #571

fix(server): Fixing RW code (it's remote code so the Arch checking doesn't work to see which weights to keep). (#579)

Fixes #555

feat(server): Support for env value for GPTQ_BITS and GPTQ_GROUPSIZE. (#580)

Some models are already converted, and do not have those values in the
file, this enables users to use them with less friction.

Went for pure env based because adding flags would end up (imo) very
tedious to maintain. There's a lot of sanitation to do: those flags
would be errors if not used in conjuction with `--quantize gptq`.
Then the flags need to exist in the launcher and the server passing them
all throughout all function calls.

This PR is intended as an easy escape hatch, not the defacto method to
use gptq in TGI.

Fixes #500

chore: migrate ci region for more availability. (#581)

fix(server): T5 weights names. (#582)

Fixes #541

fix(server): Adding logger import to t5_modeling.py (#585)

Logger is referenced during the apex importing but is not imported,
causing a NameError

fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590)

This fixes a typo and extends the GPTP_BITS environment variables
through to the second method which requires the same logic. Please let
me know if there's anything I've misunderstood in this change.

Thanks @Narsil for the original fix.

feat(server): Implements sharding for non divisible `vocab_size`. (#583)

- The code is relatively easy (just disable the checks on Embedding and
Head)

This cannot be done in the same easy fashion for hidden_dim/head_dim.
It's relatively easy on some models (classic MHA) but it would make the
other
models (MQA) much more complex, and GPTQ quantization another quite
hairy piece
of code.

feat(server): empty cache on errors

GPTQ Env vars: catch correct type of error (#596)

When passing in environment variables like gptq_bits, we still get
errors thrown from TGI because the try/catch block is catching the wrong
type of error. This PR aims to fix that.

@Narsil - let me know if this is how you want this formatted. My Python
is a little shaky, so I hope this syntax is correct.

feat(launcher): add arg validation and drop subprocess (#595)

feat(router): explicit warning if revision is not set (#608)

docs: README: Add logo + baseline (#611)

![image](https://github.com/huggingface/text-generation-inference/assets/3841370/58177321-479f-4ad1-b3bc-cec027423984)

fix(server): blacklist local files (#609)

Close #589 #602

v0.9.2 (#616)

fix(server): empty_cache when stopped

fix(launcher): Rename `b-float16` to `bfloat16` in the launcher arg (#621)

fea(launcher): debug logs (#623)

feat(server): Reworking the quantization script so it's still universal (not llama specific) (#587)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).

Reworking the quantization script so it's still universal (not llama
specific)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).

Still need to investigate the potential differences in quantization
results.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat(server): flash attention v2 (#624)

feat(server): add support for llamav2 (#633)

v0.9.3 (#634)

fix(server): fix llamav2 config (#635)

feat(server): auto max_batch_total_tokens for flash att models (#630)

feat(router): ngrok edge (#642)

docs: Update README.md (#639)

docs: Update README.md (#643)

Add trust_remote_code to quantize script (#647)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes a bug appeared with MR #587 fixing issue #552.
See the discussion in #552.

With MR #587 the trust_remote_code variable is not passed to
AutoModelForCausalLM, but is found in the function signature. This
prevents models like falcon to be quantized, because trust_remote_code
is required. This MR fixes the issue.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [X] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
@Narsil
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

 -->

fix(server): llama v2 GPTQ (#648)

As per title & reported
https://github.com/huggingface/text-generation-inference/issues/601#issuecomment-1641435956
https://huggingface.co/TheBloke/Llama-2-70B-chat-GPTQ/discussions/5

Test it:

```
GPTQ_BITS=4 GPTQ_GROUPSIZE=1 text-generation-launcher --model-id TheBloke/Llama-2-70B-chat-GPTQ --port 8080 --num-shard 4 --quantize gptq
```
&
```
curl 127.0.0.1:8080/generate \
    -X POST \
    -d '{"inputs":"hey llama","parameters":{"max_new_tokens":256}}' \
    -H 'Content-Type: application/json'
```

fix(server): Fixing non parameters in quantize script `bigcode/starcoder` was an example. (#661)

fix(server): use mem_get_info to get kv cache size (#664)

Close
https://github.com/huggingface/text-generation-inference/issues/649
Close
https://github.com/huggingface/text-generation-inference/issues/651
Close
https://github.com/huggingface/text-generation-inference/issues/653
Close #636

feat(server): Add exllama GPTQ CUDA kernel support #553 (#666)

Just trying to get the integration tests to pass.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Felix Marty <[email protected]>

Directly load GPTBigCode to specified device (#618)

This PR directly load GPTBigCode to specified device, avoiding moving
model between devices.

This PR directly load GPTBigCode to specified device, avoiding moving
model between devices.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

@OlivierDehaene OR @Narsil

feat(server): add local prom and health routes if running w/ ngrok

feat: add cuda memory fraction (#659)

Close #673

fix(server): fix exllama buffers (#689)

Close #683

feat(server): Using `quantize_config.json` instead of GPTQ_BITS env variables. (#671)

- Current PR is not great because we're side stepping the
  `Weights.__init__` but Weights shouldn't requires anything related
  to the config or the model_id as it aims to be a simple Wrapper
  over multi file loading.
- Ideal solution would be to use something like Rust enum
  ```
  enum Quantize{
    Bitandbytes(Bitsandbytes),
    GPTQ(bits: usize, groupsize: usize)
  ```
  And passing that around during load. Unfortunately we don't
  have access to this, so for now, side-stepping seems easier.

- Re-enabling groupsize<0 with exllama (confirmed it works.)

Helps #601

In next steps we should make sure our quantization script uses that
format and make it standard.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs(README): update readme

fix(server): fix quantization python requirements (#708)

fix(server): fix missing datasets in quantize

feat(server): support new falcon config (#712)

v0.9.4 (#713)

Add section about TGI on other AI hardware accelerators in README (#715)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

As per title.

- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs: Add hardware section to TOC in README (#721)

feat(server): update vllm version (#723)

chore: update license to HFOIL (#725)

v1.0.0 (#727)

Local gptq support. (#738)

Redoes #719

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

Fix typing in `Model.generate_token` (#733)

This PR fixes a minor type annotation issue in the signature of
`Model.generate_token`.

All existing overrides of `Model.generate_token` return
`Tuple[List[Generation], Optional[B]]`:

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/causal_lm.py#L535-L537

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/flash_causal_lm.py#L802-L804

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/seq2seq_lm.py#L589-L591

I suspect that back in 017a2a8c when `GeneratedText` and `Generation`
were separated, the function signature was not updated.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

CC @OlivierDehaene

Adding Rope scaling. (#741)

- Adds Rope NTK scaling.

Done because
https://github.com/huggingface/text-generation-inference/pull/529 was
closed
Took some code from
https://github.com/huggingface/transformers/pull/24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).

Fixes #512

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

chore: fix typo in mpt_modeling.py (#737)

Fixed typo.
<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

implemetation -> implementation

- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update…
tjluyao added a commit to mlsys-io/kv.run that referenced this pull request Jul 7, 2024
Init

fix: cleanup

Add load testing

Refactored gRPC interface
Added validation logic

ValidationError was not correctly handled

Use axum

feat: Docker image

feat: Add AML deployment

Update aml deployment

feat: Improve error handling

feat: Add arguments to CLI

v0.1.0

fix(validation): Fix error messages

feat(router): Add max_waiting_tokens

Create LICENSE (#2)

feat(server): Use safetensors

Co-authored-by: OlivierDehaene <[email protected]>

feat(client): Simplify sharded logic

feat(server): Support bitsandbytes

feat(server): Support all AutoModelForCausalLM on a best effort basis

feat: Use json formatter by default in docker image

fix(models): Revert buggy support for AutoModel

feat(server): Support generic AutoModelForCausalLM

feat(server): Support AutoModelForSeq2SeqLM

feat(launcher): Pass CUDA_VISIBLE_DEVICES to the shard

feat(server): Improved doc

fix(server): Fix Transformers fork version

feat(server): Clarify CausalLMBatch concatenate method

feat(rust): Update to 1.65

fix(router): Fix HTTP status codes

fix(readme): Typo

fix(router): Handle tokenizer errors

feat(server): Support Galactica (#4)

fix(batching): Avoid theoretical hang in batcher loop (#5)

- Avoid theoretical hang in batcher loop
- Avoid a couple of clones in the router generate method
- Keep attention mask tensors as integers
- Remove num_heads attribute

Co-authored-by: OlivierDehaene <[email protected]>

feat(server): Add model tests (#6)

fix(server): Only pad to multiple of 8 on GPUs

feat: Support stop sequences (#7)

feat: Return logprobs (#8)

feat(launcher): Add integration tests (#9)

fix(server): Fix stop sequences (#11)

fix(server): Check for device type correctly when determining initial padding (#16)

AFAIK there is no torch device type called "gpu".

fix(router): Include special tokens when tokenizing (#14)

There's currently a discrepancy in the tokenization between the router
and python server code. The latter includes special tokens but former
does not.

This results in a token count mismatch for seq2seq models such as mt0
where the tokenizer emits an EOS token at the end.

This in turn results in some unexpected/incorrect output, in particular
when batch concatenation is involved, because the python code uses the
input length passed from the router for each row.

As far as I can tell, it is better to include this token in the encoder
`input_ids`, so I guess it's best to just adjust on the router side.

feat(router): Add const parameters to validation logic  (#15)

I noticed some opportunity to collapse some of the logic, in case you
are interested.

fix(server): Use cleanup_tokenization_spaces=False for lossless decoding (#13)

Fixes #12 in the easiest way I could think of.

feat(launcher): Log server stdout (#19)

Co-authored-by: Nick Hill <[email protected]>

fix(server): Minor refactorization using new_zeros (#24)

- Fix some type hints, in particular base tokenizer class
- Make use of `tensor.new_zero/empty` methods
- Simplify env var string parsing in launcher

fix(router): Obey max batch size (#23)

feat(server): Support SantaCoder (#26)

fix(server): Fix position ids (#28)

feat(docker): Make the image compatible with api-inference (#29)

fix(docker): fix api-inference deployment (#30)

fix(router): fix api-inference deployment (#31)

fix(dockerfile): fix docker build (#32)

feat(bloom): use torch.nn.Linear and torch.nn.GELU (#33)

feat(router): Remove second lock from batcher hot path (#27)

@njhill

feat: Support sampling seeding (#37)

Co-authored-by: Yannic Kilcher <[email protected]>

feat: Add token streaming using ServerSideEvents support (#36)

Add token streaming using ServerSideEvents (SSE).

The signature of the SSE events is:

```rust
struct Details {
    finish_reason: String,
    generated_tokens: u32,
    seed: Option<u64>,
}

struct StreamResponse {
    token: Token,
    generated_text: Option<String>,
    details: Option<Details>,
}

struct ErrorResponse {
    error: String,
}
```

Revert "feat: Add token streaming using ServerSideEvents support" (#40)

Reverts huggingface/text-generation-inference#36

fix(server): fix seeding on gpu (#42)

fix(server): fix seeding with multiple shards (#44)

feat: Add token streaming using ServerSideEvents support (#41)

fix(server): fix quantization for sharded models (#45)

feat(server): Support GPT-Neox (#39)

feat(ci): Docker build and push (#46)

feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)

feat(server): support repetition penalty (#47)

feat(server): allow the server to use a local weight cache (#49)

fix(server): allow greedy repetition penalty (#51)

feat(router): use background task to manage request queue (#52)

Co-authored-by: Nick Hill <[email protected]>

breaking(router): modify /generate API to only return generated text (#50)

@njhill, @yk FYI

generated_text was concatenated to the user prompt for legacy reason. We
want to remove this behaviour as we don't think it is useful and even
detrimonial to usability.

We also remove the unused Vec.

feat(router): refactor API and add openAPI schemas (#53)

feat(docs): Clarify installation steps (#54)

Adds some bits for first-time users (like me 😄 )

feat(ci): push to AML registry (#56)

fix(server): better handling of inference mode (#57)

V0.2.1 (#58)

feat(server): support t5 (#59)

fix(docker): increase shm size (#60)

fixed SSE naming (#61)

https://en.wikipedia.org/wiki/Server-sent_events

feat: add distributed tracing (#62)

feat: add safetensors conversion (#63)

feat(server): improve download logging (#66)

feat(launcher): add disable_custom_kernels arg (#67)

feat(router): add max_total_tokens and empty_input validation (#68)

closes #65

fix(launcher): copy current env vars to subprocesses (#70)

closes #69

feat(router): add prometheus metrics scrape endpoint (#71)

v0.3.0 (#72)

feat(router): add cors allow origin options (#73)

feat(server): enable hf-transfer (#76)

fix(server): remove position_ids from galactica forward (#82)

closes #80

feat(server): pre-allocate max attention mask (#75)

v0.3.1 (#84)

feat(server): add special token bool (#85)

fix(docs): fix openapi schema (#86)

fix(server): fix token_is_special (#87)

feat(router): add legacy route for api-inference support (#88)

feat(router): ask hf.co for pipelinetag to decide on compat_return_full_text (#89)

feat(router): add api-inference headers (#91)

feat(server): add logits watermark (#90)

feat(server): update to hf_transfer==0.1.2 (#93)

feat(ci): improve CI speed (#94)

fix(launcher): add router parameters to launcher (#95)

feat(server): fix transformers commit (#96)

v0.3.2 (#97)

fix(server): fix generate_stream by forcing tokens to be decoded correctly (#100)

feat: allow local models (#101)

closes #99

feat: add supported models (#102)

feat(clients): Python client (#103)

fix(server): fix galactica batch (#106)

closes #105

feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES (#107)

feat(launcher): default num_shard to CUDA_VISIBLE_DEVICES if possible (#108)

fix(python-client): stream not set on the sync client (#109)

fix(server): fix index out of range for watermarking (#110)

feat: support typical sampling (#114)

closes #112

fix(server): do not warp prefill logits (#116)

feat(router): support left truncation (#115)

closes #111

feat(router): add best_of parameter (#117)

feat(python-client): add new parameters (#118)

v0.4.0 (#119)

feat: add OpenAssistant/oasst-sft-1-pythia-12b to the list of supported models (#122)

…ed models

fix(server): revert gpt-neox optims (#123)

fix(server): add position ids to neox (#126)

fix(server): use server tokenizer as gt (#128)

fix(python-client): relax dependencies (#129)

feat(python-client): add cookies to Client constructors and requests (#132)

I have a use case where we need to pass cookies (for auth reasons) to an
internally hosted server.

Note: I couldn't get the client tests to pass - do you need to have an
HF token?

```python
FAILED tests/test_client.py::test_generate - text_generation.errors.BadRequestError: Authorization header is correct, but the token seems invalid
```

feat(ci): add ci paths (#134)

feat: Add note about NVIDIA drivers (#64)

Co-authored-by: OlivierDehaene <[email protected]>

feat(python-client): release v0.4.0 (#135)

feat(python-client): add CI (#136)

feat(server): flash neoX (#133)

fix(server): fix flash-neox scores warping (#137)

feat(server): cleanup flash neox loading (#139)

v0.4.1 (#140)

fix(server): Avoid using try/except to determine kind of AutoModel (#142)

feat(server): Add mypy-protobuf (#141)

Generates .pyi files for protobuf stubs which provide strong typing
information. Very helpful for IDE auto-completion, etc.

feat(server): clear cache on error (#143)

feat(server): reduce mlp and attn in one op for flash neox (#145)

feat: aws sagemaker compatible image (#147)

The only difference is that now it pushes to
registry.internal.huggingface.tech/api-inference/community/text-generation-inference/sagemaker:...
instead of
registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sagemaker-...

---------

Co-authored-by: Philipp Schmid <[email protected]>

fix(ci): fix sagemaker action (#148)

feat(benchmark): tui based benchmarking tool (#149)

fix(server): fix flash neox rotary embeddings (#150)

v0.4.2 (#151)

v0.4.3 (#152)

feat(server): flash santacoder (#153)

docs(readme): provide link Logits Warper README (#154)

fix(server): fix escape characters in stop sequence (#155)

feat(docker): improve flash_attention caching (#160)

feat(launcher): allow disabling hf_transfer (#161)

fix(rust-client): use join_all instead of select_all to hopefully fix nccl issues (#162)

fix(router): use buckets for metrics histograms (#163)

feat(router): make router input validation optional (#164)

feat(server): add flash attention llama (#144)

feat(server): support OPT models (#55)

OPT models do not all have a `tokenizer.json` file on the hub at the
moment. Can't merge for now.

v0.5.0 (#168)

feat(server): optimize decode for sane tokenizers (#170)

feat(server): support sharded santacoder (#167)

fix(launcher): revert change on shard errors (#173)

fix(ci): fix CVE in github-slug-action (#174)

feat(ci): add image signing with cosign (#175)

feat(ci): add Trivy and scan docker image (#178)

feat(ci): use large runners (#179)

feat(ci): faster scanning (#180)

fix(ci): fix ci permissions (#181)

fea(dockerfile): better layer caching (#159)

fix(ci): fix cosign error (#183)

fix(docker): fix docker image (#184)

fix(docker): fix image (#185)

fix(docker): revert dockerfile changes (#186)

fix(docker): fix docker image dependencies (#187)

fix(router): fix truncation (#190)

closes #189

feat(python-client): get list of currently deployed tgi models using the inference API (#191)

feat(router): add info route (#196)

close #125

feat(server): support quantization for flash models (#200)

closes #197

feat(server): check cuda capability when importing flash models (#201)

close #198

fix(server): fix hf_transfer issue with private repos (#203)

fix(docker): remove unused dependencies (#205)

fix(router): add auth token to get model info (#207)

feat(router): add git sha to info route (#208)

feat(router): drop requests when client closes the channel (#202)

fix(ci): fix sha in docker image (#212)

feat(server): flash attention past key value optimizations (#213)

feat(router): add device and dtype info (#215)

fix(server): fix past key values logic (#216)

@njhill fyi

fix(server): cleanup new flash past_key_values logic (#217)

fix(server): fix flash causal (#218)

fix(server): fix flash causal (#219)

fix(server): fix flash batch filtering (#220)

misc: update to rust 1.69 (#221)

v0.6.0 (#222)

feat(server): reduce memory requirement (#214)

chore(server): update huggingface-hub (#227)

feat(router): use number of tokens in batch as input for dynamic batching (#226)

Co-authored-by: Nick Hill <[email protected]>

feat(router): add endpoint info to /info route (#228)

chore(server): update safetensors version (#235)

fix(python-client): add auth headers to is supported requests (#234)

Starting some routing tests. (#233)

fix(benchmarking): fix benchmarking tool

chore(launcher): refactor logic (#242)

Hopefully it's cleaner

feat(router): add tests to validation (#237)

feat(router): new healthcheck that skips the queue (#244)

Co-authored-by: OlivierDehaene <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

fix(server): fix reshaping of bloom past_key_values in concatenate() (#252)

Introduced in #214

Fixes #249

fix(server): Small tidy of code from recent changes (#251)

remaining_decode_tokens was calculated twice in Seq2SeqLMBatch.filter()

chore(server): update transformers (#250)

feat(server): add watermarking tests (#248)

feat(docker): add nvidia env vars (#255)

doc(launcher): add more docs to the `launcher` itself and link in the README (#257)

feat(benchmark): add support for private tokenizers (#262)

Adding docs on how dynamic batching works. (#258)

This PR starts the minimal possible amount of explanation I could think
of. It tries to explain how dynamic batching occurs, the interactions
with past key values and ignores the padding problem.

Maybe some drawings could help too but I kept it to text for now.

chore(github): add templates (#264)

fix(server): fix typo in tokenizers decode (#269)

closes #268

feat(server): support hf endpoint weight layout (#266)

fix(launcher): pass weights cache override to the download process (#274)

closes #273

fix(launcher): handle hub branches (#278)

fix(server): Removes the parallelism in file convertion (during download) (#275)

feat(launcher): Improve error message when download process fails. (#276)

fix(server): fix convert (#284)

chore: add `flash-attention` to docker ignore (#287)

included when building docker locally.
(Where the local dirs might have the flash-attention folder.)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

fea(server): decrease convert RAM requirements (#286)

fix(dockerfile): fix nvidia env vars (#297)

Fixes #291

feat(router): Adding response schema for compat_generate (#292)

feat(docker): add benchmarking tool to docker image (#298)

fix(docker): fix docker build (#299)

feat(server): optim flash causal lm decode_token (#285)

fix(docker): fix nvidia env vars (#305)

fix(docker): remove nvidia require cuda env (#310)

feat(server): shard token decode (#303)

feat(server): use float16 (#304)

fix(docker): remove CUDA_VERSION

feat(server): use cuda graph in logits warping (#302)

fix(server): fix multinomial implem in Sampling

feat(server): GPTQ quantization (step1) (#277)

Changes only the type from `bool` to `Option<Enum>` pretty much
everywhere.
- Use `Optional[str]` in Python (easier to manage than importing type
everywhere). Except for the cli to get proper validation
- Updated all models to handle gracefully new values. (Error out if
unknown value, or gptq since not implemented).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

chore(docker): use nvidia base image (#318)

fix(docker): remove quantize default

fix(docker): use ubuntu20.04

Hotfixes for santacoder/bigcode. (#294)

Hotfixes:

- Uses `model_type`=`gpt_bigcode` for more general usage.
- Hotfixes linked lm_head vs wte_embedding (safetensors file do not
contain the key, correctly when the file is sharded, where as pytorch
copies the tensor)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

Lifting check_unitialized. (#325)

Lifting check_unitialized.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

Removing dead variables. (#327)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat(ci): custom gpu runners (#328)

Single place for TP layers + Dropout Layer Norm + FastLinear (#329)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat: add snapshot testing (#282)

feat(integration-tests): improve comparison and health checks (#336)

fix(server): fix decode token (#334)

Fixes #333

---------

Co-authored-by: Nicolas Patry <[email protected]>

fix: set MODEL_ID in sagemaker-entrypoint script (#343)

feat(server): Support BLOOMChat-176B (#348) (#351)

@njhill,
temporary workaround to be able to run our CI as secrets are not
available to runners run by external contributors. I will ask around to
see if there is a better way.

Co-authored-by: Nick Hill <[email protected]>

fix(server): fix init for flash causal lm (#352)

Fixes #347

fix(server): t5 cannot run in f16 (#356)

Fix #349

fix(ci): fix security group (#359)

Switch security group used for ci
(open outbound rules)

Signed-off-by: Raphael <[email protected]>
Co-authored-by: Raphael <[email protected]>

feat: add nightly load testing (#358)

chore(sever): update requirements (#357)

Fixes #338

feat(server): support fp16 for t5 (#360)

Fixes #349

feat(server): do not use device_map auto on single GPU (#362)

feat(server): support trust_remote_code (#363)

feat(router): log input/ouput at debug level (#364)

@njhill FYI

v0.7.0 (#353)

feat: decrease IPC proto size (#367)

Closes #307 #308

feat(benchmarker): add summary tables (#368)

feat(server): support vectorized warpers in flash causal lm (#317)

Co-authored-by: Joel Lamy-Poirier <[email protected]>

Fix issue when load AutoModelForSeq2SeqLM model (#370)

fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

fix(server): fix quantization

feat(server): support RefinedWeb models (#379)

v0.8.0

increase health checks

feat(server): add retry on download (#384)

fix(server): fix bnb quantization for CausalLM models (#385)

v0.8.1

fix(server): fix has_position_ids (#395)

Fix #389

feat(server): remove trust_remote_code requirement for falcon models (#396)

feat(server): load santacoder/starcoder models with safetensors (#393)

Fix #366

v0.8.2

feat(sagemaker): add trust remote code to entrypoint (#394)

feat(launcher): parse oom signal (#404)

feat(server): only compute prefill logprobs when asked (#406)

Close #288

feat(server): batch tokenization for flash causal lm (#411)

chore: update openapi schema

feat(server): Rework model loading (#344)

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

feat(server): optimize dist ops (#434)

docs(launcher): fix CUDA_VISIBLE_DEVICES helper comment (#441)

It solves a typo in the comment sections referencing the environment
variable `CUDA_VISIBLE_DEVICES`. No misspelling references to this
variable have been found in code logic leading to undefined behaviour or
bugs. This PR is not expected to perform any code logic modification.

fix(makefile): Fix typo and use POSIX comparison in the makefile (#443)

This PR fixes:
- The usage of non posix comparison which may fail depending on the
shell used (`=` will always work, `==` only with bash)
- Typo in the env variable name displayed in the error message
`BUILD_EXTENSION` instead of `BUILD_EXTENSIONS`

<!-- Remove if not applicable -->

Fixes #422

feat(server): pre-allocate past key values for flash causal LM (#412)

feat(router): add ngrok integration (#453)

feat(server): improve flash attention import errors (#465)

@lewtun, is this enough?

Closes #458
Closes #456

fix(server): fix warpers on CPU (#472)

Closes #471

fix(server): Fixing T5 in case the names are mixed up. (#475)

feat(server): Update convert logic. (#483)

Should be more robust to shared tensors (ok when using
      `from_pretrained). But forcing us to add new checks in our loading
      code (since the chosen key to keep might be different from
      `transformers`).

---------

Co-authored-by: Ubuntu <[email protected]>

feat(server): Adding new ignore_rule for conversion. (#485)

fix(router): add timeout on flume sends (#488)

feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)

Let's start discussing implementation.

- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).

Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.

My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: OlivierDehaene <[email protected]>

fix(server): Do not init process group if already initialized (#388)

feat(router): add header option to disable buffering for the generate_stream response (#498)

generate_stream endpoint response stream.

Problem: If a model is run behind a proxy server such as nginx that has
buffering enabled then the response stream from generate_stream gets
aggregated into a single response which basically disables streaming.
Instead of getting a chunked response where each token is presented over
time the response presents everything all at once.

Solution: This change adds the `X-Accel-Buffering` http header which
disables buffering for the generate_stream response, allowing the
response to stream properly.

feat(server): add paged attention to flash models (#516)

Closes #478

feat(router): arg validation (#519)

feat: Add the option to force another dtype than `f16`. (#513)

fix(launcher): fix issue where launcher does not properly report shard failures (#522)

v0.9.0 (#525)

feat(server): Add Non flash MPT. (#514)

This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
https://github.com/huggingface/text-generation-inference/issues/361
Fixes
https://github.com/huggingface/text-generation-inference/issues/491
Fixes
https://github.com/huggingface/text-generation-inference/issues/290

fix: Update server/Makefile to include Makefile-vllm (#520)

For consistency and ease of use (you can just run `make` to install vllm
without any extra steps).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs(benchmarker): Adding some help for the options in `text-generation-benchmark`. (#462)

fix(server): Handle loading from local files for MPT (#534)

This PR allows the MPT model to be loaded from local files. Without this
change, an exception will be thrown by `hf_hub_download` function if
`model_id` is a local path.

fix(server): avoid errors for very small top_p values (#544)

See https://github.com/huggingface/transformers/pull/24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.

feat(server): use latest flash attention commit (#543)

@njhill FYI

feat(router): add argument for hostname in router (#545) (#550)

In title. Adds argument `--hostname` in router to support something like
`--hostname ::`. Tested with

```commandline
cargo run -- --port 8080 --hostname ::
curl -I -X GET 'http://[::1]:8080/health'  # failed before this commit
```

Trigger CI

---------

Co-authored-by: Phil Chen <[email protected]>

fix(server): decrease memory fragmentation (#557)

v0.9.1 (#558)

fix(server): harden the weights choice to save on disk. (#561)

- Look at `transformers` base class to check for
  `_key_to_ignore_on_load_missing` or `_tied_weights` which are the
  standard attributes to select the keys to NOT save on disk (since they
  are ignored)

- Modified safetensors code (to be reflected in safetensors even if it's
  an internal function).

- Will not work for trust_remote_code=True repos (like santacoder).

Should help with :
https://github.com/huggingface/text-generation-inference/issues/555
and : https://github.com/huggingface/text-generation-inference/pull/501
and https://github.com/huggingface/text-generation-inference/issues/556
and
https://github.com/huggingface/text-generation-inference/issues/482#issuecomment-1623713593

feat: better errors for warmup and TP (#575)

Close #571

fix(server): Fixing RW code (it's remote code so the Arch checking doesn't work to see which weights to keep). (#579)

Fixes #555

feat(server): Support for env value for GPTQ_BITS and GPTQ_GROUPSIZE. (#580)

Some models are already converted, and do not have those values in the
file, this enables users to use them with less friction.

Went for pure env based because adding flags would end up (imo) very
tedious to maintain. There's a lot of sanitation to do: those flags
would be errors if not used in conjuction with `--quantize gptq`.
Then the flags need to exist in the launcher and the server passing them
all throughout all function calls.

This PR is intended as an easy escape hatch, not the defacto method to
use gptq in TGI.

Fixes #500

chore: migrate ci region for more availability. (#581)

fix(server): T5 weights names. (#582)

Fixes #541

fix(server): Adding logger import to t5_modeling.py (#585)

Logger is referenced during the apex importing but is not imported,
causing a NameError

fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590)

This fixes a typo and extends the GPTP_BITS environment variables
through to the second method which requires the same logic. Please let
me know if there's anything I've misunderstood in this change.

Thanks @Narsil for the original fix.

feat(server): Implements sharding for non divisible `vocab_size`. (#583)

- The code is relatively easy (just disable the checks on Embedding and
Head)

This cannot be done in the same easy fashion for hidden_dim/head_dim.
It's relatively easy on some models (classic MHA) but it would make the
other
models (MQA) much more complex, and GPTQ quantization another quite
hairy piece
of code.

feat(server): empty cache on errors

GPTQ Env vars: catch correct type of error (#596)

When passing in environment variables like gptq_bits, we still get
errors thrown from TGI because the try/catch block is catching the wrong
type of error. This PR aims to fix that.

@Narsil - let me know if this is how you want this formatted. My Python
is a little shaky, so I hope this syntax is correct.

feat(launcher): add arg validation and drop subprocess (#595)

feat(router): explicit warning if revision is not set (#608)

docs: README: Add logo + baseline (#611)

![image](https://github.com/huggingface/text-generation-inference/assets/3841370/58177321-479f-4ad1-b3bc-cec027423984)

fix(server): blacklist local files (#609)

Close #589 #602

v0.9.2 (#616)

fix(server): empty_cache when stopped

fix(launcher): Rename `b-float16` to `bfloat16` in the launcher arg (#621)

fea(launcher): debug logs (#623)

feat(server): Reworking the quantization script so it's still universal (not llama specific) (#587)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).

Reworking the quantization script so it's still universal (not llama
specific)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).

Still need to investigate the potential differences in quantization
results.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

feat(server): flash attention v2 (#624)

feat(server): add support for llamav2 (#633)

v0.9.3 (#634)

fix(server): fix llamav2 config (#635)

feat(server): auto max_batch_total_tokens for flash att models (#630)

feat(router): ngrok edge (#642)

docs: Update README.md (#639)

docs: Update README.md (#643)

Add trust_remote_code to quantize script (#647)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes a bug appeared with MR #587 fixing issue #552.
See the discussion in #552.

With MR #587 the trust_remote_code variable is not passed to
AutoModelForCausalLM, but is found in the function signature. This
prevents models like falcon to be quantized, because trust_remote_code
is required. This MR fixes the issue.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [X] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
@Narsil
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

 -->

fix(server): llama v2 GPTQ (#648)

As per title & reported
https://github.com/huggingface/text-generation-inference/issues/601#issuecomment-1641435956
https://huggingface.co/TheBloke/Llama-2-70B-chat-GPTQ/discussions/5

Test it:

```
GPTQ_BITS=4 GPTQ_GROUPSIZE=1 text-generation-launcher --model-id TheBloke/Llama-2-70B-chat-GPTQ --port 8080 --num-shard 4 --quantize gptq
```
&
```
curl 127.0.0.1:8080/generate \
    -X POST \
    -d '{"inputs":"hey llama","parameters":{"max_new_tokens":256}}' \
    -H 'Content-Type: application/json'
```

fix(server): Fixing non parameters in quantize script `bigcode/starcoder` was an example. (#661)

fix(server): use mem_get_info to get kv cache size (#664)

Close
https://github.com/huggingface/text-generation-inference/issues/649
Close
https://github.com/huggingface/text-generation-inference/issues/651
Close
https://github.com/huggingface/text-generation-inference/issues/653
Close #636

feat(server): Add exllama GPTQ CUDA kernel support #553 (#666)

Just trying to get the integration tests to pass.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Felix Marty <[email protected]>

Directly load GPTBigCode to specified device (#618)

This PR directly load GPTBigCode to specified device, avoiding moving
model between devices.

This PR directly load GPTBigCode to specified device, avoiding moving
model between devices.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

@OlivierDehaene OR @Narsil

feat(server): add local prom and health routes if running w/ ngrok

feat: add cuda memory fraction (#659)

Close #673

fix(server): fix exllama buffers (#689)

Close #683

feat(server): Using `quantize_config.json` instead of GPTQ_BITS env variables. (#671)

- Current PR is not great because we're side stepping the
  `Weights.__init__` but Weights shouldn't requires anything related
  to the config or the model_id as it aims to be a simple Wrapper
  over multi file loading.
- Ideal solution would be to use something like Rust enum
  ```
  enum Quantize{
    Bitandbytes(Bitsandbytes),
    GPTQ(bits: usize, groupsize: usize)
  ```
  And passing that around during load. Unfortunately we don't
  have access to this, so for now, side-stepping seems easier.

- Re-enabling groupsize<0 with exllama (confirmed it works.)

Helps #601

In next steps we should make sure our quantization script uses that
format and make it standard.

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs(README): update readme

fix(server): fix quantization python requirements (#708)

fix(server): fix missing datasets in quantize

feat(server): support new falcon config (#712)

v0.9.4 (#713)

Add section about TGI on other AI hardware accelerators in README (#715)

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

As per title.

- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

docs: Add hardware section to TOC in README (#721)

feat(server): update vllm version (#723)

chore: update license to HFOIL (#725)

v1.0.0 (#727)

Local gptq support. (#738)

Redoes #719

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

Fix typing in `Model.generate_token` (#733)

This PR fixes a minor type annotation issue in the signature of
`Model.generate_token`.

All existing overrides of `Model.generate_token` return
`Tuple[List[Generation], Optional[B]]`:

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/causal_lm.py#L535-L537

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/flash_causal_lm.py#L802-L804

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/seq2seq_lm.py#L589-L591

I suspect that back in 017a2a8c when `GeneratedText` and `Generation`
were separated, the function signature was not updated.

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

CC @OlivierDehaene

Adding Rope scaling. (#741)

- Adds Rope NTK scaling.

Done because
https://github.com/huggingface/text-generation-inference/pull/529 was
closed
Took some code from
https://github.com/huggingface/transformers/pull/24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).

Fixes #512

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @

@OlivierDehaene OR @Narsil

 -->

chore: fix typo in mpt_modeling.py (#737)

Fixed typo.
<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

implemetation -> implementation

- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update…
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.