286 lines
6.2 KiB
Python
286 lines
6.2 KiB
Python
import os
|
|
import re
|
|
import sqlite3
|
|
import sys
|
|
|
|
import numpy as np
|
|
import sqlite_vec
|
|
from dotenv import load_dotenv
|
|
from llama_cpp import Llama
|
|
from pypdf import PdfReader
|
|
from sqlite_vec import serialize_float32
|
|
|
|
load_dotenv()
|
|
|
|
DEBUG = False
|
|
LLM_MODEL = os.getenv("LLM_MODEL")
|
|
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
|
PDF_DOCUMENT = os.getenv("PDF_DOCUMENT")
|
|
|
|
os.makedirs("./models", exist_ok=True)
|
|
os.makedirs("./documents", exist_ok=True)
|
|
|
|
conn = sqlite3.connect("./vector_db.sqlite")
|
|
conn.enable_load_extension(True)
|
|
sqlite_vec.load(conn)
|
|
|
|
|
|
llm = Llama(
|
|
model_path=f"./models/{LLM_MODEL}",
|
|
n_gpu_layers=-1, # Uncomment to use GPU acceleration
|
|
n_ctx=6096, # Uncomment to increase the context window
|
|
verbose=False,
|
|
log_level="error",
|
|
# seed = 1337, # Uncommment to set a specific seed
|
|
# temperature=0.2,
|
|
# repeat_penalty=1.15,
|
|
# top_p=0.9,
|
|
# top_k=40,
|
|
)
|
|
|
|
|
|
_embedding_model = None
|
|
|
|
|
|
def get_embedding_model():
|
|
global _embedding_model
|
|
if _embedding_model is None:
|
|
print("Loading embedded model...")
|
|
_embedding_model = Llama(
|
|
model_path=f"./models/{EMBEDDING_MODEL}",
|
|
embedding=True,
|
|
verbose=False,
|
|
log_level="error",
|
|
)
|
|
return _embedding_model
|
|
|
|
|
|
def init_db(dim: int):
|
|
conn.execute("PRAGMA journal_mode=WAL;")
|
|
|
|
conn.execute(f"""
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS chunks USING vec0(
|
|
id INTEGER,
|
|
embedding float[{dim}],
|
|
text TEXT
|
|
);
|
|
""")
|
|
|
|
|
|
def load_pdf(path):
|
|
reader = PdfReader(path)
|
|
text = ""
|
|
empty_pages = 0
|
|
|
|
for page in reader.pages:
|
|
page_text = page.extract_text()
|
|
|
|
if not page_text:
|
|
empty_pages += 1
|
|
continue
|
|
|
|
text += page_text + "\n"
|
|
|
|
print(f"Empty pages: {empty_pages}/{len(reader.pages)}")
|
|
print(f"Total extracted chars: {len(text)}")
|
|
|
|
return text
|
|
|
|
|
|
def chunk_text(text, max_chars=1200):
|
|
paragraphs = text.split("\n")
|
|
chunks = []
|
|
current = ""
|
|
|
|
for p in paragraphs:
|
|
if len(current) + len(p) < max_chars:
|
|
current += p + "\n"
|
|
else:
|
|
chunks.append(current.strip())
|
|
current = p + "\n"
|
|
|
|
if current:
|
|
chunks.append(current.strip())
|
|
|
|
return chunks
|
|
|
|
|
|
def normalize(vec):
|
|
v = np.array(vec, dtype=np.float32)
|
|
return (v / np.linalg.norm(v)).tolist()
|
|
|
|
|
|
def embed_chunks(chunks, batch_size=1):
|
|
all_embeddings = []
|
|
|
|
model = get_embedding_model()
|
|
|
|
for i in range(0, len(chunks), batch_size):
|
|
batch = chunks[i : i + batch_size]
|
|
|
|
result = model.create_embedding(batch)
|
|
batch_embeddings = [normalize(e["embedding"]) for e in result["data"]]
|
|
|
|
all_embeddings.extend(batch_embeddings)
|
|
|
|
print(f"Embedded {i + len(batch)}/{len(chunks)}")
|
|
|
|
return all_embeddings
|
|
|
|
|
|
def store_embeddings(chunks, embeddings):
|
|
dim = len(embeddings[0])
|
|
init_db(dim)
|
|
|
|
for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
|
|
conn.execute(
|
|
"INSERT INTO chunks (id, embedding, text) VALUES (?, ?, ?)",
|
|
(i, serialize_float32(emb), chunk),
|
|
)
|
|
|
|
conn.commit()
|
|
|
|
|
|
def tokenize(text):
|
|
# TODO: put this in a config file or something
|
|
stop_words = {
|
|
"the",
|
|
"is",
|
|
"a",
|
|
"an",
|
|
"who",
|
|
"what",
|
|
"when",
|
|
"where",
|
|
"why",
|
|
"how",
|
|
"and",
|
|
"or",
|
|
"to",
|
|
"of",
|
|
"in",
|
|
"on",
|
|
"for",
|
|
"with",
|
|
"as",
|
|
"by",
|
|
}
|
|
words = set(re.findall(r"\b\w+\b", text.lower()))
|
|
return {w for w in words if w not in stop_words}
|
|
|
|
|
|
def keyword_score(query, text):
|
|
q = tokenize(query)
|
|
t = tokenize(text)
|
|
|
|
if not q:
|
|
return 0
|
|
|
|
overlap = q & t
|
|
|
|
score = len(overlap) / len(q)
|
|
|
|
if query.lower() in text.lower():
|
|
score += 1.0
|
|
|
|
return score
|
|
|
|
|
|
def query(question, top_k=3, initial_k=10):
|
|
model = get_embedding_model()
|
|
|
|
query_embedding = normalize(
|
|
model.create_embedding([question])["data"][0]["embedding"]
|
|
)
|
|
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT id, text FROM chunks WHERE embedding MATCH ? AND k = ?
|
|
""",
|
|
(serialize_float32(query_embedding), initial_k), # type: ignore
|
|
).fetchall()
|
|
|
|
scored = []
|
|
for cid, text in rows:
|
|
score = keyword_score(question, text)
|
|
scored.append((cid, text, score))
|
|
|
|
scored.sort(key=lambda x: x[2], reverse=True)
|
|
|
|
if DEBUG:
|
|
print("\n--- RETRIEVAL DEBUG ---")
|
|
for cid, text, s in scored[:5]:
|
|
print(f"[{cid}] score={s:.2f} | {text[:120]}\n")
|
|
|
|
return [(cid, text) for cid, text, _ in scored[:top_k]]
|
|
|
|
|
|
def ask_llm(context_chunks, question):
|
|
context = "\n\n".join(f"[{cid}] {text}" for cid, text in context_chunks)
|
|
|
|
prompt = f"""You are a precise assistant.
|
|
|
|
Use ONLY the provided context to answer.
|
|
|
|
Cite sources at the end of your sentences using bracket IDs.
|
|
|
|
If unsure , say "I don't know based on the provided context."
|
|
|
|
Context:
|
|
{context}
|
|
|
|
Question:
|
|
{question}
|
|
|
|
Answer:"""
|
|
|
|
stream = llm(
|
|
prompt, max_tokens=200, stop=["</s>", "<|end|>", "Question:"], stream=True
|
|
)
|
|
|
|
print("\nANSWER:\n")
|
|
|
|
for chunk in stream:
|
|
token = chunk["choices"][0]["text"] # type: ignore
|
|
print(token, end="", flush=True)
|
|
|
|
print()
|
|
|
|
|
|
def main():
|
|
print("Loading DB...")
|
|
|
|
exists = conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'"
|
|
).fetchone()
|
|
|
|
count = 0
|
|
if exists:
|
|
count = conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
|
|
|
|
if count == 0:
|
|
print("No data found. Ingesting PDF...")
|
|
|
|
text = load_pdf(f"./documents/{PDF_DOCUMENT}")
|
|
chunks = chunk_text(text)
|
|
embeddings = embed_chunks(chunks)
|
|
store_embeddings(chunks, embeddings)
|
|
|
|
print("\nRAG is ready. Ask questions (type 'exit' to quit)")
|
|
|
|
while True:
|
|
print()
|
|
|
|
question = input("Question: ").strip()
|
|
|
|
if question.lower() in ["exit", "quit"]:
|
|
break
|
|
|
|
results = query(question)
|
|
ask_llm(results, question)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
print("Goodbye!")
|
|
sys.exit()
|