From ad25362ba4e6570e70723c4268262a6cb7f13cb8 Mon Sep 17 00:00:00 2001 From: Pablo Gonzalez Date: Sun, 16 Jun 2024 22:46:02 -0500 Subject: [PATCH] MoE reference implementation (#1714) * MoE reference implementation * Fix evaluation script targets * Add evaluation script docker * Rename model to mixtral-8x7b * Format python files * Add skip tokens flag * Add MBXP stop sequence * Add download link for preprocessed dataset --- language/mixtral-8x7b/Dockerfile | 48 +++ language/mixtral-8x7b/Dockerfile.eval | 63 +++ language/mixtral-8x7b/README.md | 249 ++++++++++++ language/mixtral-8x7b/SUT.py | 449 +++++++++++++++++++++ language/mixtral-8x7b/build.sh | 8 + language/mixtral-8x7b/dataset.py | 115 ++++++ language/mixtral-8x7b/evaluate-accuracy.py | 217 ++++++++++ language/mixtral-8x7b/evaluate_mbxp.py | 154 +++++++ language/mixtral-8x7b/launch.sh | 37 ++ language/mixtral-8x7b/main.py | 168 ++++++++ language/mixtral-8x7b/run_accuracy.sh | 22 + language/mixtral-8x7b/run_offline.sh | 10 + language/mixtral-8x7b/run_server.sh | 12 + language/mixtral-8x7b/user.conf | 5 + mlperf.conf | 9 +- 15 files changed, 1565 insertions(+), 1 deletion(-) create mode 100644 language/mixtral-8x7b/Dockerfile create mode 100644 language/mixtral-8x7b/Dockerfile.eval create mode 100644 language/mixtral-8x7b/README.md create mode 100644 language/mixtral-8x7b/SUT.py create mode 100644 language/mixtral-8x7b/build.sh create mode 100644 language/mixtral-8x7b/dataset.py create mode 100644 language/mixtral-8x7b/evaluate-accuracy.py create mode 100644 language/mixtral-8x7b/evaluate_mbxp.py create mode 100644 language/mixtral-8x7b/launch.sh create mode 100644 language/mixtral-8x7b/main.py create mode 100644 language/mixtral-8x7b/run_accuracy.sh create mode 100644 language/mixtral-8x7b/run_offline.sh create mode 100644 language/mixtral-8x7b/run_server.sh create mode 100644 language/mixtral-8x7b/user.conf diff --git a/language/mixtral-8x7b/Dockerfile b/language/mixtral-8x7b/Dockerfile new file mode 100644 index 000000000..b04910d73 --- /dev/null +++ b/language/mixtral-8x7b/Dockerfile @@ -0,0 +1,48 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 +SHELL ["/bin/bash", "-c"] + +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 + +ENV TZ=US/Pacific +ENV DEBIAN_FRONTEND=noninteractive + +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone +RUN rm -rf /var/lib/apt/lists/* && rm /etc/apt/sources.list.d/* \ + && apt update \ + && apt install -y --no-install-recommends build-essential autoconf \ + libtool git ccache curl wget pkg-config sudo ca-certificates \ + automake libssl-dev bc python3-dev python3-pip google-perftools \ + gdb libglib2.0-dev clang sshfs libre2-dev libboost-dev \ + libnuma-dev numactl sysstat sshpass ntpdate less iputils-ping \ + && apt -y autoremove \ + && apt remove -y cmake \ + && apt install -y --no-install-recommends pkg-config zip g++ zlib1g-dev \ + unzip libarchive-dev +RUN apt install -y --no-install-recommends rsync + +# Install setuptools +RUN python3 -m pip install --upgrade pip \ + && python3 -m pip install --upgrade setuptools wheel virtualenv + +# Install conda +WORKDIR /tmp +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh \ + && bash Miniconda3-* -b -p /opt/miniconda3 +ENV PATH="$PATH:/opt/miniconda3/bin" +RUN conda create -n llama2-70b python=3.10 +RUN chmod -R 777 /opt/miniconda3 diff --git a/language/mixtral-8x7b/Dockerfile.eval b/language/mixtral-8x7b/Dockerfile.eval new file mode 100644 index 000000000..9fc13772b --- /dev/null +++ b/language/mixtral-8x7b/Dockerfile.eval @@ -0,0 +1,63 @@ +# Use Ubuntu 22.04 as the base image +FROM ubuntu:22.04 +ARG DEBIAN_FRONTEND=noninteractive + +# Update package lists +RUN apt-get update + +# Install Python 3 and pip +RUN apt-get install -y python3 python3-pip git + +# Set Python 3 as the default python interpreter +RUN ln -s /usr/bin/python3 /usr/bin/python + +# Verify installation +RUN python --version +RUN pip --version +RUN git --version + +# Install requirements +RUN pip install transformers nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 huggingface-cli + +# Clone and install mxeval +RUN git clone https://github.com/amazon-science/mxeval.git +RUN pip install -e mxeval + +# Get language dependencies +RUN apt install -y wget + +# Ruby +RUN apt install -y curl libssl-dev libreadline-dev zlib1g-dev autoconf bison build-essential libyaml-dev libreadline-dev libncurses5-dev libffi-dev libgdbm-dev + +# PHP +RUN apt install -y software-properties-common ca-certificates lsb-release apt-transport-https +RUN add-apt-repository ppa:ondrej/php +RUN apt-get update +RUN apt install -y php8.0 +# RUN apt install -y php-{pear,cgi,pdo,common,curl,mbstring,gd,mysqlnd,gettext,bcmath,json,xml,fpm,intl,zip} + +# JAVA +RUN apt-get install -y openjdk-8-jdk + +# JAVASCRIPT +RUN apt install -y npm + +# SCALA +RUN apt-get install -y scala + +# C# +RUN apt-get install -y dotnet6 + +# Kotlin +RUN apt install -y zip unzip + +SHELL ["/bin/bash", "-c"] +WORKDIR "/mxeval" +RUN sed -i 's/sudo//g' /mxeval/language_setup/ubuntu.sh +RUN sed -i 's/source/PS1=1 source/g' /mxeval/language_setup/ubuntu.sh # Need this to make sure that the "source ~/.bashrc" lines work correctly +RUN sed -i 's/npx tsc/tsc/g' /mxeval/mxeval/execution.py # npx tsc runs into permission issues + +RUN PATH="$HOME/.rbenv/bin:$PATH" bash /mxeval/language_setup/ubuntu.sh + +WORKDIR "/" +CMD bash diff --git a/language/mixtral-8x7b/README.md b/language/mixtral-8x7b/README.md new file mode 100644 index 000000000..e301891d9 --- /dev/null +++ b/language/mixtral-8x7b/README.md @@ -0,0 +1,249 @@ +# Reference Implementation for Mixtral-8x7B-instruct-v0.1 + +**Basic implementation for Mixtral-8x7B-instruct-v0.1. Few noteworthy items:** + ++ Dataset was constructed by randomly sampling from the validation split of 3 datasets, open_orca_gpt4, GSM8k and MBXP. 5K samples from each one. ++ Streamer for communicating with loadgen has quite some overhead. This is only meant to provide functional implementation ++ For custom/optimized implementations of this benchmark it is important to include the : + - For server scenario, it is necessary to call `lg.FirstTokenComplete(response)` for each query. This way the first token will be reported and it's latency will be measured. + - For all scenarios, when calling `lg.QuerySamplesComplete(response)`, it is necessary that each of the elements in response is a `lg.QuerySampleResponse` that contains the number of tokens (can be create this way: `lg.QuerySampleResponse(qitem.id, bi[0], bi[1], n_tokens)`). The number of tokens reported should match with the number of tokens on your answer and this will be checked in [TEST06](../../compliance/nvidia/TEST06/) + +## Automated command to run the benchmark via MLCommons CM + +TODO + +## Prepare environment + +Copy the mlperf.conf file to this folder. +``` +cp ../../mlperf.conf . +``` + +For a CPU-only run: + +``` +conda create -n Mixtral-8x7B python=3.9 +conda activate Mixtral-8x7B + +# Install packages +conda install pybind11==2.10.4 -c conda-forge -y +python -m pip install torch==2.2.0.dev20231006+cpu --index-url https://download.pytorch.org/whl/nightly/cpu +pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 +pip install git+https://github.com/amazon-science/mxeval.git@e09974f990eeaf0c0e8f2b5eaff4be66effb2c86 + +export CUR_DIR=${PWD} +cd /loadgen + +python -m pip install . +``` + +For a GPU-based run: + +A dockerfile is provided, along with scripts to help launch it. First, add any docker volume mounts you want in +`launch.sh`. There is a section at the top of the file that looks like: +``` +# Add any volume mounts here with the following syntax +# /path/to/src:/path/to/dir/in/container +MOUNTS=( + $MLCOMMONS_REPO_PATH:$MLCOMMONS_REPO_PATH +) +``` + +For example if you have a raid space located at `/raid/data` on your local machine, you can add it to the same path in the container like so: +``` +# Add any volume mounts here with the following syntax +# /path/to/src:/path/to/dir/in/container +MOUNTS=( + $MLCOMMONS_REPO_PATH:$MLCOMMONS_REPO_PATH + /raid/data:/raid/data +) +``` +Once you have added all your mounts, launch the container with `bash launch.sh`. + +Inside the container, set up the environment with `bash build.sh`. This will install all the dependencies from the +CPU-only setup, as well as any GPU versions for applicable libraries like PyTorch. + + +## Get Model +### MLCommons Members Download + +TODO: Create MLCommons get fixed link. +For now it can be downloaded from [Hugging Face](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/tree/main) + +## Get Dataset + +TODO: Create scripts and procedure to download all of the parts of the dataset + +### Preprocessed + +#### Using Rclone +We make many of the MLPerf infernce models and datasets available using Rclone. In order to keep compatibility, you can use Rclone to get the preprocessed dataset: + +To run Rclone on Windows, you can download the executable [here](https://rclone.org/install/#windows). +To install Rclone on Linux/macOS/BSD systems, run: +```bash +sudo -v ; curl https://rclone.org/install.sh | sudo bash +``` +Once Rclone is installed, cd into the folder where you want to place the dataset and run: +```bash +rclone copyurl https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl ./ -a -P +``` +#### Using wget + +Alternatively, you can simply cd into the folder where you want to place the dataset and run +```bash +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl +``` + +### Unprocessed + +TODO: Share instructions and scripts + +## Run Performance Benchmarks + +### Offline +``` +python -u main.py --scenario Offline \ + --model-path ${CHECKPOINT_PATH} \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 15000 \ + --device cpu \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir offline-logs + +``` + +For a GPU-based run: +``` +python3 -u main.py --scenario Offline \ + --model-path ${CHECKPOINT_PATH} \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 15000 \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir offline-logs \ + --dtype float32 \ + --device cuda:0 2>&1 | tee offline_performance_log.log +``` + +### Server +``` +python -u main.py --scenario Server \ + --model-path ${CHECKPOINT_PATH} \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 15000 \ + --device cpu \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir server-logs +``` + +The ServerSUT was not tested for GPU runs. + + +## Run Accuracy Benchmarks + +### Offline +``` +OUTPUT_LOG_DIR=offline-accuracy-logs + +mkdir -p "run_outputs" # The script will dump all the outputs to 'run_outputs'. + +python -u main.py --scenario Offline \ + --model-path ${CHECKPOINT_PATH} \ + --accuracy \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 15000 \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir ${OUTPUT_LOG_DIR} \ + --device cpu + + +ACCURACY_LOG_FILE=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json +if [ -e ${ACCURACY_LOG_FILE} ]; then + python evaluate-accuracy.py --checkpoint-path ${CHECKPOINT_PATH} \ + --mlperf-accuracy-file ${ACCURACY_LOG_FILE} --dataset-file ${DATASET_PATH} --dtype int32 +fi + +# Optional: Create a pickled pandas DataFrame that is the original dataset with extra columns with output data from the +# accuracy run. The following columns will be added: +# - "gen_output_tok_id": A list of ints representing the tokenized output sequence. +# - "gen_output_text": A str representing the untokenized output sequence. +# - "gen_output_tok_len": An int representing the number of output tokens. +# - "rouge1": The rouge1 score for this sample +# - "rouge2": The rouge2 score for this sample +# - "rougeL": The rougeL score for this sample +# This file will by default be saved to 'full_output.pkl'. You can modify this with --output-pkl-path. +python consolidate_results.py --dataset-path ${DATASET_PATH} --model-dir ${CHECKPOINT_PATH} +``` + +For the GPU run - The above steps have been automated in `run_accuracy.sh`. You can also modify this script to use +`--device cpu` to adapt it to a CPU-only run. + + +### Server +``` +OUTPUT_LOG_DIR=server-accuracy-logs + +python -u main.py --scenario Server \ + --model-path ${CHECKPOINT_PATH} \ + --accuracy \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 15000 \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir ${OUTPUT_LOG_DIR} \ + --device cpu + + +ACCURACY_LOG_FILE=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json +if [ -e ${ACCURACY_LOG_FILE} ]; then + python evaluate-accuracy.py --checkpoint-path ${CHECKPOINT_PATH} \ + --mlperf-accuracy-file ${ACCURACY_LOG_FILE} --dataset-file ${DATASET_PATH} --dtype int32 +fi +``` + +The ServerSUT was not tested for GPU runs. + +### Evaluation +Recreating the enviroment for evaluating the quality metrics can be quite tedious. Therefore we provide a dockerfile and recommend using docker for this task. +1. Build the evaluation container +```bash +docker build . -f Dockerfile.eval -t evaluation +``` +2. Run the docker in interactive mode and with +```bash +sudo docker run -it -v $(pwd):/eval -t evaluation +``` +3. +```bash +cd eval +huggingface-cli login --token [huggingface_token] +python -u evaluate-accuracy.py --checkpoint-path mistralai/Mixtral-8x7B-instruct-v0.1 \ + --mlperf-accuracy-file [path_to_mlperf_accuracy_file] \ + --dataset-file [path_to_dataset] \ + --n_workers 8 +``` + + +## Accuracy Target + +Reference scores: +Open Orca: +```json +{'rouge1': 45.4911, 'rouge2': 23.2829, 'rougeL': 30.3615, 'rougeLsum': 42.4333} +``` +GSM8K: +```json +{'gsm8k_accuracy': 73.78} +``` +MBXP: +```json +{'mbxp_accuracy': 60.16} +``` +For official submissions, 99% of each reference score is enforced. Additionally, 90%-110% of the generated tokens_per_samples: +```json +{'tokens_per_sample': 145.9} +``` \ No newline at end of file diff --git a/language/mixtral-8x7b/SUT.py b/language/mixtral-8x7b/SUT.py new file mode 100644 index 000000000..9ed44dbf5 --- /dev/null +++ b/language/mixtral-8x7b/SUT.py @@ -0,0 +1,449 @@ +import os +import time +import numpy as np +import array +import torch +from torch.nn.functional import pad +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList +from transformers.generation.streamers import BaseStreamer + +import pickle +import time +import threading +import tqdm +import queue + +import logging +from typing import TYPE_CHECKING, Optional, List +from pathlib import Path + +import mlperf_loadgen as lg +from dataset import Dataset + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("Mixtral-8x7B-Instruct-v0.1") + +gen_kwargs = { + "early_stopping": True, + "max_new_tokens": 1024, + "min_new_tokens": 1, + "num_beams": 1, + "do_sample": False +} + +class StopAfterSequence(LogitsProcessor): + """Logits processor (to use with HuggingFace `generate()` method : + https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/ + text_generation#transformers.generation_utils.GenerationMixin). + + This logits processor makes that when the model generates a specified + stopping sequence, it stops generating new tokens + + Args: + stop_seq (List[int]): ID of the space token. + eos_token_id (int): ID of the EOS token. + device (str): Device that the model is running + """ + def __init__(self, eos_token_id: int, stop_seq: List[int] = [13, 13940, 28832, 13], device="cpu"): + super().__init__() + assert(len(stop_seq) >= 1) + self.device = device + self.stop_seq = torch.tensor(stop_seq, dtype=torch.long).to(device) + self.stop_seq_length = len(stop_seq) + self.eos_token_id = eos_token_id + + def check_stop_condition(self, input_ids: torch.LongTensor): + stop_condition_met = (input_ids[:, -self.stop_seq_length:] == self.stop_seq).all(dim=1) + return stop_condition_met + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if input_ids.size(1) > self.stop_seq_length: + forced_eos = torch.full((scores.size(1),), -float("inf")).to(self.device) + forced_eos[self.eos_token_id] = 0 + scores[self.check_stop_condition(input_ids)] = forced_eos + return scores + + +class FirstTokenStreamer(BaseStreamer): + """ Streams first tokens to a 'holder' """ + + def __init__(self, first_token, tokens_cache=[], + is_first_token=True, response_ids=[]): + """ Response ids added to 'sign' the first token""" + + self.first_token = first_token # Queue for first token + self.is_first_token = is_first_token + + # Cache for subsequent generated tokens + self.tokens_cache = tokens_cache + + self.response_ids = response_ids + + # The first tokens sent to the streamer are actually the input prompts + self.is_prompt = True + + def put(self, value): + """ Caches the tokens as they're generated. Assumes bs=1 """ + + # Prompts are streamed first so we need to skip the first time value + # that arrives + if self.is_prompt: + self.is_prompt = False + return + + value = value.item() + if self.is_first_token: + + # Add generated first token together with its query response_id to + # first tokens queue + self.first_token.put((value, self.response_ids[0])) + + self.is_first_token = False + return + + self.tokens_cache.append(value) + + def end(self): + pass + + def get_out_tokens(self): + return self.tokens_cache + + +class SUT(): + def __init__(self, + model_path=None, + dtype="bfloat16", + device="cpu", + batch_size=None, + total_sample_count=24576, + dataset_path=None, + use_cached_outputs=False, + # Set this to True *only for test accuracy runs* in case your + # prior session was killed partway through + workers=1): + + self.model_path = model_path or "mistralai/Mixtral-8x7B-Instruct-v0.1" + self.device = device + + if not batch_size: + if device == "cpu": + batch_size = 1 + else: + batch_size = 32 # Reduce to 8 if using 4 GPUs, 16 for 8. + self.batch_size = batch_size + + # dtype + if dtype == 'bfloat16': + self.amp_enabled = True + self.amp_dtype = torch.bfloat16 + elif dtype == 'float16': + self.amp_enabled = True + self.amp_dtype = torch.float16 + else: + self.amp_enabled = False + self.amp_dtype = torch.float32 + + if 'cuda' in self.device: + assert torch.cuda.is_available(), "torch gpu is not available, exiting..." + + self.dataset_path = dataset_path + self.data_object = Dataset(self.model_path, + dataset_path=self.dataset_path, + total_sample_count=total_sample_count, + device=self.device) + self.qsl = lg.ConstructQSL(self.data_object.total_sample_count, self.data_object.perf_count, + self.data_object.LoadSamplesToRam, self.data_object.UnloadSamplesFromRam) + + self.load_model() + + self.num_workers = workers + self.worker_threads = [None] * self.num_workers + self.query_queue = queue.Queue() + + self.use_cached_outputs = use_cached_outputs + self.sample_counter = 0 + self.sample_counter_lock = threading.Lock() + + def start(self): + # Create worker threads + for j in range(self.num_workers): + worker = threading.Thread(target=self.process_queries) + worker.start() + self.worker_threads[j] = worker + + def stop(self): + for _ in range(self.num_workers): + self.query_queue.put(None) + + for worker in self.worker_threads: + worker.join() + + def process_queries(self): + """Processor of the queued queries. User may choose to add batching logic """ + + while True: + qitem = self.query_queue.get() + if qitem is None: + break + + query_ids = [q.index for q in qitem] + + fname = "q" + "_".join([str(i) for i in query_ids]) + fname = f"run_outputs/{fname}.pkl" + _p = Path(fname) + if self.use_cached_outputs and _p.exists(): + # Read cache + with _p.open(mode="rb") as f: + d = pickle.load(f) + processed_output = d["outputs"] + tik1 = None + tik2 = None + tik3 = None + tok = None + else: + # Construct / collate batch + max_seq_len = 1024 + + tik1 = time.time() + + input_ids_tensor = [] + input_masks_tensor = [] + input_len = [] + input_dataset = [] + for q in qitem: + input_ids_tensor.append(pad(self.data_object.input_ids[q.index], + (max_seq_len - + self.data_object.input_lens[q.index], 0, 0, 0), + value=self.tokenizer.pad_token_id)) + input_masks_tensor.append(pad(self.data_object.attention_masks[q.index], + (max_seq_len - + self.data_object.input_lens[q.index], 0, 0, 0), + value=0)) + input_len.append(self.data_object.input_lens[q.index]) + + # In case we predict code generation, we can specify an additional stop sequence + input_dataset.append(self.data_object.dataset_names[q.index]) + input_ids_tensor = torch.cat(input_ids_tensor) + input_masks_tensor = torch.cat(input_masks_tensor) + + assert input_ids_tensor.shape == input_masks_tensor.shape + assert input_ids_tensor.shape[0] <= self.batch_size + + tik2 = time.time() + logits_processor = LogitsProcessorList([StopAfterSequence(self.tokenizer.eos_token_id, device=self.device)]) + for i in range(len(input_ids_tensor)): + ids, masks, dataset = input_ids_tensor[i:i+1], input_masks_tensor[i:i+1], input_dataset[i] + pred_output_tokens = [] + if dataset == "MBXP": + out = self.model.generate( + input_ids=ids, + attention_mask=masks, + pad_token_id=self.tokenizer.pad_token_id, + logits_processor=logits_processor, + **gen_kwargs + ) + else: + out = self.model.generate( + input_ids=ids, + attention_mask=masks, + pad_token_id=self.tokenizer.pad_token_id, + **gen_kwargs + ) + pred_output_tokens.append(out) + pred_output_tokens = torch.cat(pred_output_tokens) + tik3 = time.time() + + processed_output = self.data_object.postProcess(pred_output_tokens, + input_seq_lens=input_len, + query_id_list=query_ids) + + for i in range(len(qitem)): + n_tokens = processed_output[i].shape[0] + response_array = array.array( + "B", processed_output[i].tobytes()) + bi = response_array.buffer_info() + response = [ + lg.QuerySampleResponse( + qitem[i].id, + bi[0], + bi[1], + n_tokens)] + lg.QuerySamplesComplete(response) + + tok = time.time() + + with self.sample_counter_lock: + self.sample_counter += len(qitem) + print(f"Samples run: {self.sample_counter}") + if tik1: + print(f"\tBatchMaker time: {tik2 - tik1}") + print(f"\tInference time: {tik3 - tik2}") + print(f"\tPostprocess time: {tok - tik3}") + print(f"\t==== Total time: {tok - tik1}") + else: + print(f"\tLoaded from cache: {_p}") + + def load_model(self): + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, + device_map="auto", + low_cpu_mem_usage=True, + torch_dtype=self.amp_dtype + ) + print("Loaded model") + + self.device = torch.device(self.device) + if self.device == "cpu": + # Force CPU if your system has GPU and you specifically want + # CPU-only run + self.model = self.model.to(self.device) + + self.model.eval() + self.model = self.model.to(memory_format=torch.channels_last) + + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + model_max_length=1024, + padding_side="left", + use_fast=False,) + + self.tokenizer.pad_token = self.tokenizer.eos_token + print("Loaded tokenizer") + + def get_sut(self): + self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries) + return self.sut + + def get_qsl(self): + return self.qsl + + def predict(self, **kwargs): + raise NotImplementedError + + def issue_queries(self, query_samples): + """ Receives samples from loadgen and adds them to queue. Users may choose to batch here""" + + list_prompts_tokens = [] + list_prompts_attn_masks = [] + + print(f"IssueQuery started with {len(query_samples)} samples") + while len(query_samples) > 0: + self.query_queue.put(query_samples[:self.batch_size]) + query_samples = query_samples[self.batch_size:] + print(f"IssueQuery done") + + def flush_queries(self): + pass + + def __del__(self): + pass + + +class SUTServer(SUT): + def __init__(self, model_path=None, dtype="bfloat16", device="cpu", + total_sample_count=24576, dataset_path=None, workers=1): + + super().__init__( + model_path=model_path, + dtype=dtype, + device=device, + total_sample_count=total_sample_count, + dataset_path=dataset_path, + workers=workers) + + self.first_token_queue = queue.Queue() + + def start(self): + + # Create worker threads + for j in range(self.num_workers): + worker = threading.Thread(target=self.process_queries) + worker.start() + self.worker_threads[j] = worker + + # Create first token response thread + self.ft_response_thread = threading.Thread( + target=self.process_first_tokens) + self.ft_response_thread.start() + + def process_first_tokens(self): + + while True: + first_token_item = self.first_token_queue.get() + + if first_token_item is None: + log.info("Exiting First token response thread") + break + + first_tokens, response_id = first_token_item + + response_data = array.array("B", np.array( + first_tokens, np.float32).tobytes()) + bi = response_data.buffer_info() + response = [lg.QuerySampleResponse(response_id, bi[0], bi[1])] + lg.FirstTokenComplete(response) + + def process_queries(self): + """Processor of the queued queries. User may choose to add batching logic """ + while True: + + qitem = self.query_queue.get() + if qitem is None: + break + + input_ids_tensor = self.data_object.input_ids[qitem.index] + input_masks_tensor = self.data_object.attention_masks[qitem.index] + dataset = self.data_object.dataset_names[qitem.index] + + # TODO: This PoC is super slow with significant overhead. Best to + # create a patch to `generate` + tokens_cache = [] + tokens_streamer = FirstTokenStreamer( + self.first_token_queue, + tokens_cache=tokens_cache, + is_first_token=True, + response_ids=[ + qitem.id]) + + logits_processor = LogitsProcessorList([StopAfterSequence(self.tokenizer.eos_token_id, device=self.device)]) + if dataset == "MBXP": + _ = self.model.generate(input_ids=input_ids_tensor, + attention_mask=input_masks_tensor, + pad_token_id=self.tokenizer.pad_token_id, + streamer=tokens_streamer, + logits_processor=logits_processor, + **gen_kwargs + ) + else: + _ = self.model.generate(input_ids=input_ids_tensor, + attention_mask=input_masks_tensor, + pad_token_id=self.tokenizer.pad_token_id, + streamer=tokens_streamer, + **gen_kwargs + ) + + output_tokens = tokens_streamer.get_out_tokens() + n_tokens = len(output_tokens) + response_array = array.array( + "B", np.array( + output_tokens, np.int32).tobytes()) + bi = response_array.buffer_info() + response = [lg.QuerySampleResponse( + qitem.id, bi[0], bi[1], n_tokens)] + lg.QuerySamplesComplete(response) + + def issue_queries(self, query_samples): + + self.query_queue.put(query_samples[0]) + + def stop(self): + for _ in range(self.num_workers): + self.query_queue.put(None) + + for worker in self.worker_threads: + worker.join() + + self.first_token_queue.put(None) + self.ft_response_thread.join() diff --git a/language/mixtral-8x7b/build.sh b/language/mixtral-8x7b/build.sh new file mode 100644 index 000000000..87afb992f --- /dev/null +++ b/language/mixtral-8x7b/build.sh @@ -0,0 +1,8 @@ +set -e + +conda install pybind11==2.10.4 -c conda-forge -y +conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch-nightly -c nvidia +python -m pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 + + +cd ../../loadgen && python3 -m pip install . diff --git a/language/mixtral-8x7b/dataset.py b/language/mixtral-8x7b/dataset.py new file mode 100644 index 000000000..d2cafac63 --- /dev/null +++ b/language/mixtral-8x7b/dataset.py @@ -0,0 +1,115 @@ +import random +import os +import time +import numpy as np +import torch +from datasets import load_dataset, load_from_disk +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch.nn.functional import pad +from torch.utils.data import DataLoader +from typing import Optional, Dict, Sequence +import io +# import utils +import copy +import pickle + +import logging +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("Llama-70B-Dataset") + + +class Dataset(): + def __init__(self, model_name=None, total_sample_count=15000, + perf_count_override=None, dataset_path=None, device="cpu"): + self.model_name = model_name or "mistralai/Mixtral-8x7B-v0.1" + self.dataset_path = dataset_path + self.max_length = 1024 + self.device = device + + # self.total_sample_count = total_sample_count + + self.load_tokenizer() + self.load_processed_dataset() + + self.total_sample_count = min(len(self.input_ids), total_sample_count) + self.perf_count = perf_count_override or self.total_sample_count + + def load_tokenizer(self): + """ Returns tokenizer """ + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, + model_max_length=1024, + padding_side="left", + use_fast=False,) + + self.tokenizer.pad_token = self.tokenizer.eos_token + + def load_processed_dataset(self): + if not os.path.isfile(self.dataset_path): + log.warn( + "Processed pickle file {} not found. Please check that the path is correct".format( + self.dataset_path)) + + print("Loading dataset...") + import pandas as pd + processed_data = pd.read_pickle(self.dataset_path) + + input_tokens = processed_data['tok_input'] + + self.input_ids = [] + self.input_lens = [] + self.attention_masks = [] + self.dataset_names = [] + + for ids in input_tokens: + input_ids = torch.tensor(ids, dtype=torch.int32).view( + 1, -1).to(self.device) + attn_mask = torch.ones_like(input_ids) + self.input_ids.append(input_ids) + self.attention_masks.append(attn_mask) + self.input_lens.append(input_ids.shape[-1]) + + for dataset in processed_data['dataset']: + self.dataset_names.append(dataset) + print("Finished loading dataset.") + + def postProcess(self, out_tokens, input_seq_lens=None, + query_id_list=None, sample_index_list=None): + """ Postprocesses output prediction """ + + # TODO: Create response object in postProcess(?) + """ + preds = [] + for i in range(out_tokens.shape[0]): + #pred = out_tokens[i].reshape(-1).cpu().numpy() # Slice up to original input length as below? + + input_len = input_seq_lens[i] if input_seq_lens else 0 + pred = out_tokens[i, input_len:].reshape(-1).cpu().numpy() + preds.append(pred) + """ + # Everything is padded to max_len (1024), so prune the input and parse + # to numpy + output_seq = out_tokens[:, 1024:].cpu().numpy() + assert len(query_id_list) == output_seq.shape[0] + + # Save outputs + if not os.path.exists("run_outputs"): + os.makedirs("run_outputs") + fname = "q" + "_".join([str(i) for i in query_id_list]) + fname = f"run_outputs/{fname}.pkl" + with open(fname, mode='wb') as f: + d = {"query_ids": query_id_list, + "outputs": output_seq} + print(f"Saving outputs to {fname}") + pickle.dump(d, f) + + return output_seq + + def LoadSamplesToRam(self, sample_list): + pass + + def UnloadSamplesFromRam(self, sample_list): + pass + + def __del__(self): + pass diff --git a/language/mixtral-8x7b/evaluate-accuracy.py b/language/mixtral-8x7b/evaluate-accuracy.py new file mode 100644 index 000000000..e20834c41 --- /dev/null +++ b/language/mixtral-8x7b/evaluate-accuracy.py @@ -0,0 +1,217 @@ +import argparse +from transformers import AutoTokenizer +import nltk +import evaluate +import numpy as np +import pandas as pd +import json +import re + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint-path", required=True, + help="Path to Llama2-70b-hf-chat checkpoint") + parser.add_argument("--mlperf-accuracy-file", required=True, + help="path to mlperf_log_accuracy.json") + parser.add_argument("--dataset-file", required=True, + help="path to processed validation dataset") + parser.add_argument("--n_workers", default=2, type=int, + help="Number of workers used for the MBXP evaluation") + parser.add_argument("--verbose", action="store_true", + help="verbose messages") + parser.add_argument("--dtype", default="int64", + help="dtype of the accuracy log", choices=["int32", "int64", "float"]) + args = parser.parse_args() + return args + + +def get_groundtruth(processed_dataset_file): + data = pd.read_pickle(processed_dataset_file) + return data + + +# Functions for evaluating GSM8K +def find_numbers(x: str) -> list[str]: + """Finds all numbers in a string.""" + # Search for number, possibly negative (hyphen), with thousand separators + # (comma), and with a decimal point (period inbetween digits). + numbers = re.compile( + r'-?[\d,]*\.?\d+', + re.MULTILINE | re.DOTALL | re.IGNORECASE, + ).findall(x) + return numbers + + +def find_number(x: str, + answer_delimiter: str = 'The answer is') -> str: + """Finds the most relevant number in a string.""" + # If model uses the answer delimiter, then select the first number following + # that format. + if answer_delimiter in x: + answer = x.split(answer_delimiter)[-1] + numbers = find_numbers(answer) + if numbers: + return numbers[0] + + # In general, select the last number in the string. + numbers = find_numbers(x) + if numbers: + return numbers[-1] + return '' + + +def maybe_remove_comma(x: str) -> str: + # Example: 5,600 -> 5600 + return x.replace(',', '') + + +def try_float(x: str): + try: + ret = float(x) + except BaseException: + ret = None + return ret + +# Functions for evaluating OpenOrca + + +def postprocess_text(preds, targets): + preds = [pred.strip() for pred in preds] + targets = [target.strip() for target in targets] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] + + return preds, targets + +# Functions for MBXP + + +def create_mbxp_dict(row, response): + lang, entry_point = row["id"].split("_", 1) + return { + "lang": lang, + "prompt": row["input"], + "test_code": row["gt_output"], + "entry_point": entry_point, + "response": response + } + + +def main(): + + args = get_args() + dataset_path = args.dataset_file + checkpoint_path = args.checkpoint_path + metric = evaluate.load("rouge") + nltk.download('punkt') + + tokenizer = AutoTokenizer.from_pretrained( + checkpoint_path, + model_max_length=2048, + padding_side="left", + use_fast=False,) + + data = get_groundtruth(args.dataset_file) + query_types, gt_outputs = data["dataset"], data["gt_output"] + + target_required_GSM8K = [] + target_required_OpenOrca = [] + results_MBXP = [] + preds_token_GSM8K = [] + preds_token_OpenOrca = [] + preds_token_MBXP = [] + + eval_dtype = np.int64 + if args.dtype == "int32": + eval_dtype = np.int32 + elif args.dtype == "float": + eval_dtype = np.float32 + + with open(args.mlperf_accuracy_file, "r") as f: + results = json.load(f) + + seen = set() + gen_tok_len = 0 + gen_num = 0 + for pred in results: + gen_num += 1 + qsl_idx = pred['qsl_idx'] + if qsl_idx in seen: + continue + + seen.add(qsl_idx) + + query_type = query_types.iloc[qsl_idx] + if query_type == "GSM8K": + target = gt_outputs.iloc[qsl_idx] + target_required_GSM8K.append(target) + pred = np.frombuffer(bytes.fromhex(pred['data']), eval_dtype) + + gen_tok_len += len(pred) + preds_token_GSM8K.append(pred) + elif query_type == "OpenOrca": + target = gt_outputs.iloc[qsl_idx] + target_required_OpenOrca.append(target) + pred = np.frombuffer(bytes.fromhex(pred['data']), eval_dtype) + + gen_tok_len += len(pred) + preds_token_OpenOrca.append(pred) + else: + target = data.iloc[qsl_idx] + pred = np.frombuffer(bytes.fromhex(pred['data']), eval_dtype) + pred_str = tokenizer.decode(pred, skip_special_tokens=True) + results_MBXP.append(create_mbxp_dict(target, pred_str)) + + gen_tok_len += len(pred) + + # OpenOrca metric + preds_decoded_text = tokenizer.batch_decode( + preds_token_OpenOrca, skip_special_tokens=True) + + preds, targets = postprocess_text( + preds_decoded_text, target_required_OpenOrca) + result = metric.compute( + predictions=preds, references=targets, use_stemmer=True, use_aggregator=False) + result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} + prediction_lens = [len(pred) for pred in preds] + # GSM8K metric + preds_decoded_text = tokenizer.batch_decode( + preds_token_GSM8K, skip_special_tokens=True) + pred_nums = [ + maybe_remove_comma( + find_number( + pred_text.split("\nQ:")[0])) for pred_text in preds_decoded_text] + gsm8k_total = len(target_required_GSM8K) + correct = 0 + for idx in range(len(target_required_GSM8K)): + ref = try_float(target_required_GSM8K[idx]) + tgt = try_float(pred_nums[idx]) + if tgt is None: + continue + correct += (ref == tgt) + + gsm8k_accuracy = 100.0 * correct / gsm8k_total + + # MBXP metric + from evaluate_mbxp import evaluate_mbxp + mbxp_accuracy = evaluate_mbxp(results_MBXP, args.n_workers) + + result = { + **result, + 'gen_len': np.sum(prediction_lens), + 'gen_num': gen_num, + 'gen_tok_len': gen_tok_len, + 'tokens_per_sample': round(gen_tok_len / gen_num, 1), + 'gsm8k_accuracy': gsm8k_accuracy, + 'mbxp_accuracy': mbxp_accuracy + } + + print("\nResults\n") + print(result) + + +if __name__ == "__main__": + main() diff --git a/language/mixtral-8x7b/evaluate_mbxp.py b/language/mixtral-8x7b/evaluate_mbxp.py new file mode 100644 index 000000000..e7d55d169 --- /dev/null +++ b/language/mixtral-8x7b/evaluate_mbxp.py @@ -0,0 +1,154 @@ +import argparse +import json +import multiprocessing +import pickle +import queue +import re +import timeit + +import pandas as pd +from tqdm import tqdm + +from mxeval.execution import check_correctness as check_correctness_python +from mxeval.execution import ( + check_correctness_cpp, + check_correctness_csharp, + check_correctness_go, + check_correctness_java, + check_correctness_javascript, + check_correctness_kotlin, + check_correctness_perl, + check_correctness_php, + check_correctness_ruby, + check_correctness_scala, + check_correctness_swift, + check_correctness_typescript, +) + + +def postprocess_golang(code: str) -> str: + multi_line_imports = re.compile( + r"^import \(\n(.+)((?:\n.+)+)\n\)", re.MULTILINE) + line_imports = re.compile(r"^import \".*\"") + func_main = re.compile(r"^func main.*^}", re.MULTILINE | re.DOTALL) + + code = code.replace("package main", "") # Remove package main + code = multi_line_imports.sub("", code) + code = line_imports.sub("", code) + code = func_main.sub("", code) + + return code + + +def postprocess_scala(code: str) -> str: + code = code.replace("object Main extends App {", "") + code = "".join(code.splitlines(True)[:-1]) + return code + + +def postprocess_python(code: str) -> str: + return code.lstrip() + + +def worker(inp_queue, out_queue): + while True: + try: + problem = inp_queue.get(timeout=5) + except queue.Empty: + break + + key = f"{problem['lang']}_{problem['entry_point']}" + checker = eval(f"check_correctness_{problem['lang']}") + + problem["task_id"] = key + problem["test"] = problem["test_code"] + + solution = problem["response"] + + try: + solution = solution[:solution.index("```")] + except ValueError: + # Happens when a code block isn't closed properly + pass + + if problem["lang"] == "go": + solution = postprocess_golang(solution) + elif problem["lang"] == "python": + solution = postprocess_python(solution) + elif problem["lang"] == "scala": + solution = postprocess_scala(solution) + + # Mixtral likes escaping underscores for some reason, so let's remove + # these + solution = solution.replace("\\_", "_") + + # The evaluation script evaluates `code = prompt + solution + tests` + # But Mixtral regenerates the prompt in its output, so we should remove + # this + problem["prompt"] = "" + try: + result = checker(problem, solution, timeout=20.0) + out_queue.put( + (key, + problem["lang"], + result["passed"], + result["result"], + problem["response"])) + except Exception as e: + print(e) + out_queue.put( + (key, problem["lang"], False, "", problem["response"])) + + +def evaluate_mbxp(results, n_workers): + by_lang = {} + for problem in results: + by_lang.setdefault(problem["lang"], []).append(problem) + + inp_queue = multiprocessing.Queue() + out_queue = multiprocessing.Queue() + + n_problems = 0 + + for lang, problems in by_lang.items(): + if lang not in ["cpp", "python", "php", + "javascript", "ruby", "typescript"]: + continue + + n_problems += len(problems) + for problem in problems: + inp_queue.put(problem) + + start = timeit.default_timer() + workers = [] + for _ in range(n_workers): + w = multiprocessing.Process(target=worker, args=(inp_queue, out_queue)) + w.start() + workers.append(w) + + passes = {} + n_passed = 0 + lang_passed = {} + lang_counts = {} + for i in tqdm(range(n_problems)): + key, lang, passed, result, response = out_queue.get() + passes[key] = { + "passed": passed, + "result": result, + "response": response} + n_passed += passed + + lang_passed.setdefault(lang, 0) + lang_passed[lang] += passed + + lang_counts.setdefault(lang, 0) + lang_counts[lang] += 1 + + end = timeit.default_timer() + print(f"Processed {n_problems} in {end - start}s") + print(f"{100 * n_passed / n_problems : .02f}% pass@1") + print(lang_passed, lang_counts) + with open("evaluated_test.json", "w") as f: + json.dump(passes, f, indent=2) + + return 100 * n_passed / n_problems diff --git a/language/mixtral-8x7b/launch.sh b/language/mixtral-8x7b/launch.sh new file mode 100644 index 000000000..c3389c516 --- /dev/null +++ b/language/mixtral-8x7b/launch.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +MLCOMMONS_REPO_PATH="$(dirname "$(dirname "$PWD")")" + +# Add any volume mounts here with the following syntax +# /path/to/src:/path/to/dir/in/container +MOUNTS=( + $MLCOMMONS_REPO_PATH:$MLCOMMONS_REPO_PATH +) + +# Set up docker environment file for current user +rm -f .docker_env +echo "CI_BUILD_USER=`id -u -n`" >> .docker_env +echo "CI_BUILD_UID=`id -u`" >> .docker_env +echo "CI_BUILD_GROUP=`id -g -n`" >> .docker_env +echo "CI_BUILD_GID=`id -g`" >> .docker_env +cat .docker_env + +# Build container +docker build . -t llm/gpubringup + +# Build mount flags +declare -a MOUNT_FLAGS +for _mount in ${MOUNTS[@]}; do + _split=($(echo $_mount | tr ':' '\n')); + MOUNT_FLAGS+=("--mount type=bind,source=${_split[0]},target=${_split[1]}"); +done + +set -x +nvidia-docker run -it --rm --net=host --runtime=nvidia --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ + --cap-add=SYS_PTRACE --cap-add=SYS_ADMIN --cap-add=DAC_READ_SEARCH \ + --security-opt seccomp=unconfined \ + -w $PWD \ + --env-file `pwd`/.docker_env \ + ${MOUNT_FLAGS[*]} \ + llm/gpubringup \ + bash ./with_the_same_user diff --git a/language/mixtral-8x7b/main.py b/language/mixtral-8x7b/main.py new file mode 100644 index 000000000..396948ba0 --- /dev/null +++ b/language/mixtral-8x7b/main.py @@ -0,0 +1,168 @@ +import subprocess +import mlperf_loadgen as lg +import argparse +import os +import logging +import sys +from SUT import SUT, SUTServer + +sys.path.insert(0, os.getcwd()) + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("Mixtral-8x7B-Instruct-v0.1-MAIN") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--scenario", + type=str, + choices=[ + "Offline", + "Server"], + default="Offline", + help="Scenario") + parser.add_argument( + "--model-path", + type=str, + default="mistralai/Mixtral-8x7B-Instruct-v0.1", + help="Model name") + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="path to processed validation dataset") + parser.add_argument( + "--accuracy", + action="store_true", + help="Run accuracy mode") + parser.add_argument( + "--dtype", + type=str, + default="float32", + help="data type of the model, choose from float16, bfloat16 and float32") + parser.add_argument( + "--device", + type=str, + choices=[ + "cpu", + "cuda:0"], + default="cpu", + help="device to use") + parser.add_argument( + "--audit-conf", + type=str, + default="audit.conf", + help="audit config for LoadGen settings during compliance runs") + parser.add_argument( + "--mlperf-conf", + type=str, + default="mlperf.conf", + help="mlperf rules config") + parser.add_argument( + "--user-conf", + type=str, + default="user.conf", + help="user config for user LoadGen settings such as target QPS") + # TODO: This interpretation of 'total-sample-count' is a little + # misleading. Fix it + parser.add_argument( + "--total-sample-count", + type=int, + default=24576, + help="Number of samples to use in benchmark.") + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Model batch-size to use in benchmark.") + parser.add_argument( + "--output-log-dir", + type=str, + default="output-logs", + help="Where logs are saved") + parser.add_argument( + "--enable-log-trace", + action="store_true", + help="Enable log tracing. This file can become quite large") + parser.add_argument( + "--num-workers", + type=int, + default=1, + help="Number of workers to process queries") + + args = parser.parse_args() + return args + + +scenario_map = { + "offline": lg.TestScenario.Offline, + "server": lg.TestScenario.Server, +} + +sut_map = { + "offline": SUT, + "server": SUTServer +} + + +def main(): + args = get_args() + + settings = lg.TestSettings() + settings.scenario = scenario_map[args.scenario.lower()] + # Need to update the conf + settings.FromConfig(args.mlperf_conf, "mixtral-8x7b", args.scenario) + settings.FromConfig(args.user_conf, "mixtral-8x7b", args.scenario) + + if args.accuracy: + settings.mode = lg.TestMode.AccuracyOnly + log.warning( + "Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet") + else: + settings.mode = lg.TestMode.PerformanceOnly + + os.makedirs(args.output_log_dir, exist_ok=True) + log_output_settings = lg.LogOutputSettings() + log_output_settings.outdir = args.output_log_dir + log_output_settings.copy_summary_to_stdout = True + log_settings = lg.LogSettings() + log_settings.log_output = log_output_settings + log_settings.enable_trace = args.enable_log_trace + + sut_cls = sut_map[args.scenario.lower()] + + sut = sut_cls( + model_path=args.model_path, + dtype=args.dtype, + batch_size=args.batch_size, + dataset_path=args.dataset_path, + total_sample_count=args.total_sample_count, + device=args.device, + ) + + # Start sut before loadgen starts + sut.start() + lgSUT = lg.ConstructSUT(sut.issue_queries, sut.flush_queries) + log.info("Starting Benchmark run") + lg.StartTestWithLogSettings( + lgSUT, + sut.qsl, + settings, + log_settings, + args.audit_conf) + + # Stop sut after completion + sut.stop() + + log.info("Run Completed!") + + log.info("Destroying SUT...") + lg.DestroySUT(lgSUT) + + log.info("Destroying QSL...") + lg.DestroyQSL(sut.qsl) + + +if __name__ == "__main__": + main() diff --git a/language/mixtral-8x7b/run_accuracy.sh b/language/mixtral-8x7b/run_accuracy.sh new file mode 100644 index 000000000..80db6b604 --- /dev/null +++ b/language/mixtral-8x7b/run_accuracy.sh @@ -0,0 +1,22 @@ +CHECKPOINT_PATH="${CHECKPOINT_PATH:mistralai/Mixtral-8x7B-Instruct-v0.1}" +DATASET_PATH="${DATASET_PATH:dataset/2024_06_06_mixtral_15k_v4.pkl}" + +mkdir -p "run_outputs" + +python3 -u main.py --scenario Offline \ + --model-path ${CHECKPOINT_PATH} \ + --accuracy \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 15000 \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir offline_accuracy_loadgen_logs \ + --dtype float32 \ + --device cuda:0 2>&1 | tee offline_accuracy_log.log + +python3 evaluate-accuracy.py --checkpoint-path ${CHECKPOINT_PATH} \ + --mlperf-accuracy-file offline_accuracy_loadgen_logs/mlperf_log_accuracy.json \ + --dataset-file ${DATASET_PATH} \ + --dtype int32 + +python3 consolidate_results.py --dataset-path ${DATASET_PATH} --model-dir ${CHECKPOINT_PATH} diff --git a/language/mixtral-8x7b/run_offline.sh b/language/mixtral-8x7b/run_offline.sh new file mode 100644 index 000000000..f451fdc24 --- /dev/null +++ b/language/mixtral-8x7b/run_offline.sh @@ -0,0 +1,10 @@ +CHECKPOINT_PATH="${CHECKPOINT_PATH:mistralai/Mixtral-8x7B-Instruct-v0.1}" +DATASET_PATH="${DATASET_PATH:dataset/2024_06_06_mixtral_15k_v4.pkl}" + +python -u main.py --scenario Offline \ + --model-path ${CHECKPOINT_PATH} \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 15000 \ + --dataset-path ${DATASET_PATH} \ + --device cpu 2>&1 | tee server_log.log diff --git a/language/mixtral-8x7b/run_server.sh b/language/mixtral-8x7b/run_server.sh new file mode 100644 index 000000000..7c4e4a05e --- /dev/null +++ b/language/mixtral-8x7b/run_server.sh @@ -0,0 +1,12 @@ + + +CHECKPOINT_PATH="${CHECKPOINT_PATH:mistralai/Mixtral-8x7B-Instruct-v0.1}" +DATASET_PATH="${DATASET_PATH:dataset/2024_06_06_mixtral_15k_v4.pkl}" + +python -u main.py --scenario Server \ + --model-path ${CHECKPOINT_PATH} \ + --mlperf-conf mlperf.conf \ + --user-conf user.conf \ + --total-sample-count 15000 \ + --dataset-path ${DATASET_PATH} \ + --device cpu 2>&1 | tee server_log.log diff --git a/language/mixtral-8x7b/user.conf b/language/mixtral-8x7b/user.conf new file mode 100644 index 000000000..c5f3ca488 --- /dev/null +++ b/language/mixtral-8x7b/user.conf @@ -0,0 +1,5 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds +# \ No newline at end of file diff --git a/mlperf.conf b/mlperf.conf index dd835563d..5a3c78b22 100644 --- a/mlperf.conf +++ b/mlperf.conf @@ -43,6 +43,7 @@ retinanet.MultiStream.target_latency = 528 # LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario gptj.*.sample_concatenate_permutation = 1 llama2-70b.*.sample_concatenate_permutation = 1 +mixtral-8x7B.*.sample_concatenate_permutation = 1 *.Server.target_latency = 10 *.Server.target_latency_percentile = 99 @@ -58,14 +59,19 @@ gptj.Server.target_latency = 20000 stable-diffusion-xl.Server.target_latency = 20000 # Llama2-70b benchmarks measures token latencies llama2-70b.*.use_token_latencies = 1 +mixtral-8x7b.*.use_token_latencies = 1 # gptj benchmark infers token latencies gptj.*.infer_token_latencies = 1 gptj.*.token_latency_scaling_factor = 69 -# Only ttft and tpot are tracked for the llama2-70b benchmark therefore target_latency = 0 +# Only ttft and tpot are tracked for the llama2-70b & mixtral-8x7B benchmark therefore target_latency = 0 llama2-70b.Server.target_latency = 0 llama2-70b.Server.ttft_latency = 2000 llama2-70b.Server.tpot_latency = 200 +mixtral-8x7b.Server.target_latency = 0 +mixtral-8x7b.Server.ttft_latency = 2000 +mixtral-8x7b.Server.tpot_latency = 200 + *.Offline.target_latency_percentile = 90 *.Offline.min_duration = 600000 @@ -83,6 +89,7 @@ rnnt.Offline.min_query_count = 2513 3d-unet.Offline.min_query_count = 43 stable-diffusion-xl.Offline.min_query_count = 5000 llama2-70b.Offline.min_query_count = 24576 +mixtral-8x7b.Offline.min_query_count = 15000 # These fields should be defined and overridden by user.conf. *.SingleStream.target_latency = 10