Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added starcoder support #18

Merged
merged 4 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ Everyone is welcome to contribute!
As this is early stage work, there's lots improvements that can be done in the future and you're welcome to contribute!

- [x] Get rix of Python 3.10 dependency
- [ ] Clean basic code smells
- [x] Clean basic code smells
- [ ] Improve support for OpenAI and Anthropic
- [ ] Add support for other LLM providers
- [x] Add support for other LLM providers
- [ ] Add support for locally hosted models
- [ ] Pass model-related options to templating engine to allow for model-specific prompts
- [ ] Add support for testing against expectations (elapsed_time, tokens_used)
Expand Down Expand Up @@ -122,6 +122,14 @@ gcloud config set project <your-project-id>
gcloud auth application-default login
```


### Code completion
Using Starcoder model you can get code completion for a variety of languages. Here's a quick example of how to use it (check out the content of `examples/code/completion.yaml`):

```sh
$ prr run ./examples/code/completion.yaml
```

### Run a prompt from a simple text file containing just a prompt

Let's create a simple text file and call it `dingo` with the following content:
Expand Down Expand Up @@ -181,7 +189,7 @@ If you refer to another template within your template, changes to that file will
If your prompt is often saved and you're worried of running it too often, you can use `-c` option that's specific to `watch` command which enables defined number of seconds cooldown after every run, before it proceeds to execute on your changes again.

```
$ ./prr watch -c 15 ./subconcepts-of-buddhism
$ prr watch -c 15 ./subconcepts-of-buddhism
```


Expand Down Expand Up @@ -418,6 +426,8 @@ stats:

* OpenAI/chat - https://platform.openai.com/docs/guides/chat
* Anthropic/complete - https://console.anthropic.com/docs/api
* Google Vertex AI PaLM - https://cloud.google.com/vertex-ai/docs/generative-ai/learn/overview
* Starcoder - https://huggingface.co/bigcode/starcoder

## Development

Expand Down Expand Up @@ -450,6 +460,11 @@ OPENAI_API_KEY="sk-..."
ANTHROPIC_API_KEY="sk-ant-..."

DEFAULT_SERVICE="openai/chat/gpt-3.5-turbo"
# https://console.cloud.google.com
GOOGLE_PROJECT="gcp-project-id"
GOOGLE_LOCATION="us-central1"
# https://huggingface.co/settings/tokens
HF_TOKEN="hf_..."
```

You can also use DEFAULT_SERVICE to specify the model you want to use by default, but otherwise you're good to go!
Expand Down
11 changes: 11 additions & 0 deletions examples/code/completion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
version: 1
prompt:
content: "def prompt_text_from_template(self, template):"
services:
models:
- 'bigcode/starcoder/starcoder'
options:
temperature: 0.7
max_tokens: 250
top_p: 0.9
top_k: 40
144 changes: 141 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 55 additions & 0 deletions prr/services/providers/bigcode/starcoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from huggingface_hub import Repository
from text_generation import Client

from prr.runner.request import ServiceRequest
from prr.runner.response import ServiceResponse
from prr.utils.config import load_config

config = load_config()

HF_TOKEN = config.get("HF_TOKEN", None)
API_URL = "https://api-inference.huggingface.co/models/bigcode/starcoder"

FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"

FIM_INDICATOR = "<FILL_HERE>"


class ServiceBigcodeStarcoder:
provider = "bigcode"
service = "starcoder"

def run(self, prompt, service_config):
self.service_config = service_config
self.prompt = prompt

client = Client(
API_URL,
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

options = self.service_config.options

prompt_text = prompt.template_text()

service_request = ServiceRequest(service_config, prompt_text)

response = client.generate(
prompt_text,
temperature=options.temperature,
max_new_tokens=options.max_tokens,
top_p=options.top_p,
# repetition_penalty=options.repetition_penalty,
)

service_response = ServiceResponse(
response.generated_text,
{
"tokens_used": response.details.generated_tokens,
"stop_reason": response.details.finish_reason,
},
)

return service_request, service_response
2 changes: 2 additions & 0 deletions prr/services/service_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from prr.services.providers.anthropic.complete import ServiceAnthropicComplete
from prr.services.providers.bigcode.starcoder import ServiceBigcodeStarcoder
from prr.services.providers.google.chat import ServiceGoogleChat
from prr.services.providers.google.complete import ServiceGoogleComplete
from prr.services.providers.openai.chat import ServiceOpenAIChat
Expand All @@ -17,6 +18,7 @@ def register(self, service_class):
def register_all_services(self):
self.register(ServiceOpenAIChat)
self.register(ServiceAnthropicComplete)
self.register(ServiceBigcodeStarcoder)
self.register(ServiceGoogleComplete)
self.register(ServiceGoogleChat)

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ rich = "^13.3.5"
Jinja2 = "^3.1.2"
pyyaml = "^6.0"
google-cloud-aiplatform = "^1.25.0"
huggingface-hub = "^0.14.1"
text-generation = "^0.5.2"

[tool.poetry.dev-dependencies]
pre-commit = "^2.12.0"
Expand Down