Skip to content

Commit

Permalink
limit deps of cmd line aici; fixes #32
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jan 19, 2024
1 parent e60a2f2 commit d57dec8
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 63 deletions.
5 changes: 3 additions & 2 deletions harness/run_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import time
import pyaici
import pyaici.comms

from transformers import (
AutoTokenizer,
Expand All @@ -25,7 +26,7 @@

class AsyncLogitProcessor(LogitsProcessor, BaseStreamer):
def __init__(
self, runner: pyaici.AiciRunner, module_id: str, module_arg: str
self, runner: pyaici.comms.AiciRunner, module_id: str, module_arg: str
) -> None:
super().__init__()
self.runner = runner
Expand Down Expand Up @@ -72,7 +73,7 @@ def main(args):
)
model = cast(PreTrainedModel, model)

runner = pyaici.AiciRunner.from_cli(args)
runner = pyaici.runner_from_cli(args)

arg = ""
if args.aici_module_arg:
Expand Down
2 changes: 1 addition & 1 deletion harness/run_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def main(args: argparse.Namespace):
engine_args = EngineArgs.from_cli_args(args)

# build it first, so it fails fast
aici = pyaici.AiciRunner.from_cli(args)
aici = pyaici.runner_from_cli(args)

engine = LLMEngine.from_engine_args(engine_args)
pyaici.vllm.install(aici)
Expand Down
2 changes: 1 addition & 1 deletion harness/vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
served_model = args.model

# build it first, so it fails fast
aici = pyaici.AiciRunner.from_cli(args)
aici = pyaici.runner_from_cli(args)

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
Expand Down
1 change: 1 addition & 0 deletions llama-cpp-low/llama.cpp
Submodule llama.cpp added at 381ee1
58 changes: 53 additions & 5 deletions pyaici/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,54 @@
from pyaici.comms import AiciRunner, add_cli_args
import argparse

__all__ = [
"AiciRunner",
"add_cli_args",
]
def runner_from_cli(args):
from pyaici.comms import AiciRunner

aici = AiciRunner(
rtpath=args.aici_rt,
tokenizer=args.aici_tokenizer,
trace_file=args.aici_trace,
rtargs=args.aici_rtarg,
)
return aici


def add_cli_args(parser: argparse.ArgumentParser, single=False):
parser.add_argument(
"--aici-rt",
type=str,
required=True,
help="path to aicirt",
)
parser.add_argument(
"--aici-tokenizer",
type=str,
default="llama",
help="tokenizer to use; llama, gpt4, ...",
)
parser.add_argument(
"--aici-trace",
type=str,
help="save trace of aicirt interaction to a JSONL file",
)
parser.add_argument(
"--aici-rtarg",
"-A",
type=str,
default=[],
action="append",
help="pass argument to aicirt process",
)

if single:
parser.add_argument(
"--aici-module",
type=str,
required=True,
help="id of the module to run",
)
parser.add_argument(
"--aici-module-arg",
type=str,
default="",
help="arg passed to module (filename)",
)
10 changes: 5 additions & 5 deletions pyaici/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import subprocess
import ujson
import json
import sys
import os
import argparse

from . import rest, jssrc
from . import add_cli_args, AiciRunner
from . import add_cli_args, runner_from_cli


def cli_error(msg: str):
Expand All @@ -31,7 +31,7 @@ def build_rust(folder: str):
stdout=-1,
check=True,
)
info = ujson.decode(r.stdout)
info = json.loads(r.stdout)
if len(info["workspace_default_members"]) != 1:
cli_error("please run from project, not workspace, folder")
pkg_id = info["workspace_default_members"][0]
Expand Down Expand Up @@ -90,7 +90,7 @@ def ask_completion(cmd_args, *args, **kwargs):
os.makedirs("tmp", exist_ok=True)
path = "tmp/response.json"
with open(path, "w") as f:
ujson.dump(res, f, indent=1)
json.dump(res, f, indent=1)
print(f"response saved to {path}")
print("Usage:", res["usage"])
print("Storage:", res["storage"])
Expand Down Expand Up @@ -239,7 +239,7 @@ def main_inner():
sys.exit(0)

if args.subcommand == "benchrt":
AiciRunner.from_cli(args).bench()
runner_from_cli(args).bench()
sys.exit(0)

if args.subcommand == "tags":
Expand Down
42 changes: 0 additions & 42 deletions pyaici/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,45 +615,3 @@ def response_by_seq_id(self, seq_id: int) -> Dict[str, Any]:
Get the response for a given batch entry ID.
"""
return self.last_mid_response.get(str(seq_id), None)


def add_cli_args(parser: argparse.ArgumentParser, single=False):
parser.add_argument(
"--aici-rt",
type=str,
required=True,
help="path to aicirt",
)
parser.add_argument(
"--aici-tokenizer",
type=str,
default="llama",
help="tokenizer to use; llama, gpt4, ...",
)
parser.add_argument(
"--aici-trace",
type=str,
help="save trace of aicirt interaction to a JSONL file",
)
parser.add_argument(
"--aici-rtarg",
"-A",
type=str,
default=[],
action="append",
help="pass argument to aicirt process",
)

if single:
parser.add_argument(
"--aici-module",
type=str,
required=True,
help="id of the module to run",
)
parser.add_argument(
"--aici-module-arg",
type=str,
default="",
help="arg passed to module (filename)",
)
14 changes: 7 additions & 7 deletions pyaici/rest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import requests
import ujson
import json
import os
import urllib.parse
import sys
Expand Down Expand Up @@ -47,7 +47,7 @@ def _mk_url(path: str) -> str:
def response_error(kind: str, resp: requests.Response):
text = resp.text
try:
d = ujson.decode(text)
d = json.loads(text)
if "message" in d:
text = d["message"]
except:
Expand All @@ -63,7 +63,7 @@ def req(tp: str, url: str, **kwargs):
if log_level >= 4:
print(f"{tp.upper()} {url} headers={headers}")
if "json" in kwargs:
print(ujson.dumps(kwargs["json"]))
print(json.dumps(kwargs["json"]))
return requests.request(tp, url, headers=headers, **kwargs)


Expand Down Expand Up @@ -126,7 +126,7 @@ def completion(
):
if ignore_eos is None:
ignore_eos = not not ast_module
json = {
data = {
"model": "",
"prompt": prompt,
"max_tokens": max_tokens,
Expand All @@ -137,15 +137,15 @@ def completion(
"aici_arg": aici_arg,
"ignore_eos": ignore_eos,
}
resp = req("post", "completions", json=json, stream=True)
resp = req("post", "completions", json=data, stream=True)
if resp.status_code != 200:
raise response_error("completions", resp)
texts = [""] * n
logs = [""] * n
full_resp = []
storage = {}
res = {
"request": json,
"request": data,
"response": full_resp,
"text": texts,
"logs": logs,
Expand All @@ -161,7 +161,7 @@ def completion(
continue
decoded_line: str = line.decode("utf-8")
if decoded_line.startswith("data: {"):
d = ujson.decode(decoded_line[6:])
d = json.loads(decoded_line[6:])
full_resp.append(d)
if "usage" in d:
res["usage"] = d["usage"]
Expand Down

0 comments on commit d57dec8

Please sign in to comment.