Skip to content

Commit

Permalink
git pull main - resolve merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
anujsinha3 committed Feb 14, 2024
2 parents fbbdd82 + aac9880 commit 9853ca2
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import itertools
import logging
from timeit import default_timer as timer
from typing import Dict, List, Tuple
from typing import Dict, List

import torch
import typer

from autora.doc.classes.EvalResult import EvalResult
from autora.doc.pipelines.data import load_data
from autora.doc.pipelines.metrics import eval_bleu_meteor, eval_semscore
from autora.doc.pipelines.train import fine_tune, get_dataset
from autora.doc.runtime.predict_hf import Predictor
from autora.doc.runtime.prompts import PROMPTS, PromptIds
from autora.doc.util import get_prompts_from_file

# For inference
DEFAULT_INFERENCE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
# For training
DEFAULT_BASE_MODEL = "autora-doc/Llama-2-7b-chat-hf"

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(module)s.%(funcName)s(): %(message)s",
Expand All @@ -24,7 +31,7 @@
@app.command(help="Evaluate a model for code-to-documentation generation for all prompts in the prompts_file")
def eval_prompts(
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
model_path: str = typer.Option(DEFAULT_INFERENCE_MODEL, help="Path to HF model"),
prompts_file: str = typer.Argument(..., help="JSON file with a list of dictionary of prompts"),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
Expand Down Expand Up @@ -62,7 +69,7 @@ def eval_prompts(
@app.command(help="Evaluate model on a data file")
def eval(
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
model_path: str = typer.Option(DEFAULT_INFERENCE_MODEL, help="Path to HF model"),
prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
Expand All @@ -89,16 +96,6 @@ def eval(
return eval_prompt(data_file, pred, prompt, param_dict)


def load_data(data_file: str) -> Tuple[List[str], List[str]]:
import jsonlines

with jsonlines.open(data_file) as reader:
items = [item for item in reader]
inputs = [f"{item['instruction']}" for item in items]
labels = [item["output"] for item in items]
return inputs, labels


def eval_prompt(
data_file: str, pred: Predictor, prompt: str, param_dict: Dict[str, float], prompt_index: int = 0
) -> EvalResult:
Expand Down Expand Up @@ -133,7 +130,7 @@ def eval_prompt(
@app.command()
def generate(
python_file: str = typer.Argument(..., help="Python file to generate documentation for"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
model_path: str = typer.Option(DEFAULT_INFERENCE_MODEL, help="Path to HF model"),
output: str = typer.Option("output.txt", help="Output file"),
prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"),
param: List[str] = typer.Option(
Expand Down Expand Up @@ -161,6 +158,18 @@ def import_model(model_name: str) -> None:
pass


@app.command()
def train(
new_model_name: str = typer.Argument(..., help="File name for the fine-tuned model"),
dataset: str = typer.Argument(..., help="Path to the jsonl file with training data"),
base_model: str = typer.Option(
DEFAULT_BASE_MODEL, help="Path to the base Huggingface model to fine-tune"
),
) -> None:
ds = get_dataset(dataset)
fine_tune(base_model, new_model_name, ds)


@app.command()
def import_data(code_file: str, text_file: str, output_file: str = "data.jsonl") -> None:
from pathlib import Path
Expand Down

0 comments on commit 9853ca2

Please sign in to comment.