539 lines
19 KiB
Python
539 lines
19 KiB
Python
import glob as globmod
|
|
import json
|
|
import os
|
|
import time
|
|
import pandas as pd
|
|
import lancedb
|
|
from lancedb.pydantic import LanceModel
|
|
from sentence_transformers import SentenceTransformer
|
|
import config
|
|
|
|
# ── Embedding Setup ───────────────────────────────────────────────────────
|
|
|
|
def load_embedding_model():
|
|
"""
|
|
Logika pemuatan model embedding berdasarkan konfigurasi:
|
|
1. Jika model_path kosong -> gunakan default cache (~/.cache/...)
|
|
2. Jika model_path diisi tapi folder belum ada -> download lalu simpan ke folder tersebut
|
|
3. Jika model_path diisi dan folder sudah ada -> load langsung dari folder tersebut
|
|
"""
|
|
model_name = "all-MiniLM-L6-v2"
|
|
custom_path = config.RAG_MODEL_PATH.strip()
|
|
|
|
try:
|
|
if not custom_path:
|
|
# Kasus 1: Pakai default cache
|
|
print(f"[RAG] Loading embedding model '{model_name}' from default cache...")
|
|
return SentenceTransformer(model_name)
|
|
|
|
# Kasus 2 & 3: Menggunakan path kustom
|
|
if os.path.exists(custom_path):
|
|
print(f"[RAG] Loading embedding model from custom path: {custom_path}")
|
|
return SentenceTransformer(custom_path)
|
|
else:
|
|
print(f"[RAG] Custom path {custom_path} not found. Downloading model first...")
|
|
model = SentenceTransformer(model_name)
|
|
# Buat direktori jika belum ada
|
|
os.makedirs(custom_path, exist_ok=True)
|
|
model.save(custom_path)
|
|
print(f"[RAG] Model successfully downloaded and saved to: {custom_path}")
|
|
return model
|
|
|
|
except Exception as e:
|
|
print(f"[RAG] Critical Error loading embedding model: {e}")
|
|
return None
|
|
|
|
# Inisialisasi model saat startup
|
|
embedding_model = load_embedding_model()
|
|
|
|
def get_embedding(text):
|
|
"""Fungsi standar untuk menghasilkan embedding"""
|
|
if embedding_model is None:
|
|
raise Exception("Embedding model not loaded. Check your config or internet connection.")
|
|
return embedding_model.encode(text).tolist()
|
|
|
|
# Skema sederhana untuk menghindari konflik Pydantic
|
|
class DocumentSchema(LanceModel):
|
|
text: str
|
|
id: str
|
|
metadata: str
|
|
vector: list[float]
|
|
|
|
# ── LanceDB singleton ───────────────────────────────────────────────────────
|
|
|
|
_db = None
|
|
|
|
def _get_db():
|
|
global _db
|
|
if _db is None:
|
|
_db = lancedb.connect(config.RAG_PERSIST_DIR)
|
|
return _db
|
|
|
|
def _get_table(name):
|
|
db = _get_db()
|
|
if name in db.table_names():
|
|
return db.open_table(name)
|
|
return db.create_table(name, schema=DocumentSchema)
|
|
|
|
# ── Tool schemas ─────────────────────────────────────────────────────
|
|
|
|
schema_store_knowledge = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "store_knowledge",
|
|
"description": (
|
|
"Store one or more documents with arbitrary metadata into a RAG collection. "
|
|
"Metadata is a free-form dict — choose meaningful keys for future filtering "
|
|
"(e.g., restaurant, category, allergens, spice_level, taste_profile, price"
|
|
", customer_id, dietary)."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"collection": {
|
|
"type": "string",
|
|
"description": "Target collection name (must be defined in config)"
|
|
},
|
|
"documents": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"id": {"type": "string", "description": "Unique document ID"},
|
|
"text": {"type": "string", "description": "Document body text"},
|
|
"metadata": {
|
|
"type": "object",
|
|
"description": "Arbitrary key-value metadata",
|
|
"default": {}
|
|
}
|
|
},
|
|
"required": ["id", "text"]
|
|
},
|
|
"description": "List of documents to persist"
|
|
}
|
|
},
|
|
"required": ["collection", "documents"]
|
|
}
|
|
}
|
|
}
|
|
|
|
schema_search_knowledge = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search_knowledge",
|
|
"description": (
|
|
"Semantically search a RAG collection. Optionally narrow with a "
|
|
"metadata filter using SQL-like syntax. "
|
|
"Example: \"metadata LIKE '%main_course%'\""
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"collection": {
|
|
"type": "string",
|
|
"description": "Collection name to search in"
|
|
},
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Natural-language search query"
|
|
},
|
|
"n_results": {
|
|
"type": "integer",
|
|
"description": "Max results to return (default 5)",
|
|
"default": 5
|
|
},
|
|
"filter": {
|
|
"type": "string",
|
|
"description": "Optional SQL-like filter for metadata JSON string",
|
|
"default": None
|
|
}
|
|
},
|
|
"required": ["collection", "query"]
|
|
}
|
|
}
|
|
}
|
|
|
|
schema_create_collection = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "create_collection",
|
|
"description": (
|
|
"Create a new RAG collection for a new topic/domain. Use a short, descriptive name "
|
|
"with underscores (e.g., 'tanaman_hias', 'customer_profiles'). Optionally provide a description."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {
|
|
"type": "string",
|
|
"description": "Collection name (lowercase, underscores for spaces)"
|
|
},
|
|
"description": {
|
|
"type": "string",
|
|
"description": "What this collection stores",
|
|
"default": ""
|
|
}
|
|
},
|
|
"required": ["name"]
|
|
}
|
|
}
|
|
}
|
|
|
|
schema_delete_collection = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "delete_collection",
|
|
"description": "Permanently delete an entire RAG collection and all documents in it. Use with caution — this cannot be undone.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {
|
|
"type": "string",
|
|
"description": "Collection name to delete"
|
|
}
|
|
},
|
|
"required": ["name"]
|
|
}
|
|
}
|
|
}
|
|
|
|
schema_list_collections = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "list_collections",
|
|
"description": "List all existing RAG collections with their document count and description.",
|
|
"parameters": {"type": "object", "properties": {}}
|
|
}
|
|
}
|
|
|
|
schema_inspect_collection = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "inspect_collection",
|
|
"description": (
|
|
"Examine sample documents and metadata fields in a RAG collection. "
|
|
"Always call this before search_knowledge to learn what metadata keys "
|
|
"are available for filtering, then pass them in the filter parameter."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"collection": {
|
|
"type": "string",
|
|
"description": "Collection name to inspect"
|
|
},
|
|
"sample_size": {
|
|
"type": "integer",
|
|
"description": "Number of sample documents (default 3)",
|
|
"default": 3
|
|
}
|
|
},
|
|
"required": ["collection"]
|
|
}
|
|
}
|
|
}
|
|
|
|
schema_ingest_files = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "ingest_files",
|
|
"description": (
|
|
"Read one or more files (supports glob patterns like *.py or src/**/*.md) "
|
|
"and store their content into a RAG collection. "
|
|
"Optionally chunk files into smaller pieces by line count. "
|
|
"Automatically extracts metadata: filename, path, extension, size, modification time."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"collection": {
|
|
"type": "string",
|
|
"description": "Target collection name (will be created if it doesn't exist)"
|
|
},
|
|
"paths": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "File paths or glob patterns (e.g., ['*.txt', 'src/**/*.py'])"
|
|
},
|
|
"chunk_size": {
|
|
"type": "integer",
|
|
"description": "Lines per chunk (0 = whole file as one document)",
|
|
"default": 0
|
|
},
|
|
"chunk_overlap": {
|
|
"type": "integer",
|
|
"description": "Line overlap between chunks (only used when chunk_size > 0)",
|
|
"default": 0
|
|
},
|
|
"recursive": {
|
|
"type": "boolean",
|
|
"description": "Search directories recursively when using glob patterns",
|
|
"default": True
|
|
}
|
|
},
|
|
"required": ["collection", "paths"]
|
|
}
|
|
}
|
|
}
|
|
|
|
# ── Tool handlers ────────────────────────────────────────────────────
|
|
|
|
def store_knowledge(collection, documents):
|
|
try:
|
|
table = _get_table(collection)
|
|
data = []
|
|
for doc in documents:
|
|
data.append({
|
|
"id": doc["id"],
|
|
"text": doc["text"],
|
|
"metadata": json.dumps(doc.get("metadata", {}), ensure_ascii=False),
|
|
"vector": get_embedding(doc["text"])
|
|
})
|
|
table.add(data)
|
|
return f"Stored {len(documents)} document(s) in '{collection}'."
|
|
except Exception as e:
|
|
return f"Error: {e}"
|
|
|
|
def search_knowledge(collection, query, n_results=5, filter=None):
|
|
try:
|
|
table = _get_table(collection)
|
|
# LanceDB semantic search
|
|
query_vector = get_embedding(query)
|
|
res = table.search(query_vector).limit(n_results)
|
|
|
|
if filter:
|
|
res = table.search(query_vector).where(filter).limit(n_results)
|
|
|
|
df = res.to_pandas()
|
|
|
|
if df.empty:
|
|
return "No results found."
|
|
|
|
out = []
|
|
for _, row in df.iterrows():
|
|
did = row["id"]
|
|
txt = row["text"]
|
|
if len(txt) > 500:
|
|
txt = txt[:500] + "..."
|
|
meta = row["metadata"]
|
|
out.append(f"[{did}]\n text: {txt}\n metadata: {meta}")
|
|
|
|
return "\n---\n".join(out)
|
|
except Exception as e:
|
|
return f"Error: {e}"
|
|
|
|
def create_collection(name, description=""):
|
|
try:
|
|
_get_table(name)
|
|
return f"Collection '{name}' is ready."
|
|
except Exception as e:
|
|
return f"Error: {e}"
|
|
|
|
def delete_collection(name):
|
|
try:
|
|
db = _get_db()
|
|
table_path = os.path.join(config.RAG_PERSIST_DIR, name)
|
|
if os.path.exists(table_path):
|
|
import shutil
|
|
shutil.rmtree(table_path)
|
|
return f"Deleted collection '{name}'."
|
|
except Exception as e:
|
|
return f"Error: {e}"
|
|
|
|
def list_collections():
|
|
try:
|
|
db = _get_db()
|
|
cols = db.table_names()
|
|
if not cols:
|
|
return "No collections exist yet."
|
|
|
|
out = ["Available collections:"]
|
|
for col in cols:
|
|
table = db.open_table(col)
|
|
cnt = len(table.to_pandas())
|
|
out.append(f"- {col} [{cnt} docs]")
|
|
return "\n".join(out)
|
|
except Exception as e:
|
|
return f"Error: {e}"
|
|
|
|
def inspect_collection(collection, sample_size=3):
|
|
try:
|
|
table = _get_table(collection)
|
|
df = table.to_pandas()
|
|
cnt = len(df)
|
|
if cnt == 0:
|
|
return f"Collection '{collection}' is empty."
|
|
|
|
n = min(sample_size, cnt)
|
|
sample = df.head(n)
|
|
|
|
out = [f"Collection: {collection} | Total documents: {cnt}", f"Sample ({n}):"]
|
|
for _, row in sample.iterrows():
|
|
txt = row["text"]
|
|
if len(txt) > 200:
|
|
txt = txt[:200] + "..."
|
|
meta = row["metadata"]
|
|
out.append(f"\n [{row['id']}] text: {txt} metadata: {meta}")
|
|
|
|
keys = set()
|
|
for m_str in sample["metadata"]:
|
|
try:
|
|
m_dict = json.loads(m_str)
|
|
keys.update(m_dict.keys())
|
|
except:
|
|
pass
|
|
if keys:
|
|
out.append(f"\nMetadata keys: {', '.join(sorted(keys))}")
|
|
|
|
return "\n".join(out)
|
|
except Exception as e:
|
|
return f"Error: {e}"
|
|
|
|
def ingest_files(collection, paths, chunk_size=0, chunk_overlap=0, recursive=True):
|
|
try:
|
|
table = _get_table(collection)
|
|
all_data = []
|
|
processed, skipped = 0, 0
|
|
|
|
file_set = set()
|
|
for p in paths:
|
|
expanded = globmod.glob(p, recursive=recursive)
|
|
if expanded:
|
|
file_set.update(expanded)
|
|
elif os.path.isfile(p):
|
|
file_set.add(p)
|
|
else:
|
|
skipped += 1
|
|
|
|
if not file_set:
|
|
return "No matching files found."
|
|
|
|
for fpath in sorted(file_set):
|
|
if not os.path.isfile(fpath):
|
|
skipped += 1
|
|
continue
|
|
|
|
ext = os.path.splitext(fpath)[1].lower()
|
|
stat = os.stat(fpath)
|
|
base_meta = {
|
|
"filename": os.path.basename(fpath),
|
|
"path": os.path.relpath(fpath),
|
|
"extension": ext,
|
|
"size": stat.st_size,
|
|
"mtime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(stat.st_mtime)),
|
|
}
|
|
base_name = os.path.splitext(os.path.basename(fpath))[0]
|
|
|
|
if ext in (".xlsx", ".xlsm"):
|
|
try:
|
|
import openpyxl
|
|
except ImportError:
|
|
skipped += 1
|
|
continue
|
|
|
|
wb = openpyxl.load_workbook(fpath, read_only=True, data_only=True)
|
|
for sheet_name in wb.sheetnames:
|
|
ws = wb[sheet_name]
|
|
rows = []
|
|
for row in ws.iter_rows(values_only=True):
|
|
vals = [str(c) if c is not None else "" for c in row]
|
|
rows.append("\t".join(vals))
|
|
|
|
lines = rows
|
|
content = "\n".join(lines)
|
|
if not content.strip():
|
|
continue
|
|
|
|
sheet_meta = dict(base_meta)
|
|
sheet_meta["sheet"] = sheet_name
|
|
|
|
if chunk_size > 0:
|
|
n_lines = len(lines)
|
|
cid = 0
|
|
start = 0
|
|
while start < n_lines:
|
|
end = min(start + chunk_size, n_lines)
|
|
chunk_text = "\n".join(lines[start:end])
|
|
doc_id = f"{base_name}_{sheet_name}_chunk_{cid}"
|
|
meta = dict(sheet_meta)
|
|
meta["chunk_index"] = cid
|
|
meta["chunk_lines"] = end - start
|
|
meta["chunk_start_line"] = start + 1
|
|
all_data.append({
|
|
"id": doc_id,
|
|
"text": chunk_text,
|
|
"metadata": json.dumps(meta, ensure_ascii=False),
|
|
"vector": get_embedding(chunk_text)
|
|
})
|
|
cid += 1
|
|
step = chunk_size - chunk_overlap
|
|
start += step if step > 0 else 1
|
|
processed += 1
|
|
else:
|
|
doc_id = f"{base_name}_{sheet_name}"
|
|
all_data.append({
|
|
"id": doc_id,
|
|
"text": content,
|
|
"metadata": json.dumps(sheet_meta, ensure_ascii=False),
|
|
"vector": get_embedding(content)
|
|
})
|
|
processed += 1
|
|
wb.close()
|
|
else:
|
|
try:
|
|
with open(fpath, "r", encoding="utf-8", errors="replace") as f:
|
|
lines = f.readlines()
|
|
except Exception:
|
|
skipped += 1
|
|
continue
|
|
|
|
content = "".join(lines)
|
|
if not content.strip():
|
|
skipped += 1
|
|
continue
|
|
|
|
if chunk_size > 0:
|
|
n_lines = len(lines)
|
|
cid = 0
|
|
start = 0
|
|
while start < n_lines:
|
|
end = min(start + chunk_size, n_lines)
|
|
chunk_text = "".join(lines[start:end])
|
|
doc_id = f"{base_name}_chunk_{cid}"
|
|
meta = dict(base_meta)
|
|
meta["chunk_index"] = cid
|
|
meta["chunk_lines"] = end - start
|
|
meta["chunk_start_line"] = start + 1
|
|
all_data.append({
|
|
"id": doc_id,
|
|
"text": chunk_text,
|
|
"metadata": json.dumps(meta, ensure_ascii=False),
|
|
"vector": get_embedding(chunk_text)
|
|
})
|
|
cid += 1
|
|
step = chunk_size - chunk_overlap
|
|
start += step if step > 0 else 1
|
|
processed += 1
|
|
else:
|
|
doc_id = base_name
|
|
all_data.append({
|
|
"id": doc_id,
|
|
"text": content,
|
|
"metadata": json.dumps(base_meta, ensure_ascii=False),
|
|
"vector": get_embedding(content)
|
|
})
|
|
processed += 1
|
|
|
|
if all_data:
|
|
table.add(all_data)
|
|
|
|
parts = [f"Ingested {processed} file(s) into '{collection}'"]
|
|
if processed > 0:
|
|
parts.append(f"({len(all_data)} document(s) total)")
|
|
if skipped > 0:
|
|
parts.append(f"({skipped} file(s) skipped)")
|
|
return " ".join(parts)
|
|
|
|
except Exception as e:
|
|
return f"Error: {e}"
|