Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Based fork #2

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import sys
from typing import List, Optional

from lm_eval.__main__ import cli_evaluate


from datetime import datetime
import os
import importlib.util

import click
from tqdm import tqdm


MAX_WORKERS_PER_GPU = 1


def execute_config(
model: str,
task: str,
batch_size: int,
limit: int,
output_dir: str
):
# Save the original standard output
import subprocess

output_dir = os.path.join(output_dir, model, task)

args = [
"lm_eval",
"--model", "based_lm",
"--model_args", f"checkpoint_name={model}",
"--tasks", task,
"--device", "cuda:0",
"--batch_size", str(batch_size),
"--log_samples",
"--output_path", output_dir
]

if limit is not None:
args.extend(["--limit", str(limit)])

subprocess.run(args)



@click.command()
@click.option("-m", "--model", type=str, multiple=True)
@click.option("-t", "--task", type=str, multiple=True)
@click.option("-p", "--parallelize", is_flag=True)
@click.option("--gpus", default=None, type=str)
@click.option("--batch-size", default=8, type=int)
@click.option("--limit", default=None, type=int)
def main(
model: List[str],
task: List[str],
batch_size: int,
limit: Optional[int],
parallelize: bool,
gpus: str
):

if gpus is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = gpus

# Load the given Python file as a module
configs = [
{"model": m, "task": t} for m in model for t in task
]

use_ray = parallelize and len(configs) > 0
if use_ray:
import ray
# ray was killing workers due to OOM, but it didn't seem to be necessary
os.environ["RAY_memory_monitor_refresh_ms"] = "0"
ray.init(ignore_reinit_error=True, log_to_driver=True)

print(f"Running sweep with {len(configs)} configs")

output_dir = f"output/{datetime.now().strftime('%y-%m-%d_%H-%M')}"

# Run each script in parallel using Ray
if not use_ray:
for config in configs:
execute_config(
**config,
batch_size=batch_size,
limit=limit,
output_dir=output_dir
)
else:
completed = 0
total = len(configs)
print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")

remote = ray.remote(num_gpus=(1 // MAX_WORKERS_PER_GPU))(execute_config)
futures = [remote.remote(**config, batch_size=batch_size, limit=limit, output_dir=output_dir) for config in configs]

while futures:
complete, futures = ray.wait(futures)
completed += len(complete)
print(f"Completed: {completed} ({completed / total:0.1%}) | Total: {total}")

ray.shutdown()



if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def download(
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
breakpoint()
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
Expand Down Expand Up @@ -430,7 +431,7 @@ def build_all_requests(
if not isinstance(inst, list):
inst = [inst]

instances.append(inst)
instances.extend(inst)

# now flatten, this is to allow slicing to work with pickles

Expand Down
1 change: 0 additions & 1 deletion lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
mamba_open_lm
)


# TODO: implement __all__


Expand Down
65 changes: 65 additions & 0 deletions lm_eval/models/based_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import re
from transformers import AutoTokenizer
import torch

from based.utils.hf import load_config_hf

from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM


@register_model("based_lm")
class BasedLMWrapper(HFLM):
def __init__(
self,
checkpoint_name: str='hazyresearch/based-1.3b',
arch: str=None,
device: str = "cuda",
**kwargs
) -> None:

if arch is None:
arch = checkpoint_name.split("/")[1].split("-")[0]

assert arch in ['based', 'mamba', 'attn'], print("`arch` must be one of 'based', 'mamba', or 'attn'")

if "backend" in kwargs:
# based currently only supports causal models
assert kwargs["backend"] == "causal"

self.checkpoint_name = checkpoint_name

if arch == "based":
from based.models.gpt import GPTLMHeadModel
model = GPTLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device) #.to(dtype=torch.float16)
elif arch == "mamba":
from based.models.mamba import MambaLMHeadModel
model = MambaLMHeadModel.from_pretrained_hf(pretrained_model_name=self.checkpoint_name, device=device)
elif arch == "attn":
from based.models.transformer.gpt import GPTLMHeadModel, GPT2Config, state_dict_from_pretrained; # TODO: construct a loading function
config_data = load_config_hf(self.checkpoint_name)
config = GPT2Config(**config_data)
model = GPTLMHeadModel(config=config, device=device, dtype=torch.float16)
state_dict = state_dict_from_pretrained(self.checkpoint_name, dtype=torch.float16)
# remove the 'model.' prefix from the keys
state_dict = {re.sub("^model\.", "", k): v for k, v in state_dict.items()}
# remove Unexpected key(s) in state_dict: "train_metrics.num-tokens.count", "val_metrics.num-tokens.count", "test_metrics.num-tokens.count". from the state_dict
state_dict = {k: v for k, v in state_dict.items() if "metrics" not in k}
model.load_state_dict(state_dict)
else:
raise ValueError(f"Unsupported model {arch}")

tokenizer_name = kwargs.get("tokenizer", "gpt2")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

model.device = device

super().__init__(
pretrained=model,
# set appropriate defaults for tokenizer, max length, etc
backend=kwargs.get("backend", "causal"),
max_length=kwargs.get("max_length", 2048),
tokenizer=tokenizer,
device=device,
**kwargs,
)
2 changes: 1 addition & 1 deletion lm_eval/tasks/fda/swde.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
task: fda
class: !function task.FDA
class: !function task.FDA
5 changes: 2 additions & 3 deletions lm_eval/tasks/fda/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def doc_to_text(self, doc):

def doc_to_target(self, doc):
return doc["value"]

def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Expand Down Expand Up @@ -92,9 +92,8 @@ def higher_is_better(self):
"contains": True, # Exact match (the normalized answer exactly match the gold answer
}


def contains_score(prediction: str, labels: List[str]):
return max(
int(bool(re.search(re.compile(re.escape(label), re.IGNORECASE), prediction)))
for label in labels
)
)
Empty file.
2 changes: 2 additions & 0 deletions lm_eval/tasks/squad_completion/squad_completion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
task: squad_completion
class: !function task.SQUADCompletion
102 changes: 102 additions & 0 deletions lm_eval/tasks/squad_completion/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
"""
from typing import List

import re
import numpy as np

from lm_eval.api.task import ConfigurableTask
from lm_eval.api.instance import Instance


class SQUADCompletion(ConfigurableTask):
VERSION = 0
DATASET_PATH = "hazyresearch/based-squad"
DATASET_NAME = "default"

def __init__(self):
super().__init__(config={'metadata': {'version': self.VERSION}})

def has_training_docs(self):
return False

def has_validation_docs(self):
return True

def has_test_docs(self):
return False

def validation_docs(self):
return self.dataset["validation"]

def doc_to_text(self, doc):
return doc["text"]

def doc_to_target(self, doc):
return doc["value"]

def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.

:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""

return [
Instance(
request_type="generate_until",
doc=doc,
arguments=(ctx, {"until": ["\n"], "max_gen_toks": 48}),
idx=0,
**kwargs,
)
]

def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document

:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# continuation, (logprob_unanswerable, _) = results
continuation = results

return {
"contains": contains_score(continuation[0], [doc["value"]])
}

def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"contains": np.mean, # Exact match (the normalized answer exactly match the gold answer)
}

def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"contains": True, # Exact match (the normalized answer exactly match the gold answer
}


def contains_score(prediction: str, labels: List[str]):
return max(
int(bool(re.search(re.compile(re.escape(label), re.IGNORECASE), prediction)))
for label in labels
)
2 changes: 1 addition & 1 deletion lm_eval/tasks/swde/swde.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
task: swde
class: !function task.SWDE
class: !function task.SWDE
1 change: 0 additions & 1 deletion lm_eval/tasks/swde/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def higher_is_better(self):
"contains": True, # Exact match (the normalized answer exactly match the gold answer
}


def contains_score(prediction: str, labels: List[str]):
return max(
int(bool(re.search(re.compile(re.escape(label), re.IGNORECASE), prediction)))
Expand Down
38 changes: 38 additions & 0 deletions run_harness.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash

# TASKS=hellaswag,lambada_openai,piqa,arc_easy,arc_challenge,winogrande
TASKS=swde
DEVICE=cuda:0
BATCH_SIZE=32

MODELS=("based-360m" "mamba-360m" "attn-360m")

# for MODEL in "${MODELS[@]}"; do

# lm_eval \
# --model based_lm \
# --model_args checkpoint_name=hazyresearch/$MODEL \
# --tasks $TASKS \
# --device $DEVICE \
# --batch_size $BATCH_SIZE \
# --limit 100 \
# --log_samples \
# --output_path output/$MODEL

# done


python launch.py \
--batch-size 32 \
-m "hazyresearch/based-360m" \
-m "hazyresearch/mamba-360m" \
-m "hazyresearch/attn-360m" \
-t "swde" \
-t "hellaswag" \
-t "lambada_openai" \
-t "piqa" \
-t "arc_easy" \
-t "arc_challenge" \
-t "winogrande" \
# --limit 1000 \
-p