Improving RAG

This commit is contained in:
Dita Aji Pratama 2026-06-24 17:11:03 +07:00
parent 059445de5e
commit 03221ca119
4 changed files with 167 additions and 108 deletions

View File

@ -122,7 +122,8 @@ SESSION_DB_PATH = os.path.expanduser(
# ─── RAG (YAML) ─────────────────────────────────────────────────────────────────
RAG_PERSIST_DIR = os.getenv("RAG_PERSIST_DIR", default=_yaml_get("rag", "persist_dir", default="chroma_db"))
RAG_PERSIST_DIR = os.getenv("RAG_PERSIST_DIR", default=_yaml_get("rag", "persist_dir", default="lancedb_data"))
RAG_MODEL_PATH = os.getenv("RAG_MODEL_PATH", default=_yaml_get("rag", "model_path", default=""))
# ─── Humanize Delay (YAML) ─────────────────────────────────────────────────────

View File

@ -34,7 +34,8 @@ llm:
- name : "z-ai/glm-5"
rag:
persist_dir: chroma_db # ChromaDB ONNX default (all-MiniLM-L6-v2, local)
persist_dir: "~/.config/hendrik/rag" # LanceDB Vector Store (all-MiniLM-L6-v2, local)
model_path: "~/.config/hendrik/models" # Custom path to store/load embedding model.
session:
db_path: "~/.config/hendrik/sessions.json"
@ -50,8 +51,7 @@ telegram:
allowed_group_ids: "" # comma-separated, empty = all group
selective_response: true # true = only response if mentioned/relevant
# Humanize Delay (anti-bot detection)
delay:
delay: # Humanize Delay (anti-bot detection)
read_min: 1.0 # second
read_max: 2.0 # second
typing_speed: 15.0 # characters per second

View File

@ -5,3 +5,7 @@ openpyxl>=3.1.0
slixmpp
python-telegram-bot>=20.0
tinydb>=4.8.0
lancedb
sentence-transformers
pandas
pylance

View File

@ -2,29 +2,78 @@ import glob as globmod
import json
import os
import time
import chromadb
from chromadb.config import Settings
import pandas as pd
import lancedb
from lancedb.pydantic import LanceModel
from sentence_transformers import SentenceTransformer
import config
# ── Embedding Setup ───────────────────────────────────────────────────────
# ── ChromaDB singleton ───────────────────────────────────────────────
def load_embedding_model():
"""
Logika pemuatan model embedding berdasarkan konfigurasi:
1. Jika model_path kosong -> gunakan default cache (~/.cache/...)
2. Jika model_path diisi tapi folder belum ada -> download lalu simpan ke folder tersebut
3. Jika model_path diisi dan folder sudah ada -> load langsung dari folder tersebut
"""
model_name = "all-MiniLM-L6-v2"
custom_path = config.RAG_MODEL_PATH.strip()
_store = None
try:
if not custom_path:
# Kasus 1: Pakai default cache
print(f"[RAG] Loading embedding model '{model_name}' from default cache...")
return SentenceTransformer(model_name)
def _get_store():
global _store
if _store is None:
_store = chromadb.PersistentClient(
path=config.RAG_PERSIST_DIR,
settings=Settings(anonymized_telemetry=False),
)
return _store
# Kasus 2 & 3: Menggunakan path kustom
if os.path.exists(custom_path):
print(f"[RAG] Loading embedding model from custom path: {custom_path}")
return SentenceTransformer(custom_path)
else:
print(f"[RAG] Custom path {custom_path} not found. Downloading model first...")
model = SentenceTransformer(model_name)
# Buat direktori jika belum ada
os.makedirs(custom_path, exist_ok=True)
model.save(custom_path)
print(f"[RAG] Model successfully downloaded and saved to: {custom_path}")
return model
def _collection(name):
"""Get or create collection — uses ChromaDB's default ONNX embedding (all-MiniLM-L6-v2)."""
return _get_store().get_or_create_collection(name=name)
except Exception as e:
print(f"[RAG] Critical Error loading embedding model: {e}")
return None
# Inisialisasi model saat startup
embedding_model = load_embedding_model()
def get_embedding(text):
"""Fungsi standar untuk menghasilkan embedding"""
if embedding_model is None:
raise Exception("Embedding model not loaded. Check your config or internet connection.")
return embedding_model.encode(text).tolist()
# Skema sederhana untuk menghindari konflik Pydantic
class DocumentSchema(LanceModel):
text: str
id: str
metadata: str
vector: list[float]
# ── LanceDB singleton ───────────────────────────────────────────────────────
_db = None
def _get_db():
global _db
if _db is None:
_db = lancedb.connect(config.RAG_PERSIST_DIR)
return _db
def _get_table(name):
db = _get_db()
if name in db.table_names():
return db.open_table(name)
return db.create_table(name, schema=DocumentSchema)
# ── Tool schemas ─────────────────────────────────────────────────────
@ -74,9 +123,8 @@ schema_search_knowledge = {
"name": "search_knowledge",
"description": (
"Semantically search a RAG collection. Optionally narrow with a "
"metadata filter using ChromaDB where syntax. "
"Examples: {'category': 'main_course'}, {'spice_level': {'$lte': 2}}, "
"{'allergens': {'$contains': 'seafood'}}."
"metadata filter using SQL-like syntax. "
"Example: \"metadata LIKE '%main_course%'\""
),
"parameters": {
"type": "object",
@ -95,8 +143,8 @@ schema_search_knowledge = {
"default": 5
},
"filter": {
"type": "object",
"description": "Optional metadata filter dict",
"type": "string",
"description": "Optional SQL-like filter for metadata JSON string",
"default": None
}
},
@ -185,7 +233,6 @@ schema_inspect_collection = {
}
}
schema_ingest_files = {
"type": "function",
"function": {
@ -229,135 +276,131 @@ schema_ingest_files = {
}
}
# ── Tool handlers ────────────────────────────────────────────────────
def _sanitize_meta(meta):
"""ChromaDB metadata only allows str/int/float/bool. Convert lists to JSON string, remove empty lists."""
out = {}
for k, v in meta.items():
if isinstance(v, list):
if len(v) == 0:
continue
out[k] = json.dumps(v, ensure_ascii=False)
elif isinstance(v, (str, int, float, bool)):
out[k] = v
else:
out[k] = str(v)
return out
def store_knowledge(collection, documents):
try:
col = _collection(collection)
ids, texts, metas = [], [], []
table = _get_table(collection)
data = []
for doc in documents:
ids.append(doc["id"])
texts.append(doc["text"])
metas.append(_sanitize_meta(doc.get("metadata", {})))
col.add(ids=ids, documents=texts, metadatas=metas)
data.append({
"id": doc["id"],
"text": doc["text"],
"metadata": json.dumps(doc.get("metadata", {}), ensure_ascii=False),
"vector": get_embedding(doc["text"])
})
table.add(data)
return f"Stored {len(documents)} document(s) in '{collection}'."
except Exception as e:
return f"Error: {e}"
def search_knowledge(collection, query, n_results=5, filter=None):
try:
col = _collection(collection)
kw = {"query_texts": [query], "n_results": n_results}
table = _get_table(collection)
# LanceDB semantic search
query_vector = get_embedding(query)
res = table.search(query_vector).limit(n_results)
if filter:
kw["where"] = filter
r = col.query(**kw)
if not r["ids"] or not r["ids"][0]:
res = table.search(query_vector).where(filter).limit(n_results)
df = res.to_pandas()
if df.empty:
return "No results found."
out = []
for i in range(len(r["ids"][0])):
did = r["ids"][0][i]
txt = r["documents"][0][i]
for _, row in df.iterrows():
did = row["id"]
txt = row["text"]
if len(txt) > 500:
txt = txt[:500] + "..."
meta = json.dumps(r["metadatas"][0][i], ensure_ascii=False) if r.get("metadatas") else "{}"
dist = ""
if r.get("distances"):
dist = f" (score: {r['distances'][0][i]:.4f})"
out.append(f"[{did}]{dist}\n text: {txt}\n metadata: {meta}")
meta = row["metadata"]
out.append(f"[{did}]\n text: {txt}\n metadata: {meta}")
return "\n---\n".join(out)
except Exception as e:
return f"Error: {e}"
def create_collection(name, description=""):
try:
col = _get_store().get_or_create_collection(name=name)
col.modify(metadata={"description": description})
_get_table(name)
return f"Collection '{name}' is ready."
except Exception as e:
return f"Error: {e}"
def delete_collection(name):
try:
_get_store().delete_collection(name)
db = _get_db()
table_path = os.path.join(config.RAG_PERSIST_DIR, name)
if os.path.exists(table_path):
import shutil
shutil.rmtree(table_path)
return f"Deleted collection '{name}'."
except Exception as e:
return f"Error: {e}"
def list_collections():
try:
cols = _get_store().list_collections()
db = _get_db()
cols = db.table_names()
if not cols:
return "No collections exist yet."
out = ["Available collections:"]
for col in cols:
meta = col.metadata or {}
desc = meta.get("description", "")
cnt = col.count()
tag = f" ({desc})" if desc else ""
out.append(f"- {col.name}{tag} [{cnt} docs]")
table = db.open_table(col)
cnt = len(table.to_pandas())
out.append(f"- {col} [{cnt} docs]")
return "\n".join(out)
except Exception as e:
return f"Error: {e}"
def inspect_collection(collection, sample_size=3):
try:
col = _collection(collection)
cnt = col.count()
table = _get_table(collection)
df = table.to_pandas()
cnt = len(df)
if cnt == 0:
return f"Collection '{collection}' is empty."
n = min(sample_size, cnt)
r = col.get(limit=n, include=["documents", "metadatas"])
sample = df.head(n)
out = [f"Collection: {collection} | Total documents: {cnt}", f"Sample ({n}):"]
for i in range(len(r["ids"])):
txt = r["documents"][i]
for _, row in sample.iterrows():
txt = row["text"]
if len(txt) > 200:
txt = txt[:200] + "..."
meta = json.dumps(r["metadatas"][i], ensure_ascii=False) if r.get("metadatas") and r["metadatas"][i] else "(none)"
out.append(f"\n [{r['ids'][i]}] text: {txt} metadata: {meta}")
meta = row["metadata"]
out.append(f"\n [{row['id']}] text: {txt} metadata: {meta}")
keys = set()
for m in r["metadatas"]:
if m:
keys.update(m.keys())
for m_str in sample["metadata"]:
try:
m_dict = json.loads(m_str)
keys.update(m_dict.keys())
except:
pass
if keys:
out.append(f"\nMetadata keys: {', '.join(sorted(keys))}")
return "\n".join(out)
except Exception as e:
return f"Error: {e}"
def ingest_files(collection, paths, chunk_size=0, chunk_overlap=0, recursive=True):
try:
col = _collection(collection)
all_ids, all_texts, all_metas = [], [], []
table = _get_table(collection)
all_data = []
processed, skipped = 0, 0
# Expand glob patterns into real file paths
file_set = set()
for p in paths:
expanded = globmod.glob(p, recursive=recursive)
if expanded:
file_set.update(expanded)
else:
# Maybe it's a literal path that doesn't look like a glob
if os.path.isfile(p):
elif os.path.isfile(p):
file_set.add(p)
else:
skipped += 1
@ -381,7 +424,6 @@ def ingest_files(collection, paths, chunk_size=0, chunk_overlap=0, recursive=Tru
}
base_name = os.path.splitext(os.path.basename(fpath))[0]
# ── read content ──────────────────────────────────────────
if ext in (".xlsx", ".xlsm"):
try:
import openpyxl
@ -396,6 +438,7 @@ def ingest_files(collection, paths, chunk_size=0, chunk_overlap=0, recursive=Tru
for row in ws.iter_rows(values_only=True):
vals = [str(c) if c is not None else "" for c in row]
rows.append("\t".join(vals))
lines = rows
content = "\n".join(lines)
if not content.strip():
@ -416,22 +459,27 @@ def ingest_files(collection, paths, chunk_size=0, chunk_overlap=0, recursive=Tru
meta["chunk_index"] = cid
meta["chunk_lines"] = end - start
meta["chunk_start_line"] = start + 1
all_ids.append(doc_id)
all_texts.append(chunk_text)
all_metas.append(_sanitize_meta(meta))
all_data.append({
"id": doc_id,
"text": chunk_text,
"metadata": json.dumps(meta, ensure_ascii=False),
"vector": get_embedding(chunk_text)
})
cid += 1
step = chunk_size - chunk_overlap
start += step if step > 0 else 1
processed += 1
else:
doc_id = f"{base_name}_{sheet_name}"
all_ids.append(doc_id)
all_texts.append(content)
all_metas.append(_sanitize_meta(sheet_meta))
all_data.append({
"id": doc_id,
"text": content,
"metadata": json.dumps(sheet_meta, ensure_ascii=False),
"vector": get_embedding(content)
})
processed += 1
wb.close()
else:
# Plain-text files
try:
with open(fpath, "r", encoding="utf-8", errors="replace") as f:
lines = f.readlines()
@ -456,26 +504,32 @@ def ingest_files(collection, paths, chunk_size=0, chunk_overlap=0, recursive=Tru
meta["chunk_index"] = cid
meta["chunk_lines"] = end - start
meta["chunk_start_line"] = start + 1
all_ids.append(doc_id)
all_texts.append(chunk_text)
all_metas.append(_sanitize_meta(meta))
all_data.append({
"id": doc_id,
"text": chunk_text,
"metadata": json.dumps(meta, ensure_ascii=False),
"vector": get_embedding(chunk_text)
})
cid += 1
step = chunk_size - chunk_overlap
start += step if step > 0 else 1
processed += 1
else:
doc_id = base_name
all_ids.append(doc_id)
all_texts.append(content)
all_metas.append(_sanitize_meta(base_meta))
all_data.append({
"id": doc_id,
"text": content,
"metadata": json.dumps(base_meta, ensure_ascii=False),
"vector": get_embedding(content)
})
processed += 1
if all_ids:
col.add(ids=all_ids, documents=all_texts, metadatas=all_metas)
if all_data:
table.add(all_data)
parts = [f"Ingested {processed} file(s) into '{collection}'"]
if processed > 0:
parts.append(f"({len(all_ids)} document(s) total)")
parts.append(f"({len(all_data)} document(s) total)")
if skipped > 0:
parts.append(f"({skipped} file(s) skipped)")
return " ".join(parts)