Skip to content

Commit

Permalink
Update to ipex 2.2.0 for CPU (#143)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
xwu99 authored Mar 15, 2024
1 parent e2a6450 commit dcef31b
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion llm_on_ray/inference/deepspeed_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def init_model(self, local_rank: int):
ipex._C.disable_jit_linear_repack()
except Exception:
pass
pipe.model = ipex.optimize_transformers(
pipe.model = ipex.llm.optimize(
pipe.model.eval(),
dtype=torch.bfloat16
if self.infer_conf.ipex.precision == PRECISION_BF16
Expand Down
2 changes: 1 addition & 1 deletion llm_on_ray/inference/mllm_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, infer_conf: InferenceConfig):
ipex._C.disable_jit_linear_repack()
except Exception:
pass
model = ipex.optimize_transformers(
model = ipex.llm.optimize(
model.eval(),
dtype=torch.bfloat16
if infer_conf.ipex.precision == PRECISION_BF16
Expand Down
2 changes: 1 addition & 1 deletion llm_on_ray/inference/transformer_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, infer_conf: InferenceConfig):
ipex._C.disable_jit_linear_repack()
except Exception:
pass
model = ipex.optimize_transformers(
model = ipex.llm.optimize(
model.eval(),
dtype=torch.bfloat16
if infer_conf.ipex.precision == PRECISION_BF16
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ dependencies = [

[project.optional-dependencies]
cpu = [
"transformers>=4.35.0",
"intel_extension_for_pytorch==2.1.0+cpu",
"torch==2.1.0+cpu",
"oneccl_bind_pt==2.1.0+cpu"
"transformers>=4.35.0, <=4.35.2",
"intel_extension_for_pytorch>=2.2.0",
"torch>=2.2.0",
"oneccl_bind_pt>=2.2.0"
]

gpu = [
Expand Down

0 comments on commit dcef31b

Please sign in to comment.