This repository was archived by the owner on Apr 9, 2025. It is now read-only.
forked from Priyamakeshwari/TeachGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
86 lines (78 loc) · 2.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import logging
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
from langchain.llms import LlamaCpp
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.callbacks.manager import CallbackManager
from langchain.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from huggingface_hub import hf_hub_download
from config import (
PERSIST_DIRECTORY,
MODEL_DIRECTORY,
SOURCE_DIR,
EMBEDDING_MODEL,
DEVICE_TYPE,
CHROMA_SETTINGS,
MODEL_NAME,
MODEL_FILE,
N_GPU_LAYERS,
MAX_TOKEN_LENGTH,
)
def load_model(device_type:str = DEVICE_TYPE, model_id:str = MODEL_NAME, model_basename:str = MODEL_FILE, LOGGING=logging):
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
try:
model_path = hf_hub_download(
repo_id=model_id,
filename=model_basename,
resume_download=True,
cache_dir=MODEL_DIRECTORY,
)
kwargs = {
"model_path": model_path,
"max_tokens": MAX_TOKEN_LENGTH,
"n_ctx": MAX_TOKEN_LENGTH,
"n_batch": 512,
"callback_manager": callback_manager,
"verbose":False,
"f16_kv":True,
"streaming":True,
}
if device_type.lower() == "mps":
kwargs["n_gpu_layers"] = 1
if device_type.lower() == "cuda":
kwargs["n_gpu_layers"] = N_GPU_LAYERS # set this based on your GPU
llm = LlamaCpp(**kwargs)
LOGGING.info(f"Loaded {model_id} locally")
return llm # Returns a LlamaCpp object
except Exception as e:
LOGGING.info(f"Error {e}")
def retriver(device_type:str = DEVICE_TYPE, LOGGING=logging):
embeddings = HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={"device": DEVICE_TYPE},
cache_folder=MODEL_DIRECTORY,
)
db = Chroma(
persist_directory=PERSIST_DIRECTORY,
embedding_function=embeddings,
)
retriever = db.as_retriever()
LOGGING.info(f"Loaded Chroma DB Successfully")
llm = load_model(device_type, model_id=MODEL_NAME, model_basename=MODEL_FILE, LOGGING=logging)
template = """
[INST]
Context: {summaries}
User: {question}
[/INST]
"""
prompt = PromptTemplate(input_variables=["summaries", "question"], template=template)
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
retriever=retriever,
# chain_type="stuff",
chain_type_kwargs={"prompt": prompt},
)
chain({'question' : "What is the linux command to list files in direcotyu",},return_only_outputs=True)
if __name__ == '__main__':
retriver()