This project is an example of how to build a REST API for image classification using PyTorch and Flask. The API accepts image inputs of any size, pre-processes the image, and runs it through a pre-trained PyTorch model to make a prediction. The predicted class and confidence score are returned as JSON.
- Python 3.x
- PyTorch
- Flask
- Clone the repository:
git clone https://github.com/username/repo.git
- cd repo
- Install the required packages:
<pip install -r requirements.txt
>
- run the training command:
python main.py --mode train --save_path [PATH]
- run the training command:
python main.py --mode test --model_path [PATH]
-
Start the Flask app:
python main.py --mode 'serve' --model_path 'path/to/the/checkpoint'
-
Send a POST request to the
/classify
endpoint with an image file attached:
bash curl -X POST -F "image=@/path/to/image.jpg" http://localhost:5000/classify
- The PyTorch model is located in
model.py
. - The Flask app is located in
app.py
. - The image preprocessing pipeline is defined in
app.py
. - The server configuration is defined in
config.py
.
- Fork the repository.
- Create a new branch.
- Make your changes and commit them.
- Push to the branch.
- Submit a pull request.
This project is licensed under the MIT License - see the LICENSE
file for details.
- This project was inspired by PyTorch Image Classification Tutorial and Building a Simple Flask API for Image Recognition.