import os from typing import Any, Dict, List, Optional from fastapi import APIRouter, FastAPI, HTTPException, Response from pydantic import BaseModel, Field from Runner import Runner from Memory import MemoryManager from PromptLibrary import PromptLibrary app = FastAPI(title="ADHDbot API") router = APIRouter() class ChatTurn(BaseModel): role: str content: str class RunRequest(BaseModel): userId: Optional[str] = None category: str = "general" promptName: str = "welcome" context: str = "API triggered run" history: List[ChatTurn] = Field(default_factory=list) modeHint: Optional[str] = None class RunResponse(BaseModel): userId: Optional[str] category: str promptName: str context: str message: str class NoteCreate(BaseModel): note: str = Field(..., min_length=1, description="Raw text of the note") metadata: Dict[str, Any] = Field(default_factory=dict) class NotesResponse(BaseModel): userId: str notes: List[Dict[str, Any]] class MemoryResponse(BaseModel): userId: str summaries: List[Dict[str, Any]] notes: List[Dict[str, Any]] class PromptCatalogResponse(BaseModel): catalog: Dict[str, Dict[str, str]] class ContextRequest(BaseModel): context: str = Field(..., min_length=1) class ActionItemRecord(BaseModel): id: str title: str cadence: str details: Optional[str] = None interval_minutes: Optional[int] = None created_at: str updated_at: str progress: List[Dict[str, Any]] = Field(default_factory=list) class ActionItemCreate(BaseModel): title: str = Field(..., min_length=1) cadence: str = Field(default="daily") details: Optional[str] = None interval_minutes: Optional[int] = Field(default=None, ge=0) class ActionItemUpdate(BaseModel): title: Optional[str] = Field(default=None, min_length=1) cadence: Optional[str] = None details: Optional[str] = None interval_minutes: Optional[int] = Field(default=None, ge=0) class ActionItemProgressCreate(BaseModel): status: str = Field(default="update", min_length=1) note: Optional[str] = None class ActionItemsResponse(BaseModel): userId: str action_items: List[ActionItemRecord] def defaultUserId(): return os.getenv("TARGET_USER_ID") @router.get("/health") def health(): return {"status": "ok"} @router.post("/run", response_model=RunResponse) def run_bot(request: RunRequest): userId = ensureUserId(request.userId) message = Runner.run( userId, request.category, request.promptName, request.context, history=[turn.model_dump() for turn in request.history], modeHint=request.modeHint, ) return RunResponse( userId=userId, category=request.category, promptName=request.promptName, context=request.context, message=message, ) @router.get("/users/{userId}/notes", response_model=NotesResponse) def get_user_notes(userId: str, limit: int = 10): memory = MemoryManager.loadUserMemory(userId) notes = memory.get("notes", []) if limit > 0: notes = notes[-limit:] return NotesResponse(userId=userId, notes=notes) @router.post("/users/{userId}/notes", response_model=NotesResponse, status_code=201) def create_user_note(userId: str, payload: NoteCreate): cleaned_metadata = payload.metadata or {} MemoryManager.recordNote(userId, payload.note, cleaned_metadata) updated = MemoryManager.loadUserMemory(userId) return NotesResponse(userId=userId, notes=updated.get("notes", [])) @router.get("/users/{userId}/memory", response_model=MemoryResponse) def get_user_memory(userId: str): memory = MemoryManager.loadUserMemory(userId) return MemoryResponse( userId=userId, notes=memory.get("notes", []), summaries=memory.get("summaries", []), ) @router.get("/users/{userId}/actions", response_model=ActionItemsResponse) def list_user_actions(userId: str): actions = MemoryManager.listActionItems(userId) return ActionItemsResponse(userId=userId, action_items=actions) @router.post("/users/{userId}/actions", response_model=ActionItemsResponse, status_code=201) def create_user_action(userId: str, payload: ActionItemCreate): created = MemoryManager.createActionItem( userId, payload.title, cadence=payload.cadence, intervalMinutes=payload.interval_minutes, details=payload.details, ) if not created: raise HTTPException(status_code=400, detail="title is required") actions = MemoryManager.listActionItems(userId) return ActionItemsResponse(userId=userId, action_items=actions) @router.put("/users/{userId}/actions/{actionId}", response_model=ActionItemRecord) def update_user_action(userId: str, actionId: str, payload: ActionItemUpdate): updated = MemoryManager.updateActionItem( userId, actionId, { "title": payload.title, "details": payload.details, "cadence": payload.cadence, "interval_minutes": payload.interval_minutes, }, ) if not updated: raise HTTPException(status_code=404, detail="Action item not found") return ActionItemRecord.model_validate(updated) @router.delete("/users/{userId}/actions/{actionId}", status_code=204) def delete_user_action(userId: str, actionId: str): deleted = MemoryManager.deleteActionItem(userId, actionId) if not deleted: raise HTTPException(status_code=404, detail="Action item not found") return Response(status_code=204) @router.post( "/users/{userId}/actions/{actionId}/progress", response_model=ActionItemRecord, status_code=201, ) def add_action_progress(userId: str, actionId: str, payload: ActionItemProgressCreate): recorded = MemoryManager.recordActionProgress( userId, actionId, status=payload.status, note=payload.note, ) if not recorded: raise HTTPException(status_code=404, detail="Action item not found") action = resolve_action_or_404(userId, actionId) return ActionItemRecord.model_validate(action) @router.post("/prompts/reload", response_model=PromptCatalogResponse) def reload_prompts(): PromptLibrary.reloadCatalog() return PromptCatalogResponse(catalog=PromptLibrary.promptCatalog) @router.get("/prompts", response_model=PromptCatalogResponse) def list_prompts(): if not PromptLibrary.promptCatalog: PromptLibrary.reloadCatalog() return PromptCatalogResponse(catalog=PromptLibrary.promptCatalog) @router.post("/users/{userId}/notes/test", response_model=RunResponse) def force_note_capture(userId: str, payload: ContextRequest): """Helper endpoint for QA to trigger the welcome prompt with a custom context.""" message = Runner.run(userId, "general", "welcome", payload.context) return RunResponse( userId=userId, category="general", promptName="welcome", context=payload.context, message=message, ) def ensureUserId(userId: Optional[str]) -> str: resolved = userId or defaultUserId() if not resolved: raise HTTPException(status_code=400, detail="userId is required") return resolved def resolve_action_or_404(userId: str, actionId: str) -> Dict[str, Any]: actions = MemoryManager.listActionItems(userId) for action in actions: if action.get("id") == actionId: return action raise HTTPException(status_code=404, detail="Action item not found") app.include_router(router) app.include_router(router, prefix="/api")