Fix issues #6, #7, #11, #12, #13: med reminders, push notifications, auth persistence, scheduling conflicts
- Fix TIME object vs string comparison in scheduler preventing adaptive med reminders from ever firing (#12, #6) - Add frequency filtering to midnight schedule creation for every_n_days meds - Require start_date and interval_days for every_n_days medications - Add refresh token support (30-day) to API and bot for persistent sessions (#13) - Add "trusted device" checkbox to frontend login for long-lived sessions (#7) - Auto-refresh expired tokens in both bot (apiRequest) and frontend (api.ts) - Restore bot sessions from cache on restart using refresh tokens - Duration-aware routine scheduling conflict detection (#11) - Add conflict check when starting routine sessions against medication times - Add diagnostic logging to notification delivery channels Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
24
api/main.py
24
api/main.py
@@ -75,11 +75,33 @@ def api_login():
|
||||
return flask.jsonify({"error": "username and password required"}), 400
|
||||
token = auth.getLoginToken(username, password)
|
||||
if token:
|
||||
return flask.jsonify({"token": token}), 200
|
||||
response = {"token": token}
|
||||
# Issue refresh token when trusted device is requested
|
||||
if data.get("trust_device"):
|
||||
import jwt as pyjwt
|
||||
payload = pyjwt.decode(token, os.getenv("JWT_SECRET"), algorithms=["HS256"])
|
||||
user_uuid = payload.get("sub")
|
||||
if user_uuid:
|
||||
response["refresh_token"] = auth.createRefreshToken(user_uuid)
|
||||
return flask.jsonify(response), 200
|
||||
else:
|
||||
return flask.jsonify({"error": "invalid credentials"}), 401
|
||||
|
||||
|
||||
@app.route("/api/refresh", methods=["POST"])
|
||||
def api_refresh():
|
||||
"""Exchange a refresh token for a new access token."""
|
||||
data = flask.request.get_json()
|
||||
refresh_token = data.get("refresh_token") if data else None
|
||||
if not refresh_token:
|
||||
return flask.jsonify({"error": "refresh_token required"}), 400
|
||||
access_token, user_uuid = auth.refreshAccessToken(refresh_token)
|
||||
if access_token:
|
||||
return flask.jsonify({"token": access_token}), 200
|
||||
else:
|
||||
return flask.jsonify({"error": "invalid or expired refresh token"}), 401
|
||||
|
||||
|
||||
# ── User Routes ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -145,6 +145,17 @@ def register(app):
|
||||
meds = postgres.select("medications", where={"user_uuid": user_uuid}, order_by="name")
|
||||
return flask.jsonify(meds), 200
|
||||
|
||||
def _time_str_to_minutes(time_str):
|
||||
"""Convert 'HH:MM' to minutes since midnight."""
|
||||
parts = time_str.split(":")
|
||||
return int(parts[0]) * 60 + int(parts[1])
|
||||
|
||||
def _get_routine_duration_minutes(routine_id):
|
||||
"""Get total duration of a routine from its steps."""
|
||||
steps = postgres.select("routine_steps", where={"routine_id": routine_id})
|
||||
total = sum(s.get("duration_minutes", 0) or 0 for s in steps)
|
||||
return max(total, 1)
|
||||
|
||||
def _check_med_schedule_conflicts(user_uuid, new_times, new_days=None, exclude_med_id=None):
|
||||
"""Check if the proposed medication schedule conflicts with existing routines or medications.
|
||||
Returns (has_conflict, conflict_message) tuple.
|
||||
@@ -152,13 +163,23 @@ def register(app):
|
||||
if not new_times:
|
||||
return False, None
|
||||
|
||||
# Check conflicts with routines
|
||||
# Check conflicts with routines (duration-aware)
|
||||
user_routines = postgres.select("routines", {"user_uuid": user_uuid})
|
||||
for r in user_routines:
|
||||
sched = postgres.select_one("routine_schedules", {"routine_id": r["id"]})
|
||||
if sched and sched.get("time") in new_times:
|
||||
routine_days = json.loads(sched.get("days", "[]"))
|
||||
if not new_days or any(d in routine_days for d in new_days):
|
||||
if not sched or not sched.get("time"):
|
||||
continue
|
||||
routine_days = sched.get("days", [])
|
||||
if isinstance(routine_days, str):
|
||||
routine_days = json.loads(routine_days)
|
||||
if new_days and not any(d in routine_days for d in new_days):
|
||||
continue
|
||||
routine_start = _time_str_to_minutes(sched["time"])
|
||||
routine_dur = _get_routine_duration_minutes(r["id"])
|
||||
for t in new_times:
|
||||
med_start = _time_str_to_minutes(t)
|
||||
# Med falls within routine time range
|
||||
if routine_start <= med_start < routine_start + routine_dur:
|
||||
return True, f"Time conflicts with routine: {r.get('name', 'Unnamed routine')}"
|
||||
|
||||
# Check conflicts with other medications
|
||||
@@ -188,6 +209,11 @@ def register(app):
|
||||
if missing:
|
||||
return flask.jsonify({"error": f"missing required fields: {', '.join(missing)}"}), 400
|
||||
|
||||
# Validate every_n_days required fields
|
||||
if data.get("frequency") == "every_n_days":
|
||||
if not data.get("start_date") or not data.get("interval_days"):
|
||||
return flask.jsonify({"error": "every_n_days frequency requires both start_date and interval_days"}), 400
|
||||
|
||||
# Check for schedule conflicts
|
||||
new_times = data.get("times", [])
|
||||
new_days = data.get("days_of_week", [])
|
||||
|
||||
@@ -7,7 +7,7 @@ Routines have ordered steps. Users start sessions to walk through them.
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import flask
|
||||
import jwt
|
||||
import core.auth as auth
|
||||
@@ -420,6 +420,31 @@ def register(app):
|
||||
return flask.jsonify(
|
||||
{"error": "already have active session", "session_id": active["id"]}
|
||||
), 409
|
||||
|
||||
# Check if starting now would conflict with medication times
|
||||
now = tz.user_now()
|
||||
current_time = now.strftime("%H:%M")
|
||||
current_day = now.strftime("%a").lower()
|
||||
routine_dur = _get_routine_duration_minutes(routine_id)
|
||||
routine_start = _time_str_to_minutes(current_time)
|
||||
|
||||
user_meds = postgres.select("medications", {"user_uuid": user_uuid, "active": True})
|
||||
for med in user_meds:
|
||||
med_times = med.get("times", [])
|
||||
if isinstance(med_times, str):
|
||||
med_times = json.loads(med_times)
|
||||
med_days = med.get("days_of_week", [])
|
||||
if isinstance(med_days, str):
|
||||
med_days = json.loads(med_days)
|
||||
if med_days and current_day not in med_days:
|
||||
continue
|
||||
for mt in med_times:
|
||||
med_start = _time_str_to_minutes(mt)
|
||||
if _ranges_overlap(routine_start, routine_dur, med_start, 1):
|
||||
return flask.jsonify(
|
||||
{"error": f"Starting now would conflict with {med.get('name', 'medication')} at {mt}"}
|
||||
), 409
|
||||
|
||||
steps = postgres.select(
|
||||
"routine_steps",
|
||||
where={"routine_id": routine_id},
|
||||
@@ -649,23 +674,54 @@ def register(app):
|
||||
)
|
||||
return flask.jsonify(result), 200
|
||||
|
||||
def _check_schedule_conflicts(user_uuid, new_days, new_time, exclude_routine_id=None):
|
||||
def _get_routine_duration_minutes(routine_id):
|
||||
"""Get total duration of a routine from its steps."""
|
||||
steps = postgres.select("routine_steps", where={"routine_id": routine_id})
|
||||
total = sum(s.get("duration_minutes", 0) or 0 for s in steps)
|
||||
return max(total, 1) # At least 1 minute
|
||||
|
||||
def _time_str_to_minutes(time_str):
|
||||
"""Convert 'HH:MM' to minutes since midnight."""
|
||||
parts = time_str.split(":")
|
||||
return int(parts[0]) * 60 + int(parts[1])
|
||||
|
||||
def _ranges_overlap(start1, dur1, start2, dur2):
|
||||
"""Check if two time ranges overlap (in minutes since midnight)."""
|
||||
end1 = start1 + dur1
|
||||
end2 = start2 + dur2
|
||||
return start1 < end2 and start2 < end1
|
||||
|
||||
def _check_schedule_conflicts(user_uuid, new_days, new_time, exclude_routine_id=None, new_routine_id=None):
|
||||
"""Check if the proposed schedule conflicts with existing routines or medications.
|
||||
Returns (has_conflict, conflict_message) tuple.
|
||||
"""
|
||||
if not new_days or not new_time:
|
||||
return False, None
|
||||
|
||||
new_start = _time_str_to_minutes(new_time)
|
||||
# Get duration of the routine being scheduled
|
||||
if new_routine_id:
|
||||
new_dur = _get_routine_duration_minutes(new_routine_id)
|
||||
else:
|
||||
new_dur = 1
|
||||
|
||||
# Check conflicts with other routines
|
||||
user_routines = postgres.select("routines", {"user_uuid": user_uuid})
|
||||
for r in user_routines:
|
||||
if r["id"] == exclude_routine_id:
|
||||
continue
|
||||
other_sched = postgres.select_one("routine_schedules", {"routine_id": r["id"]})
|
||||
if other_sched and other_sched.get("time") == new_time:
|
||||
other_days = json.loads(other_sched.get("days", "[]"))
|
||||
if any(d in other_days for d in new_days):
|
||||
return True, f"Time conflicts with routine: {r.get('name', 'Unnamed routine')}"
|
||||
if not other_sched or not other_sched.get("time"):
|
||||
continue
|
||||
other_days = other_sched.get("days", [])
|
||||
if isinstance(other_days, str):
|
||||
other_days = json.loads(other_days)
|
||||
if not any(d in other_days for d in new_days):
|
||||
continue
|
||||
other_start = _time_str_to_minutes(other_sched["time"])
|
||||
other_dur = _get_routine_duration_minutes(r["id"])
|
||||
if _ranges_overlap(new_start, new_dur, other_start, other_dur):
|
||||
return True, f"Time conflicts with routine: {r.get('name', 'Unnamed routine')}"
|
||||
|
||||
# Check conflicts with medications
|
||||
user_meds = postgres.select("medications", {"user_uuid": user_uuid, "active": True})
|
||||
@@ -673,12 +729,16 @@ def register(app):
|
||||
med_times = med.get("times", [])
|
||||
if isinstance(med_times, str):
|
||||
med_times = json.loads(med_times)
|
||||
if new_time in med_times:
|
||||
# Check if medication runs on any of the same days
|
||||
med_days = med.get("days_of_week", [])
|
||||
if isinstance(med_days, str):
|
||||
med_days = json.loads(med_days)
|
||||
if not med_days or any(d in med_days for d in new_days):
|
||||
med_days = med.get("days_of_week", [])
|
||||
if isinstance(med_days, str):
|
||||
med_days = json.loads(med_days)
|
||||
# If med has no specific days, it runs every day
|
||||
if med_days and not any(d in med_days for d in new_days):
|
||||
continue
|
||||
for mt in med_times:
|
||||
med_start = _time_str_to_minutes(mt)
|
||||
# Medication takes ~0 minutes, but check if it falls within routine window
|
||||
if _ranges_overlap(new_start, new_dur, med_start, 1):
|
||||
return True, f"Time conflicts with medication: {med.get('name', 'Unnamed medication')}"
|
||||
|
||||
return False, None
|
||||
@@ -702,7 +762,8 @@ def register(app):
|
||||
new_days = data.get("days", [])
|
||||
new_time = data.get("time")
|
||||
has_conflict, conflict_msg = _check_schedule_conflicts(
|
||||
user_uuid, new_days, new_time, exclude_routine_id=routine_id
|
||||
user_uuid, new_days, new_time, exclude_routine_id=routine_id,
|
||||
new_routine_id=routine_id,
|
||||
)
|
||||
if has_conflict:
|
||||
return flask.jsonify({"error": conflict_msg}), 409
|
||||
|
||||
200
bot/bot.py
200
bot/bot.py
@@ -116,21 +116,26 @@ class JurySystem:
|
||||
print(f"Error loading DBT knowledge base: {e}")
|
||||
raise
|
||||
|
||||
async def query(self, query_text):
|
||||
"""Query the DBT knowledge base"""
|
||||
try:
|
||||
# Get embedding
|
||||
response = self.client.embeddings.create(
|
||||
model="qwen/qwen3-embedding-8b", input=query_text
|
||||
)
|
||||
query_emb = response.data[0].embedding
|
||||
def _retrieve_sync(self, query_text, top_k=5):
|
||||
"""Embed query and search vector store. Returns list of chunk dicts."""
|
||||
response = self.client.embeddings.create(
|
||||
model="qwen/qwen3-embedding-8b", input=query_text
|
||||
)
|
||||
query_emb = response.data[0].embedding
|
||||
return self.vector_store.search(query_emb, top_k=top_k)
|
||||
|
||||
# Search
|
||||
context_chunks = self.vector_store.search(query_emb, top_k=5)
|
||||
async def retrieve(self, query_text, top_k=5):
|
||||
"""Async retrieval — returns list of {metadata, score} dicts."""
|
||||
import asyncio
|
||||
return await asyncio.to_thread(self._retrieve_sync, query_text, top_k)
|
||||
|
||||
async def query(self, query_text):
|
||||
"""Query the DBT knowledge base (legacy path, kept for compatibility)."""
|
||||
try:
|
||||
context_chunks = await self.retrieve(query_text)
|
||||
if not context_chunks:
|
||||
return "I couldn't find relevant DBT information for that query."
|
||||
|
||||
# Generate answer
|
||||
context_text = "\n\n---\n\n".join(
|
||||
[chunk["metadata"]["text"] for chunk in context_chunks]
|
||||
)
|
||||
@@ -140,20 +145,8 @@ Use the provided context from the DBT Skills Training Handouts to answer the use
|
||||
If the answer is not in the context, say you don't know based on the provided text.
|
||||
Be concise, compassionate, and practical."""
|
||||
|
||||
user_prompt = f"Context:\n{context_text}\n\nQuestion: {query_text}"
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config.get("models", {}).get(
|
||||
"generator", "openai/gpt-4o-mini"
|
||||
),
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
from ai.jury_council import generate_rag_answer
|
||||
return await generate_rag_answer(query_text, context_text, system_prompt)
|
||||
except Exception as e:
|
||||
return f"Error querying DBT knowledge base: {e}"
|
||||
|
||||
@@ -176,13 +169,18 @@ def decodeJwtPayload(token):
|
||||
return json.loads(base64.urlsafe_b64decode(payload))
|
||||
|
||||
|
||||
def apiRequest(method, endpoint, token=None, data=None):
|
||||
def apiRequest(method, endpoint, token=None, data=None, _retried=False):
|
||||
url = f"{API_URL}{endpoint}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
try:
|
||||
resp = getattr(requests, method)(url, headers=headers, json=data, timeout=10)
|
||||
# Auto-refresh on 401 using refresh token
|
||||
if resp.status_code == 401 and not _retried:
|
||||
new_token = _try_refresh_token_for_session(token)
|
||||
if new_token:
|
||||
return apiRequest(method, endpoint, token=new_token, data=data, _retried=True)
|
||||
try:
|
||||
return resp.json(), resp.status_code
|
||||
except ValueError:
|
||||
@@ -191,6 +189,31 @@ def apiRequest(method, endpoint, token=None, data=None):
|
||||
return {"error": "API unavailable"}, 503
|
||||
|
||||
|
||||
def _try_refresh_token_for_session(expired_token):
|
||||
"""Find the discord user with this token and refresh it using their refresh token."""
|
||||
for discord_id, session in user_sessions.items():
|
||||
if session.get("token") == expired_token:
|
||||
refresh_token = session.get("refresh_token")
|
||||
if not refresh_token:
|
||||
# Check cache for refresh token
|
||||
cached = getCachedUser(discord_id)
|
||||
if cached:
|
||||
refresh_token = cached.get("refresh_token")
|
||||
if refresh_token:
|
||||
result, status = apiRequest("post", "/api/refresh",
|
||||
data={"refresh_token": refresh_token},
|
||||
_retried=True)
|
||||
if status == 200 and "token" in result:
|
||||
new_token = result["token"]
|
||||
session["token"] = new_token
|
||||
# Update cache
|
||||
cached = getCachedUser(discord_id) or {}
|
||||
cached["refresh_token"] = refresh_token
|
||||
setCachedUser(discord_id, cached)
|
||||
return new_token
|
||||
return None
|
||||
|
||||
|
||||
def loadCache():
|
||||
try:
|
||||
if os.path.exists(CACHE_FILE):
|
||||
@@ -229,14 +252,32 @@ def setCachedUser(discord_id, user_data):
|
||||
|
||||
def negotiateToken(discord_id, username, password):
|
||||
cached = getCachedUser(discord_id)
|
||||
|
||||
# Try refresh token first (avoids sending password)
|
||||
if cached and cached.get("refresh_token"):
|
||||
result, status = apiRequest(
|
||||
"post", "/api/refresh",
|
||||
data={"refresh_token": cached["refresh_token"]},
|
||||
_retried=True,
|
||||
)
|
||||
if status == 200 and "token" in result:
|
||||
token = result["token"]
|
||||
payload = decodeJwtPayload(token)
|
||||
user_uuid = payload["sub"]
|
||||
cached["user_uuid"] = user_uuid
|
||||
setCachedUser(discord_id, cached)
|
||||
return token, user_uuid
|
||||
|
||||
# Fall back to password login, always request refresh token (trust_device)
|
||||
login_data = {"username": username, "password": password, "trust_device": True}
|
||||
|
||||
if (
|
||||
cached
|
||||
and cached.get("username") == username
|
||||
and cached.get("hashed_password")
|
||||
and verifyPassword(password, cached.get("hashed_password"))
|
||||
):
|
||||
result, status = apiRequest(
|
||||
"post", "/api/login", data={"username": username, "password": password}
|
||||
)
|
||||
result, status = apiRequest("post", "/api/login", data=login_data, _retried=True)
|
||||
if status == 200 and "token" in result:
|
||||
token = result["token"]
|
||||
payload = decodeJwtPayload(token)
|
||||
@@ -247,14 +288,13 @@ def negotiateToken(discord_id, username, password):
|
||||
"hashed_password": cached["hashed_password"],
|
||||
"user_uuid": user_uuid,
|
||||
"username": username,
|
||||
"refresh_token": result.get("refresh_token"),
|
||||
},
|
||||
)
|
||||
return token, user_uuid
|
||||
return None, None
|
||||
|
||||
result, status = apiRequest(
|
||||
"post", "/api/login", data={"username": username, "password": password}
|
||||
)
|
||||
result, status = apiRequest("post", "/api/login", data=login_data, _retried=True)
|
||||
if status == 200 and "token" in result:
|
||||
token = result["token"]
|
||||
payload = decodeJwtPayload(token)
|
||||
@@ -265,6 +305,7 @@ def negotiateToken(discord_id, username, password):
|
||||
"hashed_password": hashPassword(password),
|
||||
"user_uuid": user_uuid,
|
||||
"username": username,
|
||||
"refresh_token": result.get("refresh_token"),
|
||||
},
|
||||
)
|
||||
return token, user_uuid
|
||||
@@ -428,7 +469,7 @@ async def handleActiveSessionShortcuts(message, session, active_session):
|
||||
|
||||
|
||||
async def handleDBTQuery(message):
|
||||
"""Handle DBT-related queries using JurySystem"""
|
||||
"""Handle DBT-related queries using JurySystem + jury council pipeline."""
|
||||
if not jury_system:
|
||||
return False
|
||||
|
||||
@@ -456,13 +497,66 @@ async def handleDBTQuery(message):
|
||||
user_input_lower = message.content.lower()
|
||||
is_dbt_query = any(keyword in user_input_lower for keyword in dbt_keywords)
|
||||
|
||||
if is_dbt_query:
|
||||
async with message.channel.typing():
|
||||
response = await jury_system.query(message.content)
|
||||
await message.channel.send(f"🧠 **DBT Support:**\n{response}")
|
||||
return True
|
||||
if not is_dbt_query:
|
||||
return False
|
||||
|
||||
return False
|
||||
from ai.jury_council import (
|
||||
generate_search_questions,
|
||||
run_jury_filter,
|
||||
generate_rag_answer,
|
||||
split_for_discord,
|
||||
)
|
||||
|
||||
async with message.channel.typing():
|
||||
# Step 1: Generate candidate questions via Qwen Nitro (fallback: qwen3-235b)
|
||||
candidates, gen_error = await generate_search_questions(message.content)
|
||||
if gen_error:
|
||||
await message.channel.send(f"⚠️ **Question generator failed:** {gen_error}")
|
||||
return True
|
||||
|
||||
# Step 2: Jury council filters candidates → safe question JSON list
|
||||
jury_result = await run_jury_filter(candidates, message.content)
|
||||
breakdown = jury_result.format_breakdown()
|
||||
|
||||
# Always show the jury deliberation (verbose, as requested)
|
||||
for chunk in split_for_discord(breakdown):
|
||||
await message.channel.send(chunk)
|
||||
|
||||
if jury_result.has_error:
|
||||
return True
|
||||
|
||||
if not jury_result.safe_questions:
|
||||
return True
|
||||
|
||||
await message.channel.send("🔍 Searching knowledge base with approved questions...")
|
||||
|
||||
# Step 3: Multi-query retrieval — deduplicated by chunk ID
|
||||
seen_ids = set()
|
||||
context_chunks = []
|
||||
for q in jury_result.safe_questions:
|
||||
results = await jury_system.retrieve(q)
|
||||
for r in results:
|
||||
chunk_id = r["metadata"].get("id")
|
||||
if chunk_id not in seen_ids:
|
||||
seen_ids.add(chunk_id)
|
||||
context_chunks.append(r["metadata"]["text"])
|
||||
|
||||
if not context_chunks:
|
||||
await message.channel.send("⚠️ No relevant content found in the knowledge base.")
|
||||
return True
|
||||
|
||||
context = "\n\n---\n\n".join(context_chunks)
|
||||
|
||||
# Step 4: Generate answer with qwen3-235b
|
||||
system_prompt = """You are a helpful mental health support assistant with expertise in DBT (Dialectical Behavior Therapy).
|
||||
Use the provided context to answer the user's question accurately and compassionately.
|
||||
If the answer is not in the context, say so — do not invent information.
|
||||
Be concise, practical, and supportive."""
|
||||
|
||||
answer = await generate_rag_answer(message.content, context, system_prompt)
|
||||
|
||||
await message.channel.send(f"🧠 **Response:**\n{answer}")
|
||||
return True
|
||||
|
||||
|
||||
async def routeCommand(message):
|
||||
@@ -540,10 +634,38 @@ async def routeCommand(message):
|
||||
)
|
||||
|
||||
|
||||
def _restore_sessions_from_cache():
|
||||
"""Try to restore user sessions from cached refresh tokens on startup."""
|
||||
restored = 0
|
||||
for discord_id, cached in user_cache.items():
|
||||
refresh_token = cached.get("refresh_token")
|
||||
if not refresh_token:
|
||||
continue
|
||||
result, status = apiRequest(
|
||||
"post", "/api/refresh",
|
||||
data={"refresh_token": refresh_token},
|
||||
_retried=True,
|
||||
)
|
||||
if status == 200 and "token" in result:
|
||||
token = result["token"]
|
||||
payload = decodeJwtPayload(token)
|
||||
user_uuid = payload["sub"]
|
||||
user_sessions[discord_id] = {
|
||||
"token": token,
|
||||
"user_uuid": user_uuid,
|
||||
"username": cached.get("username", ""),
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
restored += 1
|
||||
if restored:
|
||||
print(f"Restored {restored} user session(s) from cache")
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Bot logged in as {client.user}")
|
||||
loadCache()
|
||||
_restore_sessions_from_cache()
|
||||
backgroundLoop.start()
|
||||
|
||||
|
||||
|
||||
48
core/auth.py
48
core/auth.py
@@ -7,6 +7,16 @@ import datetime
|
||||
import os
|
||||
|
||||
|
||||
REFRESH_TOKEN_SECRET = None
|
||||
|
||||
|
||||
def _get_refresh_secret():
|
||||
global REFRESH_TOKEN_SECRET
|
||||
if REFRESH_TOKEN_SECRET is None:
|
||||
REFRESH_TOKEN_SECRET = os.getenv("JWT_SECRET", "") + "_refresh"
|
||||
return REFRESH_TOKEN_SECRET
|
||||
|
||||
|
||||
def verifyLoginToken(login_token, username=False, userUUID=False):
|
||||
if username:
|
||||
userUUID = users.getUserUUID(username)
|
||||
@@ -49,6 +59,44 @@ def getLoginToken(username, password):
|
||||
return False
|
||||
|
||||
|
||||
def createRefreshToken(userUUID):
|
||||
"""Create a long-lived refresh token (30 days)."""
|
||||
payload = {
|
||||
"sub": str(userUUID),
|
||||
"type": "refresh",
|
||||
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=30),
|
||||
}
|
||||
return jwt.encode(payload, _get_refresh_secret(), algorithm="HS256")
|
||||
|
||||
|
||||
def refreshAccessToken(refresh_token):
|
||||
"""Validate a refresh token and return a new access token + user_uuid.
|
||||
Returns (access_token, user_uuid) or (None, None)."""
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
refresh_token, _get_refresh_secret(), algorithms=["HS256"]
|
||||
)
|
||||
if decoded.get("type") != "refresh":
|
||||
return None, None
|
||||
user_uuid = decoded.get("sub")
|
||||
if not user_uuid:
|
||||
return None, None
|
||||
# Verify user still exists
|
||||
user = postgres.select_one("users", {"id": user_uuid})
|
||||
if not user:
|
||||
return None, None
|
||||
# Create new access token
|
||||
payload = {
|
||||
"sub": user_uuid,
|
||||
"name": user.get("first_name", ""),
|
||||
"exp": datetime.datetime.utcnow() + datetime.timedelta(hours=1),
|
||||
}
|
||||
access_token = jwt.encode(payload, os.getenv("JWT_SECRET"), algorithm="HS256")
|
||||
return access_token, user_uuid
|
||||
except (ExpiredSignatureError, InvalidTokenError):
|
||||
return None, None
|
||||
|
||||
|
||||
def unregisterUser(userUUID, password):
|
||||
pw_hash = getUserpasswordHash(userUUID)
|
||||
if not pw_hash:
|
||||
|
||||
@@ -18,18 +18,27 @@ logger = logging.getLogger(__name__)
|
||||
def _sendToEnabledChannels(notif_settings, message, user_uuid=None):
|
||||
"""Send message to all enabled channels. Returns True if at least one succeeded."""
|
||||
sent = False
|
||||
logger.info(f"Sending notification to user {user_uuid}: {message[:80]}")
|
||||
|
||||
if notif_settings.get("discord_enabled") and notif_settings.get("discord_user_id"):
|
||||
if discord.send_dm(notif_settings["discord_user_id"], message):
|
||||
logger.debug(f"Discord DM sent to {notif_settings['discord_user_id']}")
|
||||
sent = True
|
||||
|
||||
if notif_settings.get("ntfy_enabled") and notif_settings.get("ntfy_topic"):
|
||||
if ntfy.send(notif_settings["ntfy_topic"], message):
|
||||
logger.debug(f"ntfy sent to topic {notif_settings['ntfy_topic']}")
|
||||
sent = True
|
||||
|
||||
if notif_settings.get("web_push_enabled") and user_uuid:
|
||||
if web_push.send_to_user(user_uuid, message):
|
||||
logger.debug(f"Web push sent for user {user_uuid}")
|
||||
sent = True
|
||||
else:
|
||||
logger.warning(f"Web push failed or no subscriptions for user {user_uuid}")
|
||||
|
||||
if not sent:
|
||||
logger.warning(f"No notification channels succeeded for user {user_uuid}")
|
||||
|
||||
return sent
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ Override poll_callback() with your domain-specific logic.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import time as time_module
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from datetime import datetime, timezone, timedelta, time as time_type
|
||||
|
||||
import core.postgres as postgres
|
||||
import core.notifications as notifications
|
||||
@@ -249,6 +249,12 @@ def check_adaptive_medication_reminders():
|
||||
# Use base time
|
||||
check_time = sched.get("base_time")
|
||||
|
||||
# Normalize TIME objects to "HH:MM" strings for comparison
|
||||
if isinstance(check_time, time_type):
|
||||
check_time = check_time.strftime("%H:%M")
|
||||
elif check_time is not None:
|
||||
check_time = str(check_time)[:5]
|
||||
|
||||
if check_time != current_time:
|
||||
continue
|
||||
|
||||
@@ -367,6 +373,11 @@ def check_nagging():
|
||||
display_time = sched.get("adjusted_time")
|
||||
else:
|
||||
display_time = sched.get("base_time")
|
||||
# Normalize TIME objects for display
|
||||
if isinstance(display_time, time_type):
|
||||
display_time = display_time.strftime("%H:%M")
|
||||
elif display_time is not None:
|
||||
display_time = str(display_time)[:5]
|
||||
|
||||
# Send nag notification
|
||||
user_settings = notifications.getNotificationSettings(user_uuid)
|
||||
@@ -446,6 +457,39 @@ def _get_distinct_user_uuids():
|
||||
return uuids
|
||||
|
||||
|
||||
def _is_med_due_today(med, today):
|
||||
"""Check if a medication is due on the given date based on its frequency."""
|
||||
from datetime import date as date_type
|
||||
|
||||
freq = med.get("frequency", "daily")
|
||||
|
||||
if freq == "as_needed":
|
||||
return False
|
||||
|
||||
if freq == "specific_days":
|
||||
current_day = today.strftime("%a").lower()
|
||||
med_days = med.get("days_of_week", [])
|
||||
if current_day not in med_days:
|
||||
return False
|
||||
|
||||
if freq == "every_n_days":
|
||||
start = med.get("start_date")
|
||||
interval = med.get("interval_days")
|
||||
if start and interval:
|
||||
start_d = (
|
||||
start
|
||||
if isinstance(start, date_type)
|
||||
else datetime.strptime(str(start), "%Y-%m-%d").date()
|
||||
)
|
||||
days_since = (today - start_d).days
|
||||
if days_since < 0 or days_since % interval != 0:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _check_per_user_midnight_schedules():
|
||||
"""Create daily adaptive schedules for each user when it's midnight in
|
||||
their timezone (within the poll window)."""
|
||||
@@ -453,10 +497,13 @@ def _check_per_user_midnight_schedules():
|
||||
try:
|
||||
now = _user_now_for(user_uuid)
|
||||
if now.hour == 0 and now.minute < POLL_INTERVAL / 60:
|
||||
today = now.date()
|
||||
user_meds = postgres.select(
|
||||
"medications", where={"user_uuid": user_uuid, "active": True}
|
||||
)
|
||||
for med in user_meds:
|
||||
if not _is_med_due_today(med, today):
|
||||
continue
|
||||
times = med.get("times", [])
|
||||
if times:
|
||||
adaptive_meds.create_daily_schedule(
|
||||
@@ -499,7 +546,7 @@ def daemon_loop():
|
||||
poll_callback()
|
||||
except Exception as e:
|
||||
logger.error(f"Poll callback error: {e}")
|
||||
time.sleep(POLL_INTERVAL)
|
||||
time_module.sleep(POLL_INTERVAL)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -9,6 +9,7 @@ export default function LoginPage() {
|
||||
const [isLogin, setIsLogin] = useState(true);
|
||||
const [username, setUsername] = useState('');
|
||||
const [password, setPassword] = useState('');
|
||||
const [trustDevice, setTrustDevice] = useState(false);
|
||||
const [error, setError] = useState('');
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const { login, register } = useAuth();
|
||||
@@ -21,10 +22,10 @@ export default function LoginPage() {
|
||||
|
||||
try {
|
||||
if (isLogin) {
|
||||
await login(username, password);
|
||||
await login(username, password, trustDevice);
|
||||
} else {
|
||||
await register(username, password);
|
||||
await login(username, password);
|
||||
await login(username, password, trustDevice);
|
||||
}
|
||||
router.push('/');
|
||||
} catch (err) {
|
||||
@@ -80,6 +81,18 @@ export default function LoginPage() {
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isLogin && (
|
||||
<label className="flex items-center gap-2 text-sm text-gray-600 dark:text-gray-400">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={trustDevice}
|
||||
onChange={(e) => setTrustDevice(e.target.checked)}
|
||||
className="w-4 h-4 rounded border-gray-300 text-indigo-500 focus:ring-indigo-500"
|
||||
/>
|
||||
This is a trusted device
|
||||
</label>
|
||||
)}
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isLoading}
|
||||
|
||||
@@ -9,7 +9,7 @@ interface AuthContextType {
|
||||
token: string | null;
|
||||
isLoading: boolean;
|
||||
isAuthenticated: boolean;
|
||||
login: (username: string, password: string) => Promise<void>;
|
||||
login: (username: string, password: string, trustDevice?: boolean) => Promise<void>;
|
||||
register: (username: string, password: string) => Promise<void>;
|
||||
logout: () => void;
|
||||
refreshUser: () => Promise<void>;
|
||||
@@ -54,8 +54,8 @@ export function AuthProvider({ children }: { children: React.ReactNode }) {
|
||||
refreshUser();
|
||||
}, [refreshUser]);
|
||||
|
||||
const login = async (username: string, password: string) => {
|
||||
const result = await api.auth.login(username, password);
|
||||
const login = async (username: string, password: string, trustDevice = false) => {
|
||||
const result = await api.auth.login(username, password, trustDevice);
|
||||
const storedToken = api.auth.getToken();
|
||||
setToken(storedToken);
|
||||
|
||||
|
||||
@@ -11,11 +11,57 @@ function setToken(token: string): void {
|
||||
|
||||
function clearToken(): void {
|
||||
localStorage.removeItem('token');
|
||||
localStorage.removeItem('refresh_token');
|
||||
}
|
||||
|
||||
function getRefreshToken(): string | null {
|
||||
if (typeof window === 'undefined') return null;
|
||||
return localStorage.getItem('refresh_token');
|
||||
}
|
||||
|
||||
function setRefreshToken(token: string): void {
|
||||
localStorage.setItem('refresh_token', token);
|
||||
}
|
||||
|
||||
let refreshPromise: Promise<boolean> | null = null;
|
||||
|
||||
async function tryRefreshToken(): Promise<boolean> {
|
||||
// Deduplicate concurrent refresh attempts
|
||||
if (refreshPromise) return refreshPromise;
|
||||
|
||||
refreshPromise = (async () => {
|
||||
const refreshToken = getRefreshToken();
|
||||
if (!refreshToken) return false;
|
||||
try {
|
||||
const resp = await fetch(`${API_URL}/api/refresh`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refresh_token: refreshToken }),
|
||||
});
|
||||
if (resp.ok) {
|
||||
const data = await resp.json();
|
||||
if (data.token) {
|
||||
setToken(data.token);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// Refresh token is invalid/expired - clear everything
|
||||
clearToken();
|
||||
return false;
|
||||
} catch {
|
||||
return false;
|
||||
} finally {
|
||||
refreshPromise = null;
|
||||
}
|
||||
})();
|
||||
|
||||
return refreshPromise;
|
||||
}
|
||||
|
||||
async function request<T>(
|
||||
endpoint: string,
|
||||
options: RequestInit = {}
|
||||
options: RequestInit = {},
|
||||
_retried = false,
|
||||
): Promise<T> {
|
||||
const token = getToken();
|
||||
const headers: HeadersInit = {
|
||||
@@ -31,6 +77,14 @@ async function request<T>(
|
||||
headers,
|
||||
});
|
||||
|
||||
// Auto-refresh on 401
|
||||
if (response.status === 401 && !_retried) {
|
||||
const refreshed = await tryRefreshToken();
|
||||
if (refreshed) {
|
||||
return request<T>(endpoint, options, true);
|
||||
}
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const body = await response.text();
|
||||
let errorMsg = 'Request failed';
|
||||
@@ -49,12 +103,15 @@ async function request<T>(
|
||||
export const api = {
|
||||
// Auth
|
||||
auth: {
|
||||
login: async (username: string, password: string) => {
|
||||
const result = await request<{ token: string }>('/api/login', {
|
||||
login: async (username: string, password: string, trustDevice = false) => {
|
||||
const result = await request<{ token: string; refresh_token?: string }>('/api/login', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ username, password }),
|
||||
body: JSON.stringify({ username, password, trust_device: trustDevice }),
|
||||
});
|
||||
setToken(result.token);
|
||||
if (result.refresh_token) {
|
||||
setRefreshToken(result.refresh_token);
|
||||
}
|
||||
return result;
|
||||
},
|
||||
|
||||
|
||||
Reference in New Issue
Block a user