import os import re import sqlite3 import sys import numpy as np import sqlite_vec from dotenv import load_dotenv from llama_cpp import Llama from pypdf import PdfReader from sqlite_vec import serialize_float32 load_dotenv() DEBUG = False LLM_MODEL = os.getenv("LLM_MODEL") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") PDF_DOCUMENT = os.getenv("PDF_DOCUMENT") os.makedirs("./models", exist_ok=True) os.makedirs("./documents", exist_ok=True) conn = sqlite3.connect("./vector_db.sqlite") conn.enable_load_extension(True) sqlite_vec.load(conn) llm = Llama( model_path=f"./models/{LLM_MODEL}", n_gpu_layers=-1, # Uncomment to use GPU acceleration n_ctx=6096, # Uncomment to increase the context window verbose=False, log_level="error", # seed = 1337, # Uncommment to set a specific seed # temperature=0.2, # repeat_penalty=1.15, # top_p=0.9, # top_k=40, ) _embedding_model = None def get_embedding_model(): global _embedding_model if _embedding_model is None: print("Loading embedded model...") _embedding_model = Llama( model_path=f"./models/{EMBEDDING_MODEL}", embedding=True, verbose=False, log_level="error", ) return _embedding_model def init_db(dim: int): conn.execute("PRAGMA journal_mode=WAL;") conn.execute(f""" CREATE VIRTUAL TABLE IF NOT EXISTS chunks USING vec0( id INTEGER, embedding float[{dim}], text TEXT ); """) def load_pdf(path): reader = PdfReader(path) text = "" empty_pages = 0 for page in reader.pages: page_text = page.extract_text() if not page_text: empty_pages += 1 continue text += page_text + "\n" print(f"Empty pages: {empty_pages}/{len(reader.pages)}") print(f"Total extracted chars: {len(text)}") return text def chunk_text(text, max_chars=1200): paragraphs = text.split("\n") chunks = [] current = "" for p in paragraphs: if len(current) + len(p) < max_chars: current += p + "\n" else: chunks.append(current.strip()) current = p + "\n" if current: chunks.append(current.strip()) return chunks def normalize(vec): v = np.array(vec, dtype=np.float32) return (v / np.linalg.norm(v)).tolist() def embed_chunks(chunks, batch_size=1): all_embeddings = [] model = get_embedding_model() for i in range(0, len(chunks), batch_size): batch = chunks[i : i + batch_size] result = model.create_embedding(batch) batch_embeddings = [normalize(e["embedding"]) for e in result["data"]] all_embeddings.extend(batch_embeddings) print(f"Embedded {i + len(batch)}/{len(chunks)}") return all_embeddings def store_embeddings(chunks, embeddings): dim = len(embeddings[0]) init_db(dim) for i, (chunk, emb) in enumerate(zip(chunks, embeddings)): conn.execute( "INSERT INTO chunks (id, embedding, text) VALUES (?, ?, ?)", (i, serialize_float32(emb), chunk), ) conn.commit() def tokenize(text): # TODO: put this in a config file or something stop_words = { "the", "is", "a", "an", "who", "what", "when", "where", "why", "how", "and", "or", "to", "of", "in", "on", "for", "with", "as", "by", } words = set(re.findall(r"\b\w+\b", text.lower())) return {w for w in words if w not in stop_words} def keyword_score(query, text): q = tokenize(query) t = tokenize(text) if not q: return 0 overlap = q & t score = len(overlap) / len(q) if query.lower() in text.lower(): score += 1.0 return score def query(question, top_k=3, initial_k=10): model = get_embedding_model() query_embedding = normalize( model.create_embedding([question])["data"][0]["embedding"] ) rows = conn.execute( """ SELECT id, text FROM chunks WHERE embedding MATCH ? AND k = ? """, (serialize_float32(query_embedding), initial_k), # type: ignore ).fetchall() scored = [] for cid, text in rows: score = keyword_score(question, text) scored.append((cid, text, score)) scored.sort(key=lambda x: x[2], reverse=True) if DEBUG: print("\n--- RETRIEVAL DEBUG ---") for cid, text, s in scored[:5]: print(f"[{cid}] score={s:.2f} | {text[:120]}\n") return [(cid, text) for cid, text, _ in scored[:top_k]] def ask_llm(context_chunks, question): context = "\n\n".join(f"[{cid}] {text}" for cid, text in context_chunks) prompt = f"""You are a precise assistant. Use ONLY the provided context to answer. Cite sources at the end of your sentences using bracket IDs. If unsure , say "I don't know based on the provided context." Context: {context} Question: {question} Answer:""" stream = llm( prompt, max_tokens=200, stop=["", "<|end|>", "Question:"], stream=True ) print("\nANSWER:\n") for chunk in stream: token = chunk["choices"][0]["text"] # type: ignore print(token, end="", flush=True) print() def main(): print("Loading DB...") exists = conn.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'" ).fetchone() count = 0 if exists: count = conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0] if count == 0: print("No data found. Ingesting PDF...") text = load_pdf(f"./documents/{PDF_DOCUMENT}") chunks = chunk_text(text) embeddings = embed_chunks(chunks) store_embeddings(chunks, embeddings) print("\nRAG is ready. Ask questions (type 'exit' to quit)") while True: print() question = input("Question: ").strip() if question.lower() in ["exit", "quit"]: break results = query(question) ask_llm(results, question) if __name__ == "__main__": main() print("Goodbye!") sys.exit()