Skip to content

Commit

Permalink
run flax model
Browse files Browse the repository at this point in the history
# Current blocker is that jax not supported linux library for ARM64
# jax-ml/jax#7097
  • Loading branch information
naltukhov committed Jan 8, 2023
1 parent 5d3f573 commit 6c74af5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
10 changes: 5 additions & 5 deletions api/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.9-slim
FROM python:3.9

COPY ./app.py /joke_gen/
COPY ./model_utils.py /joke_gen/
Expand All @@ -11,8 +11,8 @@ ENV FLASK_RUN_HOST=0.0.0.0

WORKDIR /joke_gen

RUN pip install --upgrade pip
RUN pip install -r requirements.txt
EXPOSE 8888
RUN apt-get update && apt-get -y install pkg-config cmake curl
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
RUN pip install --upgrade pip && pip install -r requirements.txt

CMD ["python", "app.py"]
EXPOSE 8888
13 changes: 10 additions & 3 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,16 @@ def get_prediction() -> List[Dict]:


if __name__ == '__main__':
logging.debug('Starting loading model')
logging.debug('Starting downloading model')
# model variable refers to the global variable
model = T5GenerationModel()
model.load_model_from_file('model_weights/')
logging.debug('Model was successfully loaded!')
# model.load_model_from_file('model_weights/')

model.load_model_from_hub(model_name="naltukhov/joke-generator-t5-rus-finetune",
model_type="flax",
revision="a001d2b3c44d193f489f2e3704ca13776a57a43b",
use_auth_token=False,
force_download=True)

logging.debug('Model was successfully downloaded!')
app.run(host='0.0.0.0', port=8888)
13 changes: 9 additions & 4 deletions api/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
--extra-index-url https://download.pytorch.org/whl/cpu
Flask==2.2.2
huggingface-hub
transformers
sentencepiece
pandas

torch==1.13.1
torchvision==0.14.1
torchaudio==0.13.1

pandas==1.5.2
huggingface-hub
sentencepiece==0.1.97
transformers
flax
jax

# Current blocker is that jax not supported linux library for ARM64
# https://github.com/google/jax/issues/7097

0 comments on commit 6c74af5

Please sign in to comment.