371 lines
13 KiB
Python
371 lines
13 KiB
Python
"""
|
|
Knowledge base command handler - RAG-powered Q&A from book embeddings
|
|
Supports multiple books with user selection
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import glob
|
|
import numpy as np
|
|
from typing import List, Tuple, Optional, Dict
|
|
from pathlib import Path
|
|
|
|
from bot.command_registry import register_module
|
|
import ai.parser as ai_parser
|
|
from ai.parser import client
|
|
|
|
# Configuration
|
|
EPUBS_DIRECTORY = os.getenv("KNOWLEDGE_EMBEDDINGS_DIR", "./bot/data")
|
|
TOP_K_CHUNKS = 5
|
|
CHAT_MODEL = "deepseek/deepseek-v3.2"
|
|
EMBEDDING_EXTENSION = ".embeddings.json"
|
|
|
|
# Map embedding dimensions to the model that produced them
|
|
EMBEDDING_MODELS_BY_DIM = {
|
|
384: "sentence-transformers/all-minilm-l12-v2",
|
|
4096: "qwen/qwen3-embedding-8b",
|
|
}
|
|
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-minilm-l12-v2"
|
|
|
|
# Cache for loaded embeddings: {file_path: (chunks, embeddings, metadata)}
|
|
_knowledge_cache: Dict[str, Tuple[List[str], List[List[float]], dict]] = {}
|
|
|
|
|
|
def find_embedding_files() -> List[str]:
|
|
"""Find all embedding files in the directory."""
|
|
os.makedirs(EPUBS_DIRECTORY, exist_ok=True)
|
|
pattern = os.path.join(EPUBS_DIRECTORY, f"*{EMBEDDING_EXTENSION}")
|
|
files = glob.glob(pattern)
|
|
return sorted(files)
|
|
|
|
|
|
def get_book_name(file_path: str) -> str:
|
|
"""Extract readable book name from file path."""
|
|
return (
|
|
Path(file_path).stem.replace(EMBEDDING_EXTENSION, "").replace(".", " ").title()
|
|
)
|
|
|
|
|
|
def load_knowledge_base(
|
|
file_path: str,
|
|
) -> Optional[Tuple[List[str], List[List[float]], dict]]:
|
|
"""Load and cache a specific embeddings file."""
|
|
if file_path in _knowledge_cache:
|
|
return _knowledge_cache[file_path]
|
|
|
|
if not os.path.exists(file_path):
|
|
return None
|
|
|
|
with open(file_path, "r") as f:
|
|
data = json.load(f)
|
|
|
|
# Handle both dict format {"chunks": [...], "embeddings": [...], "metadata": {...}}
|
|
# and legacy list format where data is just the chunks
|
|
if isinstance(data, dict):
|
|
chunks = data.get("chunks", [])
|
|
embeddings = data.get("embeddings", [])
|
|
metadata = data.get("metadata", {})
|
|
elif isinstance(data, list):
|
|
# Legacy format: assume it's just chunks, or list of [chunk, embedding] pairs
|
|
if data and isinstance(data[0], dict) and "text" in data[0]:
|
|
# Format: [{"text": "...", "embedding": [...]}, ...]
|
|
chunks = [item.get("text", "") for item in data]
|
|
embeddings = [item.get("embedding", []) for item in data]
|
|
metadata = {"format": "legacy_list_of_dicts"}
|
|
else:
|
|
# Unknown list format - can't process
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
# Add file_path to metadata for reference
|
|
metadata["_file_path"] = file_path
|
|
|
|
_knowledge_cache[file_path] = (chunks, embeddings, metadata)
|
|
return _knowledge_cache[file_path]
|
|
|
|
|
|
def get_embedding_model_for_dim(dim: int) -> str:
|
|
"""Get the correct embedding model for a given dimension."""
|
|
return EMBEDDING_MODELS_BY_DIM.get(dim, DEFAULT_EMBEDDING_MODEL)
|
|
|
|
|
|
def get_query_embedding(query: str, model: str = DEFAULT_EMBEDDING_MODEL) -> List[float]:
|
|
"""Embed the user's question via OpenRouter."""
|
|
response = client.embeddings.create(model=model, input=query)
|
|
return response.data[0].embedding
|
|
|
|
|
|
def cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
|
"""Calculate similarity between two vectors."""
|
|
vec1 = np.array(vec1)
|
|
vec2 = np.array(vec2)
|
|
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
|
|
|
|
|
def search_context(
|
|
query_embedding: List[float],
|
|
chunks: List[str],
|
|
embeddings: List[List[float]],
|
|
top_k: int = 5,
|
|
) -> Tuple[List[str], List[float]]:
|
|
"""Find the most relevant chunks and return them with scores."""
|
|
scores = []
|
|
for i, emb in enumerate(embeddings):
|
|
score = cosine_similarity(query_embedding, emb)
|
|
scores.append((score, i))
|
|
|
|
scores.sort(key=lambda x: x[0], reverse=True)
|
|
top_chunks = [chunks[i] for score, i in scores[:top_k]]
|
|
top_scores = [score for score, i in scores[:top_k]]
|
|
|
|
return top_chunks, top_scores
|
|
|
|
|
|
def generate_answer(query: str, context_chunks: List[str], book_title: str) -> str:
|
|
"""Generate answer using DeepSeek via OpenRouter."""
|
|
|
|
context_text = "\n\n---\n\n".join(context_chunks)
|
|
|
|
system_prompt = f"""You are an expert assistant answering questions about "{book_title}".
|
|
Answer based strictly on the provided context. If the answer isn't in the context, say you don't know.
|
|
Do not make up information. Provide clear, helpful answers based on the book's content.
|
|
|
|
Context from {book_title}:
|
|
{context_text}"""
|
|
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=CHAT_MODEL,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": query},
|
|
],
|
|
temperature=0.1,
|
|
)
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
return f"❌ Error generating answer: {e}"
|
|
|
|
|
|
def get_user_selected_book(session) -> Optional[str]:
|
|
"""Get the currently selected book for a user."""
|
|
return session.get("knowledge_base", {}).get("selected_book")
|
|
|
|
|
|
def set_user_selected_book(session, file_path: str):
|
|
"""Set the selected book for a user."""
|
|
if "knowledge_base" not in session:
|
|
session["knowledge_base"] = {}
|
|
session["knowledge_base"]["selected_book"] = file_path
|
|
|
|
|
|
async def handle_knowledge(message, session, parsed):
|
|
"""Handle knowledge base queries with dynamic book selection."""
|
|
action = parsed.get("action", "query")
|
|
|
|
if action == "list":
|
|
embedding_files = find_embedding_files()
|
|
|
|
if not embedding_files:
|
|
await message.channel.send(
|
|
f"❌ No knowledge bases found in `{EPUBS_DIRECTORY}`"
|
|
)
|
|
return
|
|
|
|
lines = [f"{i + 1}. {get_book_name(f)}" for i, f in enumerate(embedding_files)]
|
|
current = get_user_selected_book(session)
|
|
current_text = (
|
|
f"\n\n📖 Currently selected: **{get_book_name(current)}**"
|
|
if current
|
|
else ""
|
|
)
|
|
|
|
await message.channel.send(
|
|
f"📚 **Available Knowledge Bases:**\n"
|
|
+ "\n".join(lines)
|
|
+ current_text
|
|
+ "\n\nUse `ask <book number/name> <question>` or `select book <number/name>`"
|
|
)
|
|
|
|
elif action == "select":
|
|
book_identifier = parsed.get("book", "")
|
|
embedding_files = find_embedding_files()
|
|
|
|
if not embedding_files:
|
|
await message.channel.send("❌ No knowledge bases available.")
|
|
return
|
|
|
|
selected_file = None
|
|
|
|
# Try to parse as number
|
|
try:
|
|
book_num = int(book_identifier) - 1
|
|
if 0 <= book_num < len(embedding_files):
|
|
selected_file = embedding_files[book_num]
|
|
except (ValueError, TypeError):
|
|
# Try to match by name
|
|
book_lower = book_identifier.lower()
|
|
for f in embedding_files:
|
|
if book_lower in get_book_name(f).lower() or book_lower in f.lower():
|
|
selected_file = f
|
|
break
|
|
|
|
if not selected_file:
|
|
await message.channel.send(
|
|
f"❌ Could not find book '{book_identifier}'. Use `list books` to see available options."
|
|
)
|
|
return
|
|
|
|
set_user_selected_book(session, selected_file)
|
|
book_name = get_book_name(selected_file)
|
|
await message.channel.send(f"✅ Selected knowledge base: **{book_name}**")
|
|
|
|
elif action == "query":
|
|
query = parsed.get("query", "")
|
|
book_override = parsed.get("book", "")
|
|
|
|
if not query:
|
|
await message.channel.send(
|
|
"What would you like to know? (e.g., 'what does the book say about time management?')"
|
|
)
|
|
return
|
|
|
|
# Determine which book to use
|
|
selected_file = None
|
|
|
|
if book_override:
|
|
# User specified a book in the query
|
|
embedding_files = find_embedding_files()
|
|
book_lower = book_override.lower()
|
|
|
|
# Try number first
|
|
try:
|
|
book_num = int(book_override) - 1
|
|
if 0 <= book_num < len(embedding_files):
|
|
selected_file = embedding_files[book_num]
|
|
except (ValueError, TypeError):
|
|
# Try name match
|
|
for f in embedding_files:
|
|
if (
|
|
book_lower in get_book_name(f).lower()
|
|
or book_lower in f.lower()
|
|
):
|
|
selected_file = f
|
|
break
|
|
else:
|
|
# Use user's selected book or default to first available
|
|
selected_file = get_user_selected_book(session)
|
|
if not selected_file:
|
|
embedding_files = find_embedding_files()
|
|
if embedding_files:
|
|
selected_file = embedding_files[0]
|
|
set_user_selected_book(session, selected_file)
|
|
|
|
if not selected_file:
|
|
await message.channel.send(
|
|
"❌ No knowledge base available. Please check the embeddings directory."
|
|
)
|
|
return
|
|
|
|
# Load knowledge base
|
|
kb_data = load_knowledge_base(selected_file)
|
|
if kb_data is None:
|
|
await message.channel.send(
|
|
"❌ Error loading knowledge base. Please check the file path."
|
|
)
|
|
return
|
|
|
|
chunks, embeddings, metadata = kb_data
|
|
book_title = metadata.get("title", get_book_name(selected_file))
|
|
|
|
await message.channel.send(f"🔍 Searching **{book_title}**...")
|
|
|
|
try:
|
|
# Detect embedding dimension and use matching model
|
|
emb_dim = len(embeddings[0]) if embeddings else 384
|
|
embedding_model = get_embedding_model_for_dim(emb_dim)
|
|
|
|
# Get query embedding and search
|
|
query_emb = get_query_embedding(query, model=embedding_model)
|
|
relevant_chunks, scores = search_context(
|
|
query_emb, chunks, embeddings, TOP_K_CHUNKS
|
|
)
|
|
|
|
# Generate answer
|
|
answer = generate_answer(query, relevant_chunks, book_title)
|
|
|
|
# Send response
|
|
await message.channel.send(f"📚 **Answer:**\n{answer}")
|
|
|
|
except Exception as e:
|
|
await message.channel.send(f"❌ Error processing query: {e}")
|
|
|
|
elif action == "dbt_evaluate_advice":
|
|
advice = parsed.get("advice", "")
|
|
if not advice:
|
|
await message.channel.send("Please provide the advice you want to evaluate.")
|
|
return
|
|
|
|
await message.channel.send("Processing your request for DBT advice evaluation. This may take a minute...")
|
|
|
|
system_prompt = """You are an expert in Dialectical Behavior Therapy (DBT).
|
|
Your task is to evaluate the provided advice against DBT principles.
|
|
Focus on whether the advice aligns with DBT skills, such as mindfulness, distress tolerance, emotion regulation, and interpersonal effectiveness.
|
|
Provide a clear "cleared" or "not cleared" judgment, followed by a brief explanation of why, referencing specific DBT principles where applicable.
|
|
|
|
Example of good advice evaluation:
|
|
CLEARED: This advice encourages mindfulness by suggesting to observe thoughts without judgment, which is a core DBT skill.
|
|
|
|
Example of bad advice evaluation:
|
|
NOT CLEARED: This advice promotes suppressing emotions, which is contrary to DBT's emphasis on emotion regulation through healthy expression and understanding.
|
|
|
|
Evaluate the following advice:
|
|
"""
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=CHAT_MODEL,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": advice},
|
|
],
|
|
temperature=0.2, # Slightly higher temperature for more varied explanations, but still grounded
|
|
)
|
|
evaluation = response.choices[0].message.content
|
|
await message.channel.send(f"**DBT Advice Jury Says:**\n{evaluation}")
|
|
except Exception as e:
|
|
await message.channel.send(f"❌ Error evaluating advice: {e}")
|
|
|
|
else:
|
|
await message.channel.send(
|
|
f"Unknown knowledge action: {action}. Try: list, select, query, or dbt_evaluate_advice."
|
|
)
|
|
|
|
|
|
def validate_knowledge_json(data):
|
|
"""Validate parsed JSON for knowledge queries."""
|
|
errors = []
|
|
|
|
if not isinstance(data, dict):
|
|
return ["Response must be a JSON object"]
|
|
|
|
if "error" in data:
|
|
return []
|
|
|
|
if "action" not in data:
|
|
errors.append("Missing required field: action")
|
|
|
|
action = data.get("action")
|
|
|
|
if action == "dbt_evaluate_advice" and "advice" not in data:
|
|
errors.append("Missing required field for dbt_evaluate_advice: advice")
|
|
|
|
return errors
|
|
|
|
|
|
# Register the module
|
|
register_module("knowledge", handle_knowledge)
|
|
|
|
# Register the validator
|
|
ai_parser.register_validator("knowledge", validate_knowledge_json)
|