From 98f31cd40b48116cd635ffc11609361ae775c463 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Tue, 6 Aug 2024 14:22:10 -0500 Subject: [PATCH] add emtpy_cache() after each padding (#120) --- vllm/model_executor/models/mixtral.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c34077fa2bfaf..ee9db7048f1f6 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -187,9 +187,11 @@ def process_weights_after_loading(self): self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data, (0, 128), "constant", 0), requires_grad=False) + torch.cuda.empty_cache() self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data, (0, 128), "constant", 0), requires_grad=False) + torch.cuda.empty_cache() return # If checkpoint is fp16, quantize here.