hendrik/tools/rag.py

485 lines
17 KiB
Python

import glob as globmod
import json
import os
import time
import chromadb
from chromadb.config import Settings
import config
# ── ChromaDB singleton ───────────────────────────────────────────────
_store = None
def _get_store():
global _store
if _store is None:
_store = chromadb.PersistentClient(
path=config.RAG_PERSIST_DIR,
settings=Settings(anonymized_telemetry=False),
)
return _store
def _collection(name):
"""Get or create collection — uses ChromaDB's default ONNX embedding (all-MiniLM-L6-v2)."""
return _get_store().get_or_create_collection(name=name)
# ── Tool schemas ─────────────────────────────────────────────────────
schema_store_knowledge = {
"type": "function",
"function": {
"name": "store_knowledge",
"description": (
"Store one or more documents with arbitrary metadata into a RAG collection. "
"Metadata is a free-form dict — choose meaningful keys for future filtering "
"(e.g., restaurant, category, allergens, spice_level, taste_profile, price"
", customer_id, dietary)."
),
"parameters": {
"type": "object",
"properties": {
"collection": {
"type": "string",
"description": "Target collection name (must be defined in config)"
},
"documents": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string", "description": "Unique document ID"},
"text": {"type": "string", "description": "Document body text"},
"metadata": {
"type": "object",
"description": "Arbitrary key-value metadata",
"default": {}
}
},
"required": ["id", "text"]
},
"description": "List of documents to persist"
}
},
"required": ["collection", "documents"]
}
}
}
schema_search_knowledge = {
"type": "function",
"function": {
"name": "search_knowledge",
"description": (
"Semantically search a RAG collection. Optionally narrow with a "
"metadata filter using ChromaDB where syntax. "
"Examples: {'category': 'main_course'}, {'spice_level': {'$lte': 2}}, "
"{'allergens': {'$contains': 'seafood'}}."
),
"parameters": {
"type": "object",
"properties": {
"collection": {
"type": "string",
"description": "Collection name to search in"
},
"query": {
"type": "string",
"description": "Natural-language search query"
},
"n_results": {
"type": "integer",
"description": "Max results to return (default 5)",
"default": 5
},
"filter": {
"type": "object",
"description": "Optional metadata filter dict",
"default": None
}
},
"required": ["collection", "query"]
}
}
}
schema_create_collection = {
"type": "function",
"function": {
"name": "create_collection",
"description": (
"Create a new RAG collection for a new topic/domain. Use a short, descriptive name "
"with underscores (e.g., 'tanaman_hias', 'customer_profiles'). Optionally provide a description."
),
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Collection name (lowercase, underscores for spaces)"
},
"description": {
"type": "string",
"description": "What this collection stores",
"default": ""
}
},
"required": ["name"]
}
}
}
schema_delete_collection = {
"type": "function",
"function": {
"name": "delete_collection",
"description": "Permanently delete an entire RAG collection and all documents in it. Use with caution — this cannot be undone.",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Collection name to delete"
}
},
"required": ["name"]
}
}
}
schema_list_collections = {
"type": "function",
"function": {
"name": "list_collections",
"description": "List all existing RAG collections with their document count and description.",
"parameters": {"type": "object", "properties": {}}
}
}
schema_inspect_collection = {
"type": "function",
"function": {
"name": "inspect_collection",
"description": (
"Examine sample documents and metadata fields in a RAG collection. "
"Always call this before search_knowledge to learn what metadata keys "
"are available for filtering, then pass them in the filter parameter."
),
"parameters": {
"type": "object",
"properties": {
"collection": {
"type": "string",
"description": "Collection name to inspect"
},
"sample_size": {
"type": "integer",
"description": "Number of sample documents (default 3)",
"default": 3
}
},
"required": ["collection"]
}
}
}
schema_ingest_files = {
"type": "function",
"function": {
"name": "ingest_files",
"description": (
"Read one or more files (supports glob patterns like *.py or src/**/*.md) "
"and store their content into a RAG collection. "
"Optionally chunk files into smaller pieces by line count. "
"Automatically extracts metadata: filename, path, extension, size, modification time."
),
"parameters": {
"type": "object",
"properties": {
"collection": {
"type": "string",
"description": "Target collection name (will be created if it doesn't exist)"
},
"paths": {
"type": "array",
"items": {"type": "string"},
"description": "File paths or glob patterns (e.g., ['*.txt', 'src/**/*.py'])"
},
"chunk_size": {
"type": "integer",
"description": "Lines per chunk (0 = whole file as one document)",
"default": 0
},
"chunk_overlap": {
"type": "integer",
"description": "Line overlap between chunks (only used when chunk_size > 0)",
"default": 0
},
"recursive": {
"type": "boolean",
"description": "Search directories recursively when using glob patterns",
"default": True
}
},
"required": ["collection", "paths"]
}
}
}
# ── Tool handlers ────────────────────────────────────────────────────
def _sanitize_meta(meta):
"""ChromaDB metadata only allows str/int/float/bool. Convert lists to JSON string, remove empty lists."""
out = {}
for k, v in meta.items():
if isinstance(v, list):
if len(v) == 0:
continue
out[k] = json.dumps(v, ensure_ascii=False)
elif isinstance(v, (str, int, float, bool)):
out[k] = v
else:
out[k] = str(v)
return out
def store_knowledge(collection, documents):
try:
col = _collection(collection)
ids, texts, metas = [], [], []
for doc in documents:
ids.append(doc["id"])
texts.append(doc["text"])
metas.append(_sanitize_meta(doc.get("metadata", {})))
col.add(ids=ids, documents=texts, metadatas=metas)
return f"Stored {len(documents)} document(s) in '{collection}'."
except Exception as e:
return f"Error: {e}"
def search_knowledge(collection, query, n_results=5, filter=None):
try:
col = _collection(collection)
kw = {"query_texts": [query], "n_results": n_results}
if filter:
kw["where"] = filter
r = col.query(**kw)
if not r["ids"] or not r["ids"][0]:
return "No results found."
out = []
for i in range(len(r["ids"][0])):
did = r["ids"][0][i]
txt = r["documents"][0][i]
if len(txt) > 500:
txt = txt[:500] + "..."
meta = json.dumps(r["metadatas"][0][i], ensure_ascii=False) if r.get("metadatas") else "{}"
dist = ""
if r.get("distances"):
dist = f" (score: {r['distances'][0][i]:.4f})"
out.append(f"[{did}]{dist}\n text: {txt}\n metadata: {meta}")
return "\n---\n".join(out)
except Exception as e:
return f"Error: {e}"
def create_collection(name, description=""):
try:
col = _get_store().get_or_create_collection(name=name)
col.modify(metadata={"description": description})
return f"Collection '{name}' is ready."
except Exception as e:
return f"Error: {e}"
def delete_collection(name):
try:
_get_store().delete_collection(name)
return f"Deleted collection '{name}'."
except Exception as e:
return f"Error: {e}"
def list_collections():
try:
cols = _get_store().list_collections()
if not cols:
return "No collections exist yet."
out = ["Available collections:"]
for col in cols:
meta = col.metadata or {}
desc = meta.get("description", "")
cnt = col.count()
tag = f" ({desc})" if desc else ""
out.append(f"- {col.name}{tag} [{cnt} docs]")
return "\n".join(out)
except Exception as e:
return f"Error: {e}"
def inspect_collection(collection, sample_size=3):
try:
col = _collection(collection)
cnt = col.count()
if cnt == 0:
return f"Collection '{collection}' is empty."
n = min(sample_size, cnt)
r = col.get(limit=n, include=["documents", "metadatas"])
out = [f"Collection: {collection} | Total documents: {cnt}", f"Sample ({n}):"]
for i in range(len(r["ids"])):
txt = r["documents"][i]
if len(txt) > 200:
txt = txt[:200] + "..."
meta = json.dumps(r["metadatas"][i], ensure_ascii=False) if r.get("metadatas") and r["metadatas"][i] else "(none)"
out.append(f"\n [{r['ids'][i]}] text: {txt} metadata: {meta}")
keys = set()
for m in r["metadatas"]:
if m:
keys.update(m.keys())
if keys:
out.append(f"\nMetadata keys: {', '.join(sorted(keys))}")
return "\n".join(out)
except Exception as e:
return f"Error: {e}"
def ingest_files(collection, paths, chunk_size=0, chunk_overlap=0, recursive=True):
try:
col = _collection(collection)
all_ids, all_texts, all_metas = [], [], []
processed, skipped = 0, 0
# Expand glob patterns into real file paths
file_set = set()
for p in paths:
expanded = globmod.glob(p, recursive=recursive)
if expanded:
file_set.update(expanded)
else:
# Maybe it's a literal path that doesn't look like a glob
if os.path.isfile(p):
file_set.add(p)
else:
skipped += 1
if not file_set:
return "No matching files found."
for fpath in sorted(file_set):
if not os.path.isfile(fpath):
skipped += 1
continue
ext = os.path.splitext(fpath)[1].lower()
stat = os.stat(fpath)
base_meta = {
"filename": os.path.basename(fpath),
"path": os.path.relpath(fpath),
"extension": ext,
"size": stat.st_size,
"mtime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(stat.st_mtime)),
}
base_name = os.path.splitext(os.path.basename(fpath))[0]
# ── read content ──────────────────────────────────────────
if ext in (".xlsx", ".xlsm"):
try:
import openpyxl
except ImportError:
skipped += 1
continue
wb = openpyxl.load_workbook(fpath, read_only=True, data_only=True)
for sheet_name in wb.sheetnames:
ws = wb[sheet_name]
rows = []
for row in ws.iter_rows(values_only=True):
vals = [str(c) if c is not None else "" for c in row]
rows.append("\t".join(vals))
lines = rows
content = "\n".join(lines)
if not content.strip():
continue
sheet_meta = dict(base_meta)
sheet_meta["sheet"] = sheet_name
if chunk_size > 0:
n_lines = len(lines)
cid = 0
start = 0
while start < n_lines:
end = min(start + chunk_size, n_lines)
chunk_text = "\n".join(lines[start:end])
doc_id = f"{base_name}_{sheet_name}_chunk_{cid}"
meta = dict(sheet_meta)
meta["chunk_index"] = cid
meta["chunk_lines"] = end - start
meta["chunk_start_line"] = start + 1
all_ids.append(doc_id)
all_texts.append(chunk_text)
all_metas.append(_sanitize_meta(meta))
cid += 1
step = chunk_size - chunk_overlap
start += step if step > 0 else 1
processed += 1
else:
doc_id = f"{base_name}_{sheet_name}"
all_ids.append(doc_id)
all_texts.append(content)
all_metas.append(_sanitize_meta(sheet_meta))
processed += 1
wb.close()
else:
# Plain-text files
try:
with open(fpath, "r", encoding="utf-8", errors="replace") as f:
lines = f.readlines()
except Exception:
skipped += 1
continue
content = "".join(lines)
if not content.strip():
skipped += 1
continue
if chunk_size > 0:
n_lines = len(lines)
cid = 0
start = 0
while start < n_lines:
end = min(start + chunk_size, n_lines)
chunk_text = "".join(lines[start:end])
doc_id = f"{base_name}_chunk_{cid}"
meta = dict(base_meta)
meta["chunk_index"] = cid
meta["chunk_lines"] = end - start
meta["chunk_start_line"] = start + 1
all_ids.append(doc_id)
all_texts.append(chunk_text)
all_metas.append(_sanitize_meta(meta))
cid += 1
step = chunk_size - chunk_overlap
start += step if step > 0 else 1
processed += 1
else:
doc_id = base_name
all_ids.append(doc_id)
all_texts.append(content)
all_metas.append(_sanitize_meta(base_meta))
processed += 1
if all_ids:
col.add(ids=all_ids, documents=all_texts, metadatas=all_metas)
parts = [f"Ingested {processed} file(s) into '{collection}'"]
if processed > 0:
parts.append(f"({len(all_ids)} document(s) total)")
if skipped > 0:
parts.append(f"({skipped} file(s) skipped)")
return " ".join(parts)
except Exception as e:
return f"Error: {e}"