Skip to content

Commit

Permalink
CR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Feb 14, 2024
1 parent 423d28c commit 0ffec7c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/autora/doc/pipelines/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def load_data(data_file: str) -> Tuple[List[str], List[str]]:

with jsonlines.open(data_file) as reader:
items = [item for item in reader]
inputs = [f"{item['instruction']}" for item in items]
inputs = [item["instruction"] for item in items]
labels = [item["output"] for item in items]
return inputs, labels

Expand Down
9 changes: 5 additions & 4 deletions src/autora/doc/pipelines/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ def gen() -> Iterable[Dict[str, str]]:

def fine_tune(base_model: str, new_model_name: str, dataset: Dataset) -> None:
cuda_available = torch.cuda.is_available()
config = {}

# train using 4 bit quantization for lower GPU memory usage
kwargs = (
{"device_map": "auto", "quantization_config": get_quantization_config()} if cuda_available else {}
)
if cuda_available:
config.update({"device_map": "auto", "quantization_config": get_quantization_config()})

model = AutoModelForCausalLM.from_pretrained(
base_model,
**kwargs,
**config,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
Expand Down

0 comments on commit 0ffec7c

Please sign in to comment.