From 4c9bb6043ce90fb7e07defb52ea9cc886a7e9fd6 Mon Sep 17 00:00:00 2001 From: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Date: Mon, 6 Jan 2025 15:11:26 -0800 Subject: [PATCH] Check transformers version in BLOOM for inference v1 (#6766) This PR checks that the `transformers` version is `<= 4.43.4` in the BLOOM container for inference v1, due to breaking changes in `transformers > 4.43.4`. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/module_inject/containers/bloom.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py index a78ac8120346..7a9b9ca2065b 100644 --- a/deepspeed/module_inject/containers/bloom.py +++ b/deepspeed/module_inject/containers/bloom.py @@ -19,6 +19,18 @@ class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer): def __init__(self, **kwargs): + # Check transformers version, error if > 4.43.4 (breaks at 4.44.0) + from importlib.metadata import version + v_transformers = version('transformers') + vers = v_transformers.split('.') + major = int(vers[0]) + minor = int(vers[1]) + if major > 4 or (major == 4 and minor > 43): + import sys + sys.exit( + f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported." + ) + super().__init__(**kwargs) # All model specific things should be defined here instead of the base class.