This repo provides methods for using Meta's 6 billion parameter LLaMA model to generate song lyrics with a specific metric structure. If you've been yearning to rewrite the happy birthday song so that it's just about dogs, bragi can help :).
The core functionality of bragi
is provided via the MetricGenerator
class. If you're wondering, Bragi is the Norse god of poetry!
The library also provides wrappers around various methods for extracting metric information, such as syllable counts and rhyme schemes.
MetricGenerator
controls the metric structure of generated output by constraining the model's probability distribution over tokens. Specifically, tokens that would violate the target metric structure are masked at each inference step. This is implemented via a custom logits warper.
- Install cog
sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m`
sudo chmod +x /usr/local/bin/cog
- Pull this cog repo:
git clone https://github.com/replicate/cog-llama.git
- Build the image
cd cog-llama
cog build
- Install
espeak
in cog
cog run apt-get update -y
cog run apt-get install espeak -y
- exec into the container cog build
cog run bash
- Clone this repo into the cog container
git clone https://github.com/joehoover/bragi.git
- Install requirements
cd bragi
pip install -r requirements.txt
- Install jupyterlab
pip install jupyterlab
- Launch jupyterlab
jupyter lab --allow-root
- Click the last link generated by the jupyter lab process.
- You need to install
espeak
.
On mac:
brew install espeak
On linux
apt-get update -y
apt-get install espeak -y
- You also need to make sure torch is installed. I don't like installing torch with poetry, so it's not specified in the
pyproject.toml
. If your environment doesn't already have torch, run:
pip install torch
See this notebook. But, in general:
from bragi.metric_generator import MetricGenerator
from transformers import LLaMAForCausalLM, LLaMATokenizer
import torch
CACHE_DIR = 'weights'
SEP = "<sep>"
MODEL_PATH = "/src/weights"
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# Load model and tokenizer
model = LLaMAForCausalLM.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True).to(device)
tokenizer = LLaMATokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True)
# Initialize `MetricGenerator`
generator = MetricGenerator(model=model, tokenizer=tokenizer, device=device)
# Generate
torch.manual_seed(2)
output = generator(
prompt = prompt,
text_init = text_init,
free_tokens=['||', '?', '.', ','],
# syllable_budget = torch.Tensor([6., 6.]),
num_return_sequences=1,
no_repeat_ngram_size=2,
remove_invalid_values=True,
do_sample=True,
top_k=25,
temperature=.7,
max_length = 100,
new_line_token='||',
bad_words_ids=[[8876]],
)
print('---text_init----')
print(text_init)
print('\n')
print('----output-----')
print(output)
print('\n')
print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")