Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

any help with DeBERTa #406

Open
traderpedroso opened this issue Jul 20, 2023 · 1 comment
Open

any help with DeBERTa #406

traderpedroso opened this issue Jul 20, 2023 · 1 comment

Comments

@traderpedroso
Copy link

traderpedroso commented Jul 20, 2023

I have made conversions of the model into two supported formats for Rust Bert: one being the OP extension, and the other, the newly supported ONNX format. Despite my endeavors, success has eluded me in the implementation phase. I endeavored to replicate a process which I had formerly accomplished with BART, but it proved unsuccessful. Could anyone kindly provide a preliminary guide or some assistance on how I might execute zero-shot learning with DeBERTa-v3?

from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli")
sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU"
candidate_labels = ["politics", "economy", "entertainment", "environment"]
output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(output)

Here follows my implementation using Facebook's BART Large. In this implementation, I have one endpoint in REST and another in ZeroMQ for stream, utilizing batch streaming.

extern crate actix_web;
extern crate anyhow;
extern crate serde;
extern crate serde_json;
extern crate zmq;

use actix_web::{web, App, HttpResponse, HttpServer, Responder};
use anyhow::Result;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::zero_shot_classification::{
    ZeroShotClassificationConfig, ZeroShotClassificationModel,
};
use rust_bert::resources::LocalResource;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::path::PathBuf;
use std::thread;
use tch::Device;
use zmq::{Context, SocketType};

#[derive(Deserialize, Serialize)]
struct MyData {
    sentence: Vec<String>,
    categories: Vec<String>,
}

fn generation_config(base_path: &str) -> ZeroShotClassificationConfig {
    let model_path = PathBuf::from(base_path.to_owned() + "rust_model.ot");
    let config_path = PathBuf::from(base_path.to_owned() + "config.json");
    let vocab_path = PathBuf::from(base_path.to_owned() + "vocab.json");
    let merges_path = PathBuf::from(base_path.to_owned() + "merges.txt");

    ZeroShotClassificationConfig {
        model_type: ModelType::Bart,
        model_resource: Box::new(LocalResource::from(model_path)),
        config_resource: Box::new(LocalResource::from(config_path)),
        vocab_resource: Box::new(LocalResource::from(vocab_path)),
        merges_resource: Some(Box::new(LocalResource::from(merges_path))),
        lower_case: false,
        strip_accents: None,
        add_prefix_space: None,
        device: Device::cuda_if_available(),
    }
}

fn error_response() -> Result<String, serde_json::Error> {
    let error_message = "Malformed JSON.";
    let example_json = r#"{
        "sentence": ["Who are you going to vote for in 2020?"],
        "categories": ["politics", "economy", "sports"]
    }"#;

    let error_response = json!({
        "error": error_message,
        "example": serde_json::from_str::<Value>(example_json)?
    });

    serde_json::to_string(&error_response)
}


fn predict(
    model: &ZeroShotClassificationModel,
    data: &[String],
    candidate_labels: &[&str],
    batch_size: usize,
) -> Vec<String> {
    let mut predictions = Vec::new();

    for batch_start in (0..data.len()).step_by(batch_size) {
        let batch_end = std::cmp::min(batch_start + batch_size, data.len());
        let batch = &data[batch_start..batch_end];

        let output = model
            .predict_multilabel(
                &batch.iter().map(AsRef::as_ref).collect::<Vec<&str>>(),
                candidate_labels,
                Some(Box::new(|label: &str| {
                    format!("This example is about {label}.")
                })),
                batch_size,
            )
            .unwrap();

        for item in output {
            let prediction = item
                .iter()
                .max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
                .unwrap()
                .text
                .clone();
            predictions.push(prediction);
        }
    }

    predictions
}
 


fn run_zeromq_server() -> Result<()> {
    let ctx = Context::new();
    let socket = ctx.socket(SocketType::REP)?;
    socket.bind("tcp://*:6044")?;

    let base_path = "/root/rustmodels/zeroshot/";

    let config = generation_config(base_path);

    let sequence_classification_model = ZeroShotClassificationModel::new(config)?;

    loop {
        let message = socket.recv_string(0)?.unwrap();
        let data_result: Result<HashMap<String, Vec<String>>, serde_json::Error> =
            serde_json::from_str(&message);

        match data_result {
            Ok(data) => {
                if let (Some(sentence), Some(candidate_labels)) = (data.get("sentence"), data.get("categories")) {
                    let candidate_labels: Vec<&str> = candidate_labels.iter().map(AsRef::as_ref).collect();

                    let predictions = predict(
                        &sequence_classification_model,
                        sentence,
                        &candidate_labels,
                        24,
                    );

                    let result = serde_json::to_string(&predictions)?;
                    socket.send(&result, 0)?;
                } else {
                    let error_json = error_response()?;
                    socket.send(&error_json, 0)?;
                }
            }
            Err(_) => {
                let error_json = error_response()?;
                socket.send(&error_json, 0)?;
            }
        }
    }
}


async fn handle_predict(data: web::Json<MyData>) -> impl Responder {
    let context = zmq::Context::new();
    let requester = context.socket(zmq::REQ).unwrap();
    requester.connect("tcp://127.0.0.1:6044").unwrap();

    let json_data = serde_json::to_string(&*data).unwrap();
if let Err(_) = requester.send(json_data.as_bytes(), 0) {
    return HttpResponse::InternalServerError().body(error_response().unwrap_or_else(|_| String::from("Internal server error.")));
}

let reply = match requester.recv_msg(0) {
    Ok(reply) => reply,
    Err(_) => return HttpResponse::InternalServerError().body(error_response().unwrap_or_else(|_| String::from("Internal server error."))),
};

let reply_str = std::str::from_utf8(&reply).unwrap().to_owned();

    HttpResponse::Ok().body(reply_str)
}

async fn run_http_server() -> std::io::Result<()> {
    HttpServer::new(|| App::new().route("/predict", web::post().to(handle_predict)))
        .bind("0.0.0.0:8081")?
        .run()
        .await
}

fn main() {
    let handle_zeromq = thread::spawn(|| {
        run_zeromq_server().expect("ZeroMQ server failed.");
    });

    let handle_http = thread::spawn(|| {
        let sys = actix_web::rt::System::new();
        sys.block_on(run_http_server()).expect("HTTP server failed.");
    });
    

    handle_zeromq.join().unwrap();
    handle_http.join().unwrap();
}
@Philipp-Sc
Copy link

@traderpedroso Have you made any progress?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants