Add knowledge base RAG module for book Q&A
- Create knowledge.py handler with dynamic book selection - Support list/select/query actions for multiple books - Implement vector search with cosine similarity - Add knowledge detection to AI parser config - Cache embeddings per-book for performance
This commit is contained in:
File diff suppressed because one or more lines are too long
43
bot/bot.py
43
bot/bot.py
@@ -20,8 +20,9 @@ import pickle
|
|||||||
|
|
||||||
from bot.command_registry import get_handler, list_registered
|
from bot.command_registry import get_handler, list_registered
|
||||||
import ai.parser as ai_parser
|
import ai.parser as ai_parser
|
||||||
import bot.commands.routines # noqa: F401 - registers handler
|
import bot.commands.routines # noqa: F401 - registers handler
|
||||||
import bot.commands.medications # noqa: F401 - registers handler
|
import bot.commands.medications # noqa: F401 - registers handler
|
||||||
|
import bot.commands.knowledge # noqa: F401 - registers handler
|
||||||
|
|
||||||
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
|
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
|
||||||
API_URL = os.getenv("API_URL", "http://app:5000")
|
API_URL = os.getenv("API_URL", "http://app:5000")
|
||||||
@@ -217,7 +218,7 @@ async def checkActiveSession(session):
|
|||||||
token = session.get("token")
|
token = session.get("token")
|
||||||
if not token:
|
if not token:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
resp, status = apiRequest("get", "/api/sessions/active", token)
|
resp, status = apiRequest("get", "/api/sessions/active", token)
|
||||||
if status == 200 and "session" in resp:
|
if status == 200 and "session" in resp:
|
||||||
return resp
|
return resp
|
||||||
@@ -228,45 +229,45 @@ async def handleConfirmation(message, session):
|
|||||||
"""Handle yes/no confirmation responses. Returns True if handled."""
|
"""Handle yes/no confirmation responses. Returns True if handled."""
|
||||||
discord_id = message.author.id
|
discord_id = message.author.id
|
||||||
user_input = message.content.lower().strip()
|
user_input = message.content.lower().strip()
|
||||||
|
|
||||||
if "pending_confirmations" not in session:
|
if "pending_confirmations" not in session:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check for any pending confirmations
|
# Check for any pending confirmations
|
||||||
pending = session["pending_confirmations"]
|
pending = session["pending_confirmations"]
|
||||||
if not pending:
|
if not pending:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Get the most recent pending confirmation
|
# Get the most recent pending confirmation
|
||||||
confirmation_id = list(pending.keys())[-1]
|
confirmation_id = list(pending.keys())[-1]
|
||||||
confirmation_data = pending[confirmation_id]
|
confirmation_data = pending[confirmation_id]
|
||||||
|
|
||||||
if user_input in ("yes", "y", "yeah", "sure", "ok", "confirm"):
|
if user_input in ("yes", "y", "yeah", "sure", "ok", "confirm"):
|
||||||
# Execute the confirmed action
|
# Execute the confirmed action
|
||||||
del pending[confirmation_id]
|
del pending[confirmation_id]
|
||||||
|
|
||||||
interaction_type = confirmation_data.get("interaction_type")
|
interaction_type = confirmation_data.get("interaction_type")
|
||||||
handler = get_handler(interaction_type)
|
handler = get_handler(interaction_type)
|
||||||
|
|
||||||
if handler:
|
if handler:
|
||||||
# Create a fake parsed object for the handler
|
# Create a fake parsed object for the handler
|
||||||
fake_parsed = confirmation_data.copy()
|
fake_parsed = confirmation_data.copy()
|
||||||
fake_parsed["needs_confirmation"] = False
|
fake_parsed["needs_confirmation"] = False
|
||||||
await handler(message, session, fake_parsed)
|
await handler(message, session, fake_parsed)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
elif user_input in ("no", "n", "nah", "cancel", "abort"):
|
elif user_input in ("no", "n", "nah", "cancel", "abort"):
|
||||||
del pending[confirmation_id]
|
del pending[confirmation_id]
|
||||||
await message.channel.send("❌ Cancelled.")
|
await message.channel.send("❌ Cancelled.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def handleActiveSessionShortcuts(message, session, active_session):
|
async def handleActiveSessionShortcuts(message, session, active_session):
|
||||||
"""Handle shortcuts like 'done', 'skip', 'next' when in active session."""
|
"""Handle shortcuts like 'done', 'skip', 'next' when in active session."""
|
||||||
user_input = message.content.lower().strip()
|
user_input = message.content.lower().strip()
|
||||||
|
|
||||||
# Map common shortcuts to actions
|
# Map common shortcuts to actions
|
||||||
shortcuts = {
|
shortcuts = {
|
||||||
"done": ("routine", "complete"),
|
"done": ("routine", "complete"),
|
||||||
@@ -283,7 +284,7 @@ async def handleActiveSessionShortcuts(message, session, active_session):
|
|||||||
"quit": ("routine", "cancel"),
|
"quit": ("routine", "cancel"),
|
||||||
"abort": ("routine", "abort"),
|
"abort": ("routine", "abort"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if user_input in shortcuts:
|
if user_input in shortcuts:
|
||||||
interaction_type, action = shortcuts[user_input]
|
interaction_type, action = shortcuts[user_input]
|
||||||
handler = get_handler(interaction_type)
|
handler = get_handler(interaction_type)
|
||||||
@@ -291,7 +292,7 @@ async def handleActiveSessionShortcuts(message, session, active_session):
|
|||||||
fake_parsed = {"action": action}
|
fake_parsed = {"action": action}
|
||||||
await handler(message, session, fake_parsed)
|
await handler(message, session, fake_parsed)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -306,21 +307,23 @@ async def routeCommand(message):
|
|||||||
|
|
||||||
# Check for active session first
|
# Check for active session first
|
||||||
active_session = await checkActiveSession(session)
|
active_session = await checkActiveSession(session)
|
||||||
|
|
||||||
# Handle confirmation responses
|
# Handle confirmation responses
|
||||||
confirmation_handled = await handleConfirmation(message, session)
|
confirmation_handled = await handleConfirmation(message, session)
|
||||||
if confirmation_handled:
|
if confirmation_handled:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle shortcuts when in active session
|
# Handle shortcuts when in active session
|
||||||
if active_session:
|
if active_session:
|
||||||
shortcut_handled = await handleActiveSessionShortcuts(message, session, active_session)
|
shortcut_handled = await handleActiveSessionShortcuts(
|
||||||
|
message, session, active_session
|
||||||
|
)
|
||||||
if shortcut_handled:
|
if shortcut_handled:
|
||||||
return
|
return
|
||||||
|
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
history = message_history.get(discord_id, [])
|
history = message_history.get(discord_id, [])
|
||||||
|
|
||||||
# Add context about active session to help AI understand
|
# Add context about active session to help AI understand
|
||||||
context = ""
|
context = ""
|
||||||
if active_session:
|
if active_session:
|
||||||
@@ -329,8 +332,10 @@ async def routeCommand(message):
|
|||||||
current_step = session_data.get("current_step_index", 0) + 1
|
current_step = session_data.get("current_step_index", 0) + 1
|
||||||
total_steps = active_session.get("total_steps", 0)
|
total_steps = active_session.get("total_steps", 0)
|
||||||
context = f"\n[Context: User is currently in active session for '{routine_name}', on step {current_step} of {total_steps}. They can say 'done', 'skip', 'pause', 'resume', or 'stop'.]"
|
context = f"\n[Context: User is currently in active session for '{routine_name}', on step {current_step} of {total_steps}. They can say 'done', 'skip', 'pause', 'resume', or 'stop'.]"
|
||||||
|
|
||||||
parsed = ai_parser.parse(message.content + context, "command_parser", history=history)
|
parsed = ai_parser.parse(
|
||||||
|
message.content + context, "command_parser", history=history
|
||||||
|
)
|
||||||
|
|
||||||
if discord_id not in message_history:
|
if discord_id not in message_history:
|
||||||
message_history[discord_id] = []
|
message_history[discord_id] = []
|
||||||
|
|||||||
300
bot/commands/knowledge.py
Normal file
300
bot/commands/knowledge.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""
|
||||||
|
Knowledge base command handler - RAG-powered Q&A from book embeddings
|
||||||
|
Supports multiple books with user selection
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Tuple, Optional, Dict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from bot.command_registry import register_module
|
||||||
|
import ai.parser as ai_parser
|
||||||
|
from ai.parser import client
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
EPUBS_DIRECTORY = os.getenv("KNOWLEDGE_EMBEDDINGS_DIR", "../embedding-generator/epubs")
|
||||||
|
TOP_K_CHUNKS = 5
|
||||||
|
EMBEDDING_MODEL = "sentence-transformers/all-minilm-l12-l2"
|
||||||
|
CHAT_MODEL = "deepseek/deepseek-v3.2"
|
||||||
|
EMBEDDING_EXTENSION = ".embeddings.json"
|
||||||
|
|
||||||
|
# Cache for loaded embeddings: {file_path: (chunks, embeddings, metadata)}
|
||||||
|
_knowledge_cache: Dict[str, Tuple[List[str], List[List[float]], dict]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def find_embedding_files() -> List[str]:
|
||||||
|
"""Find all embedding files in the directory."""
|
||||||
|
os.makedirs(EPUBS_DIRECTORY, exist_ok=True)
|
||||||
|
pattern = os.path.join(EPUBS_DIRECTORY, f"*{EMBEDDING_EXTENSION}")
|
||||||
|
files = glob.glob(pattern)
|
||||||
|
return sorted(files)
|
||||||
|
|
||||||
|
|
||||||
|
def get_book_name(file_path: str) -> str:
|
||||||
|
"""Extract readable book name from file path."""
|
||||||
|
return (
|
||||||
|
Path(file_path).stem.replace(EMBEDDING_EXTENSION, "").replace(".", " ").title()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_knowledge_base(
|
||||||
|
file_path: str,
|
||||||
|
) -> Optional[Tuple[List[str], List[List[float]], dict]]:
|
||||||
|
"""Load and cache a specific embeddings file."""
|
||||||
|
if file_path in _knowledge_cache:
|
||||||
|
return _knowledge_cache[file_path]
|
||||||
|
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
chunks = data.get("chunks", [])
|
||||||
|
embeddings = data.get("embeddings", [])
|
||||||
|
metadata = data.get("metadata", {})
|
||||||
|
|
||||||
|
# Add file_path to metadata for reference
|
||||||
|
metadata["_file_path"] = file_path
|
||||||
|
|
||||||
|
_knowledge_cache[file_path] = (chunks, embeddings, metadata)
|
||||||
|
return _knowledge_cache[file_path]
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_embedding(query: str) -> List[float]:
|
||||||
|
"""Embed the user's question via OpenRouter."""
|
||||||
|
response = client.embeddings.create(model=EMBEDDING_MODEL, input=query)
|
||||||
|
return response.data[0].embedding
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||||
|
"""Calculate similarity between two vectors."""
|
||||||
|
vec1 = np.array(vec1)
|
||||||
|
vec2 = np.array(vec2)
|
||||||
|
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||||
|
|
||||||
|
|
||||||
|
def search_context(
|
||||||
|
query_embedding: List[float],
|
||||||
|
chunks: List[str],
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> Tuple[List[str], List[float]]:
|
||||||
|
"""Find the most relevant chunks and return them with scores."""
|
||||||
|
scores = []
|
||||||
|
for i, emb in enumerate(embeddings):
|
||||||
|
score = cosine_similarity(query_embedding, emb)
|
||||||
|
scores.append((score, i))
|
||||||
|
|
||||||
|
scores.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
top_chunks = [chunks[i] for score, i in scores[:top_k]]
|
||||||
|
top_scores = [score for score, i in scores[:top_k]]
|
||||||
|
|
||||||
|
return top_chunks, top_scores
|
||||||
|
|
||||||
|
|
||||||
|
def generate_answer(query: str, context_chunks: List[str], book_title: str) -> str:
|
||||||
|
"""Generate answer using DeepSeek via OpenRouter."""
|
||||||
|
|
||||||
|
context_text = "\n\n---\n\n".join(context_chunks)
|
||||||
|
|
||||||
|
system_prompt = f"""You are an expert assistant answering questions about "{book_title}".
|
||||||
|
Answer based strictly on the provided context. If the answer isn't in the context, say you don't know.
|
||||||
|
Do not make up information. Provide clear, helpful answers based on the book's content.
|
||||||
|
|
||||||
|
Context from {book_title}:
|
||||||
|
{context_text}"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=CHAT_MODEL,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": query},
|
||||||
|
],
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
return f"❌ Error generating answer: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_selected_book(session) -> Optional[str]:
|
||||||
|
"""Get the currently selected book for a user."""
|
||||||
|
return session.get("knowledge_base", {}).get("selected_book")
|
||||||
|
|
||||||
|
|
||||||
|
def set_user_selected_book(session, file_path: str):
|
||||||
|
"""Set the selected book for a user."""
|
||||||
|
if "knowledge_base" not in session:
|
||||||
|
session["knowledge_base"] = {}
|
||||||
|
session["knowledge_base"]["selected_book"] = file_path
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_knowledge(message, session, parsed):
|
||||||
|
"""Handle knowledge base queries with dynamic book selection."""
|
||||||
|
action = parsed.get("action", "query")
|
||||||
|
|
||||||
|
if action == "list":
|
||||||
|
embedding_files = find_embedding_files()
|
||||||
|
|
||||||
|
if not embedding_files:
|
||||||
|
await message.channel.send(
|
||||||
|
f"❌ No knowledge bases found in `{EPUBS_DIRECTORY}`"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
lines = [f"{i + 1}. {get_book_name(f)}" for i, f in enumerate(embedding_files)]
|
||||||
|
current = get_user_selected_book(session)
|
||||||
|
current_text = (
|
||||||
|
f"\n\n📖 Currently selected: **{get_book_name(current)}**"
|
||||||
|
if current
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
await message.channel.send(
|
||||||
|
f"📚 **Available Knowledge Bases:**\n"
|
||||||
|
+ "\n".join(lines)
|
||||||
|
+ current_text
|
||||||
|
+ "\n\nUse `ask <book number/name> <question>` or `select book <number/name>`"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "select":
|
||||||
|
book_identifier = parsed.get("book", "")
|
||||||
|
embedding_files = find_embedding_files()
|
||||||
|
|
||||||
|
if not embedding_files:
|
||||||
|
await message.channel.send("❌ No knowledge bases available.")
|
||||||
|
return
|
||||||
|
|
||||||
|
selected_file = None
|
||||||
|
|
||||||
|
# Try to parse as number
|
||||||
|
try:
|
||||||
|
book_num = int(book_identifier) - 1
|
||||||
|
if 0 <= book_num < len(embedding_files):
|
||||||
|
selected_file = embedding_files[book_num]
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# Try to match by name
|
||||||
|
book_lower = book_identifier.lower()
|
||||||
|
for f in embedding_files:
|
||||||
|
if book_lower in get_book_name(f).lower() or book_lower in f.lower():
|
||||||
|
selected_file = f
|
||||||
|
break
|
||||||
|
|
||||||
|
if not selected_file:
|
||||||
|
await message.channel.send(
|
||||||
|
f"❌ Could not find book '{book_identifier}'. Use `list books` to see available options."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
set_user_selected_book(session, selected_file)
|
||||||
|
book_name = get_book_name(selected_file)
|
||||||
|
await message.channel.send(f"✅ Selected knowledge base: **{book_name}**")
|
||||||
|
|
||||||
|
elif action == "query":
|
||||||
|
query = parsed.get("query", "")
|
||||||
|
book_override = parsed.get("book", "")
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
await message.channel.send(
|
||||||
|
"What would you like to know? (e.g., 'what does the book say about time management?')"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine which book to use
|
||||||
|
selected_file = None
|
||||||
|
|
||||||
|
if book_override:
|
||||||
|
# User specified a book in the query
|
||||||
|
embedding_files = find_embedding_files()
|
||||||
|
book_lower = book_override.lower()
|
||||||
|
|
||||||
|
# Try number first
|
||||||
|
try:
|
||||||
|
book_num = int(book_override) - 1
|
||||||
|
if 0 <= book_num < len(embedding_files):
|
||||||
|
selected_file = embedding_files[book_num]
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# Try name match
|
||||||
|
for f in embedding_files:
|
||||||
|
if (
|
||||||
|
book_lower in get_book_name(f).lower()
|
||||||
|
or book_lower in f.lower()
|
||||||
|
):
|
||||||
|
selected_file = f
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Use user's selected book or default to first available
|
||||||
|
selected_file = get_user_selected_book(session)
|
||||||
|
if not selected_file:
|
||||||
|
embedding_files = find_embedding_files()
|
||||||
|
if embedding_files:
|
||||||
|
selected_file = embedding_files[0]
|
||||||
|
set_user_selected_book(session, selected_file)
|
||||||
|
|
||||||
|
if not selected_file:
|
||||||
|
await message.channel.send(
|
||||||
|
"❌ No knowledge base available. Please check the embeddings directory."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load knowledge base
|
||||||
|
kb_data = load_knowledge_base(selected_file)
|
||||||
|
if kb_data is None:
|
||||||
|
await message.channel.send(
|
||||||
|
"❌ Error loading knowledge base. Please check the file path."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
chunks, embeddings, metadata = kb_data
|
||||||
|
book_title = metadata.get("title", get_book_name(selected_file))
|
||||||
|
|
||||||
|
await message.channel.send(f"🔍 Searching **{book_title}**...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get query embedding and search
|
||||||
|
query_emb = get_query_embedding(query)
|
||||||
|
relevant_chunks, scores = search_context(
|
||||||
|
query_emb, chunks, embeddings, TOP_K_CHUNKS
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate answer
|
||||||
|
answer = generate_answer(query, relevant_chunks, book_title)
|
||||||
|
|
||||||
|
# Send response
|
||||||
|
await message.channel.send(f"📚 **Answer:**\n{answer}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await message.channel.send(f"❌ Error processing query: {e}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
await message.channel.send(
|
||||||
|
f"Unknown knowledge action: {action}. Try: list, select, or ask a question."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_knowledge_json(data):
|
||||||
|
"""Validate parsed JSON for knowledge queries."""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return ["Response must be a JSON object"]
|
||||||
|
|
||||||
|
if "error" in data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if "action" not in data:
|
||||||
|
errors.append("Missing required field: action")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
# Register the module
|
||||||
|
register_module("knowledge", handle_knowledge)
|
||||||
|
|
||||||
|
# Register the validator
|
||||||
|
ai_parser.register_validator("knowledge", validate_knowledge_json)
|
||||||
Reference in New Issue
Block a user