diff --git a/ingest.py b/ingest.py index 68392bd1..f3b90a33 100644 --- a/ingest.py +++ b/ingest.py @@ -3,6 +3,7 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed import click +import torch from langchain.docstore.document import Document from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.text_splitter import Language, RecursiveCharacterTextSplitter @@ -89,7 +90,7 @@ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Doc @click.command() @click.option( "--device_type", - default="cuda", + default="cuda" if torch.cuda.is_available() else "cpu", type=click.Choice( [ "cpu", diff --git a/run_localGPT.py b/run_localGPT.py index e0d97a5c..928c17c5 100644 --- a/run_localGPT.py +++ b/run_localGPT.py @@ -129,7 +129,7 @@ def load_model(device_type, model_id, model_basename=None): @click.command() @click.option( "--device_type", - default="cuda", + default="cuda" if torch.cuda.is_available() else "cpu", type=click.Choice( [ "cpu", @@ -219,7 +219,7 @@ def main(device_type, show_sources): # model_id = "TheBloke/orca_mini_3B-GGML" # model_basename = "orca-mini-3b.ggmlv3.q4_0.bin" - model_id="TheBloke/Llama-2-7B-Chat-GGML" + model_id = "TheBloke/Llama-2-7B-Chat-GGML" model_basename = "llama-2-7b-chat.ggmlv3.q4_0.bin" llm = load_model(device_type, model_id=model_id, model_basename=model_basename)