diff --git a/mokuro/manga_page_ocr.py b/mokuro/manga_page_ocr.py index c64d012..abb9ae2 100644 --- a/mokuro/manga_page_ocr.py +++ b/mokuro/manga_page_ocr.py @@ -9,6 +9,7 @@ from mokuro import __version__ from mokuro.cache import cache from mokuro.utils import imread +import torch class InvalidImage(Exception): @@ -35,9 +36,11 @@ def __init__( self.disable_ocr = disable_ocr if not self.disable_ocr: - logger.info("Initializing text detector") + cuda = torch.cuda.is_available() + device = 'cuda' if cuda and not force_cpu else 'cpu' + logger.info(f"Initializing text detector, using device {device}") self.text_detector = TextDetector( - model_path=cache.comic_text_detector, input_size=detector_input_size, device="cpu", act="leaky" + model_path=cache.comic_text_detector, input_size=detector_input_size, device=device, act="leaky" ) self.mocr = MangaOcr(pretrained_model_name_or_path, force_cpu)