Skip to content

Commit

Permalink
Merge pull request #6 from atrifat/feat-support-custom-and-hybrid-model
Browse files Browse the repository at this point in the history
Support custom and hybrid model (Fast Inference for CPU-only device)
  • Loading branch information
atrifat authored Jun 25, 2024
2 parents 46de22e + 1930949 commit ce88861
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 10 deletions.
17 changes: 16 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,19 @@ ENABLE_CACHE=false
CACHE_DURATION_SECONDS=60

# (Optional. Default: auto. Options: auto,cpu,cuda) Set torch default device for detoxify library. Automatically detect if cuda/gpu device is available
TORCH_DEVICE=auto
TORCH_DEVICE=auto

# (Required. Default: detoxify. Options: detoxify, hybrid, custom). Set classification model to be used in prediction. "hybrid" and "custom" can be used as fast model if application run on machine without GPU.
HATE_SPEECH_MODEL=detoxify

# (Optional. Default: "". Options: "any potential words") Some potential toxic words can be included to assist hybrid model detection. Hybrid approach uses both custom and detoxify model based on probability thresold.
# POTENTIAL_TOXIC_WORDS="f**k,n***a,ni**er"

# (Optional. Default: 0.5 . Options: float value between 0.0 and 1.0) Probability thresold when using "hybrid" model. The thresold will determine whether to continue classify using detoxify model after using custom model
HYBRID_THRESOLD_CHECK=0.5

# (Optional. Default: "./experiments/model_voting_partial_best.pkl") Custom model path. Custom Pretrained model (pickle) of scikit-learn which implement predict_proba
# CUSTOM_MODEL_PATH="./experiments/model_voting_partial_best.pkl"

# (Optional. Default: "./experiments/vectorizer_count_no_stop_words.pkl") Custom vectorizer path. Custom Vectorizer model (pickle) of scikit-learn which implement vector transform for text
# CUSTOM_VECTORIZER_PATH="./experiments/vectorizer_count_no_stop_words.pkl"
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ ENV HOME=/home/user

COPY --from=builder --chown=user:user /builder/venv /app/venv

COPY --chown=user:user app.py app.py
COPY --chown=user:user app.py app.py

RUN mkdir -p /app/experiments
COPY --chown=user:user experiments/*.pkl experiments/

RUN chown -R user:user /app && chown -R user:user /home/user

Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# hate-speech-detector-api

A Simple PoC (Proof of Concept) of Hate-speech (Toxic content) Detector API Server using model from [detoxify](https://github.com/unitaryai/detoxify). Detoxify (unbiased model) achieves score of 93.74% compared to top leaderboard score with 94.73% in [Jigsaw Unintended Bias in Toxicity Classification](https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification).
A Simple PoC (Proof of Concept) of Hate-speech (Toxic content) Detector API Server using model from [detoxify](https://github.com/unitaryai/detoxify) and/or [custom traditional machine learning](experiments/model_voting_partial_best.pkl) model. Detoxify (unbiased model) achieves AUC score of 93.74% compared to top leaderboard score with AUC 94.73% in [Jigsaw Unintended Bias in Toxicity Classification](https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification). To those who are interested in training custom machine learning model based on [Jigsaw Unintended Bias in Toxicity Classification](https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification) can take a look at our [Jupyter Notebook](experiments/hate-speech-classification.ipynb).

hate-speech-detector-api is a core dependency of [nostr-filter-relay](https://github.com/atrifat/nostr-filter-relay).

## Demo

A demo instance is available on [HuggingFace Spaces - https://atrifat-hate-speech-detector-api-demo.hf.space](https://atrifat-hate-speech-detector-api-demo.hf.space). There is no guarantee for the uptime, but feel free to test.
A demo gradio showcase is available on [HuggingFace Spaces - https://huggingface.co/spaces/rifatramadhani/hate-speech-detector](https://huggingface.co/spaces/rifatramadhani/hate-speech-detector). There is no guarantee for the uptime, but feel free to test.

## Requirements

Expand Down
98 changes: 94 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from flask import Flask, request, jsonify
import functools
import datetime
import torch
from detoxify import Detoxify
import logging
import torch
from flask_caching import Cache
import pickle

load_dotenv()

Expand All @@ -16,11 +17,27 @@
APP_ENV = os.getenv("APP_ENV", "production")
LISTEN_HOST = os.getenv("LISTEN_HOST", "0.0.0.0")
LISTEN_PORT = os.getenv("LISTEN_PORT", "7860")

CUSTOM_MODEL_PATH = os.getenv(
"CUSTOM_MODEL_PATH",
os.path.dirname(os.path.abspath(__file__))
+ "/experiments/model_voting_partial_best.pkl",
)
CUSTOM_VECTORIZER_PATH = os.getenv(
"CUSTOM_VECTORIZER_PATH",
os.path.dirname(os.path.abspath(__file__))
+ "/experiments/vectorizer_count_no_stop_words.pkl",
)

DETOXIFY_MODEL = os.getenv("DETOXIFY_MODEL", "unbiased-small")
HATE_SPEECH_MODEL = os.getenv("HATE_SPEECH_MODEL", "detoxify")
CACHE_DURATION_SECONDS = int(os.getenv("CACHE_DURATION_SECONDS", 60))
ENABLE_CACHE = os.getenv("ENABLE_CACHE", "false") == "true"
POTENTIAL_TOXIC_WORDS = list(
filter(None, os.getenv("POTENTIAL_TOXIC_WORDS", "").split(","))
)
HYBRID_THRESOLD_CHECK = float(os.getenv("HYBRID_THRESOLD_CHECK", 0.5))
TORCH_DEVICE = os.getenv("TORCH_DEVICE", "auto")

APP_VERSION = "0.2.0"

# Setup logging configuration
Expand All @@ -47,7 +64,22 @@
else:
torch_device = TORCH_DEVICE

model = Detoxify(DETOXIFY_MODEL, device=torch_device)
if HATE_SPEECH_MODEL in ["hybrid", "detoxify"]:
model = Detoxify(DETOXIFY_MODEL, device=torch_device)

if HATE_SPEECH_MODEL in ["hybrid", "custom"]:
try:
with open(CUSTOM_VECTORIZER_PATH, "rb") as f:
vectorizer = pickle.load(f)
except Exception as e:
vectorizer = None

try:
with open(CUSTOM_MODEL_PATH, "rb") as f:
model_custom = pickle.load(f)
except Exception as e:
raise e


app = Flask(__name__)

Expand Down Expand Up @@ -107,6 +139,55 @@ def perform_hate_speech_analysis(query):
return result


def perform_hate_speech_analysis_custom(query):
query_vector = vectorizer.transform([query]) if vectorizer != None else ["query"]

result = {
"identity_attack": 0.0,
"insult": 0.0,
"obscene": 0.0,
"severe_toxicity": 0.0,
"sexual_explicit": 0.0,
"threat": 0.0,
"toxicity": 0.0,
}

temp_result = model_custom.predict_proba(query_vector)
result["toxicity"] = temp_result[0][1].round(3).astype("float")

return result


def perform_hate_speech_analysis_hybrid(
query, thresold_check=0.5, potential_toxic_words=[]
):
result = {
"identity_attack": 0.0,
"insult": 0.0,
"obscene": 0.0,
"severe_toxicity": 0.0,
"sexual_explicit": 0.0,
"threat": 0.0,
"toxicity": 0.0,
}
temp_result_custom = perform_hate_speech_analysis_custom(query)

has_potential_toxic_word = False
for word in potential_toxic_words:
if word in query:
has_potential_toxic_word = True
break

if temp_result_custom["toxicity"] > thresold_check or has_potential_toxic_word:
temp_result_detoxify = perform_hate_speech_analysis(query)
if temp_result_detoxify["toxicity"] > thresold_check:
result = temp_result_detoxify
else:
result = temp_result_custom

return result


@app.errorhandler(Exception)
def handle_exception(error):
res = {"error": str(error)}
Expand All @@ -120,7 +201,16 @@ def predict():
data = request.json
q = data["q"]
start_time = datetime.datetime.now()
result = perform_hate_speech_analysis(q)

if HATE_SPEECH_MODEL == "custom":
result = perform_hate_speech_analysis_custom(q)
elif HATE_SPEECH_MODEL == "hybrid":
result = perform_hate_speech_analysis_hybrid(
q, HYBRID_THRESOLD_CHECK, POTENTIAL_TOXIC_WORDS
)
else:
result = perform_hate_speech_analysis(q)

end_time = datetime.datetime.now()
elapsed_time = end_time - start_time
logging.debug("elapsed predict time: %s", str(elapsed_time))
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ detoxify==0.5.1
Flask==3.0.0
Flask-Caching==2.3.0
gunicorn==21.2.0
pandas==2.1.1
python-dotenv==1.0.0
pandas==2.2.2
python-dotenv==1.0.0
scikit-learn==1.5.0

0 comments on commit ce88861

Please sign in to comment.