Files
ADHDbot/api.py
2025-11-11 23:11:59 -06:00

259 lines
7.5 KiB
Python

import os
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, FastAPI, HTTPException, Response
from pydantic import BaseModel, Field
from Runner import Runner
from Memory import MemoryManager
from PromptLibrary import PromptLibrary
app = FastAPI(title="ADHDbot API")
router = APIRouter()
class ChatTurn(BaseModel):
role: str
content: str
class RunRequest(BaseModel):
userId: Optional[str] = None
category: str = "general"
promptName: str = "welcome"
context: str = "API triggered run"
history: List[ChatTurn] = Field(default_factory=list)
modeHint: Optional[str] = None
class RunResponse(BaseModel):
userId: Optional[str]
category: str
promptName: str
context: str
message: str
class NoteCreate(BaseModel):
note: str = Field(..., min_length=1, description="Raw text of the note")
metadata: Dict[str, Any] = Field(default_factory=dict)
class NotesResponse(BaseModel):
userId: str
notes: List[Dict[str, Any]]
class MemoryResponse(BaseModel):
userId: str
summaries: List[Dict[str, Any]]
notes: List[Dict[str, Any]]
class PromptCatalogResponse(BaseModel):
catalog: Dict[str, Dict[str, str]]
class ContextRequest(BaseModel):
context: str = Field(..., min_length=1)
class ActionItemRecord(BaseModel):
id: str
title: str
cadence: str
details: Optional[str] = None
interval_minutes: Optional[int] = None
created_at: str
updated_at: str
progress: List[Dict[str, Any]] = Field(default_factory=list)
class ActionItemCreate(BaseModel):
title: str = Field(..., min_length=1)
cadence: str = Field(default="daily")
details: Optional[str] = None
interval_minutes: Optional[int] = Field(default=None, ge=0)
class ActionItemUpdate(BaseModel):
title: Optional[str] = Field(default=None, min_length=1)
cadence: Optional[str] = None
details: Optional[str] = None
interval_minutes: Optional[int] = Field(default=None, ge=0)
class ActionItemProgressCreate(BaseModel):
status: str = Field(default="update", min_length=1)
note: Optional[str] = None
class ActionItemsResponse(BaseModel):
userId: str
action_items: List[ActionItemRecord]
def defaultUserId():
return os.getenv("TARGET_USER_ID")
@router.get("/health")
def health():
return {"status": "ok"}
@router.post("/run", response_model=RunResponse)
def run_bot(request: RunRequest):
userId = ensureUserId(request.userId)
message = Runner.run(
userId,
request.category,
request.promptName,
request.context,
history=[turn.model_dump() for turn in request.history],
modeHint=request.modeHint,
)
return RunResponse(
userId=userId,
category=request.category,
promptName=request.promptName,
context=request.context,
message=message,
)
@router.get("/users/{userId}/notes", response_model=NotesResponse)
def get_user_notes(userId: str, limit: int = 10):
memory = MemoryManager.loadUserMemory(userId)
notes = memory.get("notes", [])
if limit > 0:
notes = notes[-limit:]
return NotesResponse(userId=userId, notes=notes)
@router.post("/users/{userId}/notes", response_model=NotesResponse, status_code=201)
def create_user_note(userId: str, payload: NoteCreate):
cleaned_metadata = payload.metadata or {}
MemoryManager.recordNote(userId, payload.note, cleaned_metadata)
updated = MemoryManager.loadUserMemory(userId)
return NotesResponse(userId=userId, notes=updated.get("notes", []))
@router.get("/users/{userId}/memory", response_model=MemoryResponse)
def get_user_memory(userId: str):
memory = MemoryManager.loadUserMemory(userId)
return MemoryResponse(
userId=userId,
notes=memory.get("notes", []),
summaries=memory.get("summaries", []),
)
@router.get("/users/{userId}/actions", response_model=ActionItemsResponse)
def list_user_actions(userId: str):
actions = MemoryManager.listActionItems(userId)
return ActionItemsResponse(userId=userId, action_items=actions)
@router.post("/users/{userId}/actions", response_model=ActionItemsResponse, status_code=201)
def create_user_action(userId: str, payload: ActionItemCreate):
created = MemoryManager.createActionItem(
userId,
payload.title,
cadence=payload.cadence,
intervalMinutes=payload.interval_minutes,
details=payload.details,
)
if not created:
raise HTTPException(status_code=400, detail="title is required")
actions = MemoryManager.listActionItems(userId)
return ActionItemsResponse(userId=userId, action_items=actions)
@router.put("/users/{userId}/actions/{actionId}", response_model=ActionItemRecord)
def update_user_action(userId: str, actionId: str, payload: ActionItemUpdate):
updated = MemoryManager.updateActionItem(
userId,
actionId,
{
"title": payload.title,
"details": payload.details,
"cadence": payload.cadence,
"interval_minutes": payload.interval_minutes,
},
)
if not updated:
raise HTTPException(status_code=404, detail="Action item not found")
return ActionItemRecord.model_validate(updated)
@router.delete("/users/{userId}/actions/{actionId}", status_code=204)
def delete_user_action(userId: str, actionId: str):
deleted = MemoryManager.deleteActionItem(userId, actionId)
if not deleted:
raise HTTPException(status_code=404, detail="Action item not found")
return Response(status_code=204)
@router.post(
"/users/{userId}/actions/{actionId}/progress",
response_model=ActionItemRecord,
status_code=201,
)
def add_action_progress(userId: str, actionId: str, payload: ActionItemProgressCreate):
recorded = MemoryManager.recordActionProgress(
userId,
actionId,
status=payload.status,
note=payload.note,
)
if not recorded:
raise HTTPException(status_code=404, detail="Action item not found")
action = resolve_action_or_404(userId, actionId)
return ActionItemRecord.model_validate(action)
@router.post("/prompts/reload", response_model=PromptCatalogResponse)
def reload_prompts():
PromptLibrary.reloadCatalog()
return PromptCatalogResponse(catalog=PromptLibrary.promptCatalog)
@router.get("/prompts", response_model=PromptCatalogResponse)
def list_prompts():
if not PromptLibrary.promptCatalog:
PromptLibrary.reloadCatalog()
return PromptCatalogResponse(catalog=PromptLibrary.promptCatalog)
@router.post("/users/{userId}/notes/test", response_model=RunResponse)
def force_note_capture(userId: str, payload: ContextRequest):
"""Helper endpoint for QA to trigger the welcome prompt with a custom context."""
message = Runner.run(userId, "general", "welcome", payload.context)
return RunResponse(
userId=userId,
category="general",
promptName="welcome",
context=payload.context,
message=message,
)
def ensureUserId(userId: Optional[str]) -> str:
resolved = userId or defaultUserId()
if not resolved:
raise HTTPException(status_code=400, detail="userId is required")
return resolved
def resolve_action_or_404(userId: str, actionId: str) -> Dict[str, Any]:
actions = MemoryManager.listActionItems(userId)
for action in actions:
if action.get("id") == actionId:
return action
raise HTTPException(status_code=404, detail="Action item not found")
app.include_router(router)
app.include_router(router, prefix="/api")