diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index bf5176525..123728aa5 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -78,6 +78,7 @@ pub struct MistralRsBuilder { no_prefix_cache: Option, prefix_cache_n: Option, disable_eos_stop: Option, + gemm_full_precision_f16: Option, } impl MistralRsBuilder { @@ -91,9 +92,9 @@ impl MistralRsBuilder { no_prefix_cache: None, prefix_cache_n: None, disable_eos_stop: None, + gemm_full_precision_f16: None, } } - pub fn with_log(mut self, log: String) -> Self { self.log = Some(log); self @@ -122,12 +123,25 @@ impl MistralRsBuilder { self.disable_eos_stop = Some(disable_eos_stop); self } + pub fn with_gemm_full_precision_f16(mut self, gemm_full_precision: bool) -> Self { + self.gemm_full_precision_f16 = Some(gemm_full_precision); + self + } pub fn build(self) -> Arc { MistralRs::new(self) } } +#[cfg(feature = "cuda")] +fn set_gemm_reduced_precision_f16() { + candle_core::cuda::set_gemm_reduced_precision_f16(true); + candle_core::cuda::set_gemm_reduced_precision_bf16(true); +} + +#[cfg(not(feature = "cuda"))] +fn set_gemm_reduced_precision_f16() {} + impl MistralRs { fn new(config: MistralRsBuilder) -> Arc { let MistralRsBuilder { @@ -139,8 +153,13 @@ impl MistralRs { no_prefix_cache, prefix_cache_n, disable_eos_stop, + gemm_full_precision_f16, } = config; + if !gemm_full_precision_f16.unwrap_or(false) { + set_gemm_reduced_precision_f16(); + } + let truncate_sequence = truncate_sequence.unwrap_or(false); let no_kv_cache = no_kv_cache.unwrap_or(false); let no_prefix_cache = no_prefix_cache.unwrap_or(false); diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index a8b6a664c..d9154449a 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -12,6 +12,14 @@ use crate::DeviceMapMetadata; const MAX_SEQ_LEN: u32 = 4096; +fn quantized_mat_mul(xs: &Tensor, w: &QMatMul, via_f16: bool) -> Result { + if via_f16 { + w.forward_via_f16(xs) + } else { + w.forward(xs) + } +} + #[derive(Debug, Clone)] struct Mlp { feed_forward_w1: QMatMul, @@ -19,12 +27,12 @@ struct Mlp { feed_forward_w3: QMatMul, } -impl Module for Mlp { - fn forward(&self, xs: &Tensor) -> Result { - let w1 = self.feed_forward_w1.forward(xs)?; - let w3 = self.feed_forward_w3.forward(xs)?; - self.feed_forward_w2 - .forward(&(candle_nn::ops::silu(&w1)? * w3)?) +impl Mlp { + fn forward(&self, xs: &Tensor, via_f16: bool) -> Result { + let w1 = quantized_mat_mul(xs, &self.feed_forward_w1, via_f16)?; + let w3 = quantized_mat_mul(xs, &self.feed_forward_w3, via_f16)?; + let y = &(candle_nn::ops::silu(&w1)? * w3)?; + quantized_mat_mul(y, &self.feed_forward_w2, via_f16) } } @@ -38,8 +46,8 @@ enum MlpOrMoe { }, } -impl Module for MlpOrMoe { - fn forward(&self, xs: &Tensor) -> Result { +impl MlpOrMoe { + fn forward(&self, xs: &Tensor, via_f16: bool) -> Result { match self { Self::MoE { feed_forward_gate_inp, @@ -94,7 +102,7 @@ impl Module for MlpOrMoe { // states by `routing_weights` on the corresponding tokens (top-1 and top-2) let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) - let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = expert_layer.forward(¤t_state, via_f16)?; let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?; ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; @@ -103,7 +111,7 @@ impl Module for MlpOrMoe { let ys = ys.reshape((b_size, seq_len, hidden_dim))?; Ok(ys) } - Self::Mlp(mlp) => mlp.forward(xs), + Self::Mlp(mlp) => mlp.forward(xs, via_f16), } } } @@ -132,11 +140,13 @@ impl LayerWeights { start_offsets: &[usize], start_offsets_kernel: Tensor, kv_cache: &mut Option<(Tensor, Tensor)>, + via_f16: bool, ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = self.attention_wq.forward(x)?; - let k = self.attention_wk.forward(x)?; - let v = self.attention_wv.forward(x)?; + + let q = quantized_mat_mul(x, &self.attention_wq, via_f16)?; + let k = quantized_mat_mul(x, &self.attention_wk, via_f16)?; + let v = quantized_mat_mul(x, &self.attention_wv, via_f16)?; let mut q = q.reshape((b_sz * seq_len, self.n_head, self.head_dim))?; let mut k = k.reshape((b_sz * seq_len, self.n_kv_head, self.head_dim))?; @@ -159,16 +169,32 @@ impl LayerWeights { let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; - let k = repeat_kv(k, self.n_head / self.n_kv_head)?.contiguous()?; - let v = repeat_kv(v, self.n_head / self.n_kv_head)?.contiguous()?; + let k = repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = repeat_kv(v, self.n_head / self.n_kv_head)?; + let att = if via_f16 { + let mm = q + .to_dtype(DType::F16)? + .matmul(&k.to_dtype(DType::F16)?.t()?)?; + + ((mm / (self.head_dim as f64).sqrt())?).to_dtype(DType::F32)? + } else { + let k = k.contiguous()?; + (q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())? + }; - let att = (q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())?; let att = CausalMasker.apply_mask(mask, att, &self.neg_inf)?; let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; + let y = if via_f16 { + att.to_dtype(DType::F16)? + .matmul(&v.to_dtype(DType::F16)?)? + .to_dtype(DType::F32)? + } else { + att.matmul(&v.contiguous()?)? + }; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = self.attention_wo.forward(&y)?; + let y = quantized_mat_mul(&y, &self.attention_wo, via_f16)?; Ok(y) } } @@ -386,6 +412,9 @@ impl ModelWeights { start_offsets_kernel: Tensor, context_lens: Vec<(usize, usize)>, ) -> Result { + let (_bz, seq_len, _) = x.dims3()?; + let via_f16 = seq_len > 32; + let mut layer_in = self.tok_embeddings.forward(x)?; let mut cache = self.cache.lock(); let mask = CausalMasker.make_causal_mask(x, &cache)?; @@ -402,18 +431,22 @@ impl ModelWeights { start_offsets, start_offsets_kernel.clone(), &mut cache[i], + via_f16, )?; let x = (attn + residual)?; // MLP let residual = &x; let x = layer.ffn_norm.forward(&x)?; - let x = layer.mlp_or_moe.forward(&x)?; + let x = layer.mlp_or_moe.forward(&x, via_f16)?; let x = (x + residual)?; layer_in = x; } let layer_in = layer_in.to_device(&self.device)?; let x = self.norm.forward(&layer_in)?; - extract_logits(&self.output.forward(&x.contiguous()?)?, context_lens) + extract_logits( + &quantized_mat_mul(&x.contiguous()?, &self.output, via_f16)?, + context_lens, + ) } }