Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional fixes #4

Merged
merged 3 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions data-preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def __init__(self, args):
self.data_path = args.data_path or valohai.inputs('dataset').path()
self.model_max_length = args.model_max_length
self.tokenizer = args.tokenizer
self.train_dataset = load_dataset(self.data_path, split='train')
self.eval_dataset = load_dataset(self.data_path, split='validation')
self.test_dataset = load_dataset(self.data_path, split='test')
self.train_dataset = load_dataset('csv', data_files=os.path.join(self.data_path, 'train.csv'))
self.eval_dataset = load_dataset('csv', data_files=os.path.join(self.data_path, 'validation.csv'))
self.test_dataset = load_dataset('csv', data_files=os.path.join(self.data_path, 'test.csv'))

def prepare_datasets(self, generate_and_tokenize_prompt):
tknzd_train_dataset = self.train_dataset.map(generate_and_tokenize_prompt)
Expand All @@ -30,10 +30,10 @@ def generate_and_tokenize_prompt(self, data_point, tokenizer):
The attributes must be one of the following: ['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating', 'genres', 'player_perspective', 'has_multiplayer', 'platforms', 'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier']

### Target sentence:
{data_point["target"]}
{data_point["ref"]}

### Meaning representation:
{data_point["meaning_representation"]}
{data_point["mr"]}
"""
return tokenizer(full_prompt, truncation=True, max_length=self.model_max_length, padding='max_length')

Expand Down
2 changes: 1 addition & 1 deletion finetune-mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def train(self):
optim=self.optimizer,
logging_dir='./logs', # Directory for storing logs
save_strategy='steps', # Save the model checkpoint every logging step
save_steps=10, # Save checkpoints every 50 steps
save_steps=50, # Save checkpoints every 50 steps
evaluation_strategy='steps', # Evaluate the model every logging step
eval_steps=50, # Evaluate and save checkpoints every 50 steps
do_eval=self.do_eval, # Perform evaluation at the end of training
Expand Down
13 changes: 12 additions & 1 deletion inference-mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,24 @@ def __init__(self, model_path: str, checkpoint_path: str) -> None:
self.ft_model = PeftModel.from_pretrained(model, checkpoint_path).eval()

def generate_response(self, prompt: str, max_tokens: int = 50) -> str:
inputs = self.tokenizer(prompt, return_tensors='pt')
inputs = self.prepare_prompt(prompt)
with torch.no_grad():
logger.info('Generating up to %d tokens...', max_tokens)
outputs = self.ft_model.generate(**inputs, max_length=max_tokens, pad_token_id=2)
logger.info('Decoding...')
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

def prepare_prompt(self, prompt):
akx marked this conversation as resolved.
Show resolved Hide resolved
test_prompt = f"""Given a target sentence construct the underlying meaning representation of the input sentence as a single function with attributes and attribute values.
This function should describe the target string accurately and the function must be one of the following ['inform', 'request', 'give_opinion', 'confirm', 'verify_attribute', 'suggest', 'request_explanation', 'recommend', 'request_attribute'].
The attributes must be one of the following: ['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating', 'genres', 'player_perspective', 'has_multiplayer', 'platforms', 'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier']

### Target sentence:
{prompt}
### Meaning representation:
"""
return self.tokenizer(test_prompt, return_tensors='pt')


def run(args):
inference = ModelInference(
Expand Down
15 changes: 6 additions & 9 deletions valohai.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
- step:
name: data-preprocess
image: sofiiavalohai/llm-toolkit:v1.0
image: valohai/llm-toolkit:0.1
environment: staging-aws-eu-west-1-p3-2xlarge
command:
- pip install -r requirements.in
Expand All @@ -15,14 +15,13 @@
inputs:
- name: dataset
default:
- s3://dd-sample-bucket/mistral/gem-viggo-dataset/viggo.py
- s3://dd-sample-bucket/mistral/gem-viggo-dataset/test.csv
- s3://dd-sample-bucket/mistral/gem-viggo-dataset/train.csv
- s3://dd-sample-bucket/mistral/gem-viggo-dataset/validation.csv

- step:
name: finetune
image: sofiiavalohai/llm-toolkit:v1.0
image: valohai/llm-toolkit:0.1
environment: staging-aws-eu-west-1-p3-2xlarge
command:
- pip install -r requirements.in
Expand All @@ -41,7 +40,7 @@
default: 5
- name: max_steps
type: integer
default: 15
default: 128
- name: learning_rate
type: float
default: 2.5e-5
Expand All @@ -58,7 +57,7 @@

- step:
name: inference
image: sofiiavalohai/llm-toolkit:v1.0
image: valohai/llm-toolkit:0.1
environment: staging-aws-eu-west-1-p3-2xlarge
command:
- pip install -r requirements.in
Expand All @@ -68,15 +67,13 @@
default: "mistralai/Mistral-7B-v0.1"
- name: prompt
type: string
default: "give_opinion(name[SpellForce 3], rating[poor], genres[real-time strategy, role-playing], player_perspective[bird view])"
default: "You mean Tony Hawk's Pro Skater 3, the 2001 sports game?"
- name: max_tokens
type: integer
default: 305
default: 150
inputs:
- name: finetuned-checkpoint
default: dataset://mistral-models/best_mistral_checkpoint
- name: test_data
default: dataset://viggo/dev_test

- pipeline:
name: training-pipeline
Expand Down