Skip to content

Commit

Permalink
Merge pull request #59 from microsoft/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Micheallei authored Jul 23, 2024
2 parents cc3611c + 3be1b7e commit 26cfe7b
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 5 deletions.
2 changes: 2 additions & 0 deletions RecLM-emb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ bash shell/test_data_pipeline.sh
```bash
bash shell/run_single_node.sh
```
- `--nproc_per_node`: the number of GPUs on your machine
#### multi node
If you have two nodes: node0 and node1 (here we use node0 as the master node), you should first run `deepspeed utils/all_reduce_bench_v2.py` to get IP and port number of the master node. Then run the following script on both nodes using the same IP and port number but different node ranks.

Expand All @@ -121,6 +122,7 @@ Note that for the [repllama](https://huggingface.co/castorini/repllama-v1-7b-lor
```bash
bash shell/infer_llm_metrics.sh
```
- `--config_file`: Indicate the code running environment. `./shell/infer_case.yaml` and `./shell/infer.yaml` are provided as references for single-gpu inference and multi-gpu inference respectively.

#### Case study
You need to first prepare your query file with the suffix .jsonl (use `user_embedding_prompt_path` parameter to specify), an example is as follows:
Expand Down
5 changes: 3 additions & 2 deletions RecLM-emb/infer_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__) )))
from preprocess.utils import get_item_text
from src.huggingface_model_infer import run_model_embedding
from src.openai_model_infer import run_api_embedding

def coverage_at_k(r, k):
r = np.asarray(r)[:k]
Expand Down Expand Up @@ -410,9 +409,11 @@ def write_metrics_to_file(metrics_dict, task_name, file_name):
f.write(json.dumps(output, ensure_ascii=False) + '\n')

if __name__ == '__main__':
openai_model_names = ['ada_embeddings', 'text-embedding-ada-002']
openai_model_names = ['ada_embeddings', 'text-embedding-ada-002', 'text-embedding-3-large']
args = parse_args()
set_seed(args.seed)
if args.model_path_or_name in openai_model_names:
from src.openai_model_infer import run_api_embedding

accelerator = Accelerator()

Expand Down
2 changes: 1 addition & 1 deletion RecLM-emb/shell/infer_metrics.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ PASSAGE_MAX_LEN=128
SENTENCE_POOLING_METHOD="mean"

cd $EXE_DIR
if [ "$MODEL_PATH_OR_NAME" = "ada_embeddings" ] || [ "$MODEL_PATH_OR_NAME" = "text-embedding-ada-002" ]; then
if [ "$MODEL_PATH_OR_NAME" = "ada_embeddings" ] || [ "$MODEL_PATH_OR_NAME" = "text-embedding-ada-002" ] || [ "$MODEL_PATH_OR_NAME" = "text-embedding-3-large" ]; then
echo "using openai model"
CONFIG_FILE=./shell/infer_case.yaml
else
Expand Down
2 changes: 1 addition & 1 deletion RecLM-emb/shell/test_data_pipeline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ else

echo "generate gpt_response_file"

python preprocess/gpt_api/query_api.py --input_file $gpt_query_file'.csv' --output_file $gpt_response_file'.csv'
python preprocess/gpt_api/api.py --input_file $gpt_query_file'.csv' --output_file $gpt_response_file'.csv'
fi

echo "generate gpt_data_file"
Expand Down
2 changes: 1 addition & 1 deletion RecLM-emb/src/huggingface_model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def run_model_embedding(model_path_or_name, max_seq_len, batch_size, prompt_path
model = AutoModel.from_pretrained(model_path_or_name, config=model_config, torch_dtype=torch_dtype)
if args.peft_model_name:
model = PeftModel.from_pretrained(model, args.peft_model_name)
model = model.merge_and_unload()
# model = model.merge_and_unload()

accelerator.print(f'loading file {prompt_path}')
test_data = pd.read_json(prompt_path, lines=True)
Expand Down

0 comments on commit 26cfe7b

Please sign in to comment.