Skip to content

Commit

Permalink
ensure nltk tokenizer works
Browse files Browse the repository at this point in the history
  • Loading branch information
odeda1 committed Sep 23, 2024
1 parent 14cdcab commit 0d7cf79
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions salt/logic/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import nltk
import pickle
import numpy as np
import pandas as pd
Expand All @@ -7,9 +8,9 @@
from itertools import chain
from typing import Dict, List
from salt.constants import NA
from nltk import sent_tokenize
from sentence_transformers import SentenceTransformer

nltk.download('punkt')

TEXTS_KEY = "texts"
VECTORS_KEY = "vectors"
Expand Down Expand Up @@ -42,7 +43,7 @@ def get_relevant_texts(texts: list[str]) -> list[str]:

def embed_texts(texts: List[str]) -> List[List[float]]:
model = get_model(MODEL_NAME)
texts_sentences = [sent_tokenize(text) for text in texts]
texts_sentences = [nltk.sent_tokenize(text) for text in texts]
texts_lengths = [len(sentences) for sentences in texts_sentences]
all_sentences = list(chain.from_iterable(texts_sentences))
if len(all_sentences) <= BATCH_SIZE:
Expand Down
5 changes: 3 additions & 2 deletions salt/resources/thin_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@
# scikit-learn==1.3.2
# sentence-transformers==2.2.2

import nltk
import pickle
import numpy as np
from tqdm.auto import tqdm
from nltk import sent_tokenize
from sklearn.linear_model import LogisticRegression
from sentence_transformers import SentenceTransformer

nltk.download('punkt')

def load_model(path: str) -> LogisticRegression:
with open(path, "rb") as f:
return pickle.load(f)


def vectorize(text: str, model: SentenceTransformer) -> np.array:
sentences = sent_tokenize(text)
sentences = nltk.sent_tokenize(text)
sentences_embeddings = model.encode(sentences)
return sentences_embeddings.mean(axis=0)

Expand Down

0 comments on commit 0d7cf79

Please sign in to comment.