minimal_rag_pdf/main.py
2026-05-02 10:22:56 -07:00

286 lines
6.2 KiB
Python

import os
import re
import sqlite3
import sys
import numpy as np
import sqlite_vec
from dotenv import load_dotenv
from llama_cpp import Llama
from pypdf import PdfReader
from sqlite_vec import serialize_float32
load_dotenv()
DEBUG = False
LLM_MODEL = os.getenv("LLM_MODEL")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
PDF_DOCUMENT = os.getenv("PDF_DOCUMENT")
os.makedirs("./models", exist_ok=True)
os.makedirs("./documents", exist_ok=True)
conn = sqlite3.connect("./vector_db.sqlite")
conn.enable_load_extension(True)
sqlite_vec.load(conn)
llm = Llama(
model_path=f"./models/{LLM_MODEL}",
n_gpu_layers=-1, # Uncomment to use GPU acceleration
n_ctx=6096, # Uncomment to increase the context window
verbose=False,
log_level="error",
# seed = 1337, # Uncommment to set a specific seed
# temperature=0.2,
# repeat_penalty=1.15,
# top_p=0.9,
# top_k=40,
)
_embedding_model = None
def get_embedding_model():
global _embedding_model
if _embedding_model is None:
print("Loading embedded model...")
_embedding_model = Llama(
model_path=f"./models/{EMBEDDING_MODEL}",
embedding=True,
verbose=False,
log_level="error",
)
return _embedding_model
def init_db(dim: int):
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks USING vec0(
id INTEGER,
embedding float[{dim}],
text TEXT
);
""")
def load_pdf(path):
reader = PdfReader(path)
text = ""
empty_pages = 0
for page in reader.pages:
page_text = page.extract_text()
if not page_text:
empty_pages += 1
continue
text += page_text + "\n"
print(f"Empty pages: {empty_pages}/{len(reader.pages)}")
print(f"Total extracted chars: {len(text)}")
return text
def chunk_text(text, max_chars=1200):
paragraphs = text.split("\n")
chunks = []
current = ""
for p in paragraphs:
if len(current) + len(p) < max_chars:
current += p + "\n"
else:
chunks.append(current.strip())
current = p + "\n"
if current:
chunks.append(current.strip())
return chunks
def normalize(vec):
v = np.array(vec, dtype=np.float32)
return (v / np.linalg.norm(v)).tolist()
def embed_chunks(chunks, batch_size=1):
all_embeddings = []
model = get_embedding_model()
for i in range(0, len(chunks), batch_size):
batch = chunks[i : i + batch_size]
result = model.create_embedding(batch)
batch_embeddings = [normalize(e["embedding"]) for e in result["data"]]
all_embeddings.extend(batch_embeddings)
print(f"Embedded {i + len(batch)}/{len(chunks)}")
return all_embeddings
def store_embeddings(chunks, embeddings):
dim = len(embeddings[0])
init_db(dim)
for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
conn.execute(
"INSERT INTO chunks (id, embedding, text) VALUES (?, ?, ?)",
(i, serialize_float32(emb), chunk),
)
conn.commit()
def tokenize(text):
# TODO: put this in a config file or something
stop_words = {
"the",
"is",
"a",
"an",
"who",
"what",
"when",
"where",
"why",
"how",
"and",
"or",
"to",
"of",
"in",
"on",
"for",
"with",
"as",
"by",
}
words = set(re.findall(r"\b\w+\b", text.lower()))
return {w for w in words if w not in stop_words}
def keyword_score(query, text):
q = tokenize(query)
t = tokenize(text)
if not q:
return 0
overlap = q & t
score = len(overlap) / len(q)
if query.lower() in text.lower():
score += 1.0
return score
def query(question, top_k=3, initial_k=10):
model = get_embedding_model()
query_embedding = normalize(
model.create_embedding([question])["data"][0]["embedding"]
)
rows = conn.execute(
"""
SELECT id, text FROM chunks WHERE embedding MATCH ? AND k = ?
""",
(serialize_float32(query_embedding), initial_k), # type: ignore
).fetchall()
scored = []
for cid, text in rows:
score = keyword_score(question, text)
scored.append((cid, text, score))
scored.sort(key=lambda x: x[2], reverse=True)
if DEBUG:
print("\n--- RETRIEVAL DEBUG ---")
for cid, text, s in scored[:5]:
print(f"[{cid}] score={s:.2f} | {text[:120]}\n")
return [(cid, text) for cid, text, _ in scored[:top_k]]
def ask_llm(context_chunks, question):
context = "\n\n".join(f"[{cid}] {text}" for cid, text in context_chunks)
prompt = f"""You are a precise assistant.
Use ONLY the provided context to answer.
Cite sources at the end of your sentences using bracket IDs.
If unsure , say "I don't know based on the provided context."
Context:
{context}
Question:
{question}
Answer:"""
stream = llm(
prompt, max_tokens=200, stop=["</s>", "<|end|>", "Question:"], stream=True
)
print("\nANSWER:\n")
for chunk in stream:
token = chunk["choices"][0]["text"] # type: ignore
print(token, end="", flush=True)
print()
def main():
print("Loading DB...")
exists = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'"
).fetchone()
count = 0
if exists:
count = conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
if count == 0:
print("No data found. Ingesting PDF...")
text = load_pdf(f"./documents/{PDF_DOCUMENT}")
chunks = chunk_text(text)
embeddings = embed_chunks(chunks)
store_embeddings(chunks, embeddings)
print("\nRAG is ready. Ask questions (type 'exit' to quit)")
while True:
print()
question = input("Question: ").strip()
if question.lower() in ["exit", "quit"]:
break
results = query(question)
ask_llm(results, question)
if __name__ == "__main__":
main()
print("Goodbye!")
sys.exit()