Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched inference for grammatical score #4

Open
Jack000 opened this issue Jul 5, 2022 · 3 comments
Open

batched inference for grammatical score #4

Jack000 opened this issue Jul 5, 2022 · 3 comments

Comments

@Jack000
Copy link

Jack000 commented Jul 5, 2022

I noticed that the lm_score code processes a single sentence at a time. This is pretty slow if you're processing a large amount of data. I wrote a batched version, though it's a bit ugly. This increases processing speed by about 8x on a single 3090

import torch.nn.functional as F

def get_lm_score(sentences, batch_tokens=42000):

    def score_batch(batch, tokenizer, model):
        inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device)
        batch_scores = []

        with torch.no_grad():
            labels = inputs["input_ids"].clone()
            labels[inputs["input_ids"] == tokenizer.pad_token_id] = -100
            out = model(input_ids=inputs["input_ids"], labels=labels, attention_mask=inputs["attention_mask"], token_type_ids=inputs["token_type_ids"])
            logits = out['logits']

            for j in range(labels.shape[0]):
                loss = F.cross_entropy(logits[j].view(-1, tokenizer.vocab_size), labels[j].view(-1))
                batch_scores.append(math.exp(loss.item()))

        return batch_scores

    model_name = 'bert-base-cased'
    model = BertForMaskedLM.from_pretrained(model_name).to(device)
    model.eval()
    tokenizer = BertTokenizerFast.from_pretrained(model_name)
    lm_score = []

    # sort sentences by length for optimal padding (getting the tokens takes too long so using string length as approximation)
    sentences_flat = []
    for sent in sentences:
        for s in sent:
            sentences_flat.append((s, len(s)))

    sentences_flat.sort(key=lambda x: x[1], reverse=True)

    batches = []

    current_batch_count = 0
    current_batch = []
    for sent in sentences_flat:
        current_batch.append(sent[0])
        current_batch_count += sent[1]
        if current_batch_count > batch_tokens:
            batches.append(current_batch)
            current_batch_count = 0
            current_batch = []

    if len(current_batch) > 0:
        batches.append(current_batch)

    score_dict = {}

    for batch in tqdm(batches):
        batch_score = score_batch(batch, tokenizer, model)
        for j, sent in enumerate(batch):
            score_dict[sent] = batch_score[j]

    for sentence in sentences:
        if len(sentence) == 0:
            lm_score.append(0.0)
            continue
        score_i = 0.0
        for x in sentence:
            if x in score_dict:
                score_i += score_dict[x]
            else:
                score_i += 10000
        score_i /= len(sentence)
        lm_score.append(score_i)
    return lm_score
@WanzhengZhu
Copy link
Owner

Thank you Jack! Can you request a pull?

@Jack000
Copy link
Author

Jack000 commented Jul 6, 2022

ah there are two things I'm not sure about.

huggingface seems to have changed their api for the model.forward call - the above code works for transformers 4.20 (the latest one) but not the one in this repo (3.3.1) The code would have to be changed if you want to keep the current transformer version.

the batched code requires a new parameter for either a batch size or number of tokens per batch. This parameter would need to be set depending on how much vram you have. I'm not sure how you'd like to expose this option in your code.

@WanzhengZhu
Copy link
Owner

Ahhh I see. Thanks for pointing that out. I will check it out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants