Skip to content

Commit

Permalink
chore: Update dependencies, logging (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs authored Feb 8, 2024
1 parent 5615549 commit 4c79762
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 11 deletions.
4 changes: 2 additions & 2 deletions azureml/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ dependencies:
- typer
- jsonlines
- accelerate>=0.24.1
- bitsandbytes>=0.41.2.post2
- transformers>=4.35.2
- bitsandbytes>=0.42.0
- transformers>=4.37.2
- xformers
- scipy
- nltk
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dev = [
"hf_transfer",
]
pipelines = ["jsonlines", "mlflow", "nltk", "sentence-transformers>=2.3.1"]
# NOTE: When updating dependencies, in particular cuda/azure ml, make sure to update the azureml/conda.yaml too
azure = ["azureml-core", "azureml-mlflow"]
cuda = ["bitsandbytes>=0.42.0", "accelerate>=0.24.1", "xformers"]

Expand Down
11 changes: 5 additions & 6 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from autora.doc.runtime.prompts import PROMPTS, PromptIds
from autora.doc.util import get_prompts_from_file

app = typer.Typer()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(module)s.%(funcName)s(): %(message)s",
)
logger = logging.getLogger(__name__)
logger.info(f"Torch version: {torch.__version__} , Cuda available: {torch.cuda.is_available()}")
app = typer.Typer()


@app.command(help="Evaluate a model for code-to-documentation generation for all prompts in the prompts_file")
Expand Down Expand Up @@ -83,9 +84,9 @@ def eval(
mlflow.log_param("prompt_id", prompt_id)
mlflow.log_param("model_path", model_path)
mlflow.log_param("data_file", data_file)
prompt = PROMPTS[prompt_id]
pred = Predictor(model_path)
return eval_prompt(data_file, pred, prompt, param_dict)
prompt = PROMPTS[prompt_id]
pred = Predictor(model_path)
return eval_prompt(data_file, pred, prompt, param_dict)


def load_data(data_file: str) -> Tuple[List[str], List[str]]:
Expand Down Expand Up @@ -175,6 +176,4 @@ def read_text(file: str) -> str:


if __name__ == "__main__":
logger.info(f"Torch version: {torch.__version__} , Cuda available: {torch.cuda.is_available()}")

app()
6 changes: 5 additions & 1 deletion src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
logger = logging.getLogger(__name__)

quantized_models = {"meta-llama/Llama-2-7b-chat-hf": "autora-doc/Llama-2-7b-chat-hf-nf4"}
non_quantized_models = {"meta-llama/Llama-2-7b-chat-hf": "autora-doc/Llama-2-7b-chat-hf"}


def preprocess_code(code: str) -> str:
Expand Down Expand Up @@ -91,6 +92,7 @@ def tokenize(self, input: List[str]) -> Dict[str, List[List[int]]]:
@staticmethod
def get_config(model_path: str) -> Tuple[str, Dict[str, str]]:
if torch.cuda.is_available():
logger.info("CUDA is available, attempting to load quantized model")
from transformers import BitsAndBytesConfig

config = {"device_map": "auto"}
Expand All @@ -108,4 +110,6 @@ def get_config(model_path: str) -> Tuple[str, Dict[str, str]]:
)
return model_path, config
else:
return model_path, {}
logger.info("CUDA is not available, loading non-quantized model")
mapped_path = non_quantized_models.get(model_path, model_path)
return mapped_path, {}
4 changes: 2 additions & 2 deletions tests/test_predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ def test_get_config_cuda(mock: mock.Mock) -> None:

@mock.patch("torch.cuda.is_available", return_value=False)
def test_get_config_nocuda(mock: mock.Mock) -> None:
model, config = Predictor.get_config(MODEL_WITH_QUANTIZED)
assert model == MODEL_WITH_QUANTIZED
model, config = Predictor.get_config(MODEL_NO_QUANTIZED)
assert model == MODEL_NO_QUANTIZED
assert len(config) == 0

0 comments on commit 4c79762

Please sign in to comment.