-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcraftmd_opensource.py
43 lines (33 loc) · 1.45 KB
/
craftmd_opensource.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import sys
import pandas as pd
# Set OpenAI API Key
import openai
deployment_name = "<insert deployment name>"
openai.api_base = f"https://{deployment_name}.openai.azure.com/"
openai.api_key = "<insert API key>"
from src.utils import get_choices
from src.craftmd import craftmd_opensource, craftmd_opensource_system
from src.models import get_model_and_tokenizer
# # To download open-source models, if not already installed in your conda environment
from huggingface_hub import login
login(token = "<insert huggingface token>")
if __name__ == "__main__":
model_names = ["llama2-7b", "mistral-v1", "mistral-v2"]
dataset = pd.read_csv("./data/usmle_and_derm_dataset.csv",
index_col=0)
cases = [(dataset.loc[idx,"case_id"],
dataset.loc[idx,"case_vignette"],
dataset.loc[idx,"category"],
get_choices(dataset,idx)) for idx in dataset.index[start:end]]
for model_name in model_names:
path_dir = f"results/{model_name}"
model, tokenizer = get_model_and_tokenizer(model_name)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
for case in cases:
try:
craftmd_opensource(case, path_dir, model, tokenizer, model_name)
except Exception as e:
print(e)
print(f"Error in run : {case[0]}")
continue