Fix issues #6, #7, #11, #12, #13: med reminders, push notifications, auth persistence, scheduling conflicts

- Fix TIME object vs string comparison in scheduler preventing adaptive med
  reminders from ever firing (#12, #6)
- Add frequency filtering to midnight schedule creation for every_n_days meds
- Require start_date and interval_days for every_n_days medications
- Add refresh token support (30-day) to API and bot for persistent sessions (#13)
- Add "trusted device" checkbox to frontend login for long-lived sessions (#7)
- Auto-refresh expired tokens in both bot (apiRequest) and frontend (api.ts)
- Restore bot sessions from cache on restart using refresh tokens
- Duration-aware routine scheduling conflict detection (#11)
- Add conflict check when starting routine sessions against medication times
- Add diagnostic logging to notification delivery channels

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-19 13:05:48 -06:00
parent 6850abf7d2
commit d4adbde3df
10 changed files with 474 additions and 69 deletions

View File

@@ -75,11 +75,33 @@ def api_login():
return flask.jsonify({"error": "username and password required"}), 400
token = auth.getLoginToken(username, password)
if token:
return flask.jsonify({"token": token}), 200
response = {"token": token}
# Issue refresh token when trusted device is requested
if data.get("trust_device"):
import jwt as pyjwt
payload = pyjwt.decode(token, os.getenv("JWT_SECRET"), algorithms=["HS256"])
user_uuid = payload.get("sub")
if user_uuid:
response["refresh_token"] = auth.createRefreshToken(user_uuid)
return flask.jsonify(response), 200
else:
return flask.jsonify({"error": "invalid credentials"}), 401
@app.route("/api/refresh", methods=["POST"])
def api_refresh():
"""Exchange a refresh token for a new access token."""
data = flask.request.get_json()
refresh_token = data.get("refresh_token") if data else None
if not refresh_token:
return flask.jsonify({"error": "refresh_token required"}), 400
access_token, user_uuid = auth.refreshAccessToken(refresh_token)
if access_token:
return flask.jsonify({"token": access_token}), 200
else:
return flask.jsonify({"error": "invalid or expired refresh token"}), 401
# ── User Routes ────────────────────────────────────────────────────

View File

@@ -145,6 +145,17 @@ def register(app):
meds = postgres.select("medications", where={"user_uuid": user_uuid}, order_by="name")
return flask.jsonify(meds), 200
def _time_str_to_minutes(time_str):
"""Convert 'HH:MM' to minutes since midnight."""
parts = time_str.split(":")
return int(parts[0]) * 60 + int(parts[1])
def _get_routine_duration_minutes(routine_id):
"""Get total duration of a routine from its steps."""
steps = postgres.select("routine_steps", where={"routine_id": routine_id})
total = sum(s.get("duration_minutes", 0) or 0 for s in steps)
return max(total, 1)
def _check_med_schedule_conflicts(user_uuid, new_times, new_days=None, exclude_med_id=None):
"""Check if the proposed medication schedule conflicts with existing routines or medications.
Returns (has_conflict, conflict_message) tuple.
@@ -152,13 +163,23 @@ def register(app):
if not new_times:
return False, None
# Check conflicts with routines
# Check conflicts with routines (duration-aware)
user_routines = postgres.select("routines", {"user_uuid": user_uuid})
for r in user_routines:
sched = postgres.select_one("routine_schedules", {"routine_id": r["id"]})
if sched and sched.get("time") in new_times:
routine_days = json.loads(sched.get("days", "[]"))
if not new_days or any(d in routine_days for d in new_days):
if not sched or not sched.get("time"):
continue
routine_days = sched.get("days", [])
if isinstance(routine_days, str):
routine_days = json.loads(routine_days)
if new_days and not any(d in routine_days for d in new_days):
continue
routine_start = _time_str_to_minutes(sched["time"])
routine_dur = _get_routine_duration_minutes(r["id"])
for t in new_times:
med_start = _time_str_to_minutes(t)
# Med falls within routine time range
if routine_start <= med_start < routine_start + routine_dur:
return True, f"Time conflicts with routine: {r.get('name', 'Unnamed routine')}"
# Check conflicts with other medications
@@ -188,6 +209,11 @@ def register(app):
if missing:
return flask.jsonify({"error": f"missing required fields: {', '.join(missing)}"}), 400
# Validate every_n_days required fields
if data.get("frequency") == "every_n_days":
if not data.get("start_date") or not data.get("interval_days"):
return flask.jsonify({"error": "every_n_days frequency requires both start_date and interval_days"}), 400
# Check for schedule conflicts
new_times = data.get("times", [])
new_days = data.get("days_of_week", [])

View File

@@ -7,7 +7,7 @@ Routines have ordered steps. Users start sessions to walk through them.
import os
import uuid
import json
from datetime import datetime
from datetime import datetime, timedelta
import flask
import jwt
import core.auth as auth
@@ -420,6 +420,31 @@ def register(app):
return flask.jsonify(
{"error": "already have active session", "session_id": active["id"]}
), 409
# Check if starting now would conflict with medication times
now = tz.user_now()
current_time = now.strftime("%H:%M")
current_day = now.strftime("%a").lower()
routine_dur = _get_routine_duration_minutes(routine_id)
routine_start = _time_str_to_minutes(current_time)
user_meds = postgres.select("medications", {"user_uuid": user_uuid, "active": True})
for med in user_meds:
med_times = med.get("times", [])
if isinstance(med_times, str):
med_times = json.loads(med_times)
med_days = med.get("days_of_week", [])
if isinstance(med_days, str):
med_days = json.loads(med_days)
if med_days and current_day not in med_days:
continue
for mt in med_times:
med_start = _time_str_to_minutes(mt)
if _ranges_overlap(routine_start, routine_dur, med_start, 1):
return flask.jsonify(
{"error": f"Starting now would conflict with {med.get('name', 'medication')} at {mt}"}
), 409
steps = postgres.select(
"routine_steps",
where={"routine_id": routine_id},
@@ -649,23 +674,54 @@ def register(app):
)
return flask.jsonify(result), 200
def _check_schedule_conflicts(user_uuid, new_days, new_time, exclude_routine_id=None):
def _get_routine_duration_minutes(routine_id):
"""Get total duration of a routine from its steps."""
steps = postgres.select("routine_steps", where={"routine_id": routine_id})
total = sum(s.get("duration_minutes", 0) or 0 for s in steps)
return max(total, 1) # At least 1 minute
def _time_str_to_minutes(time_str):
"""Convert 'HH:MM' to minutes since midnight."""
parts = time_str.split(":")
return int(parts[0]) * 60 + int(parts[1])
def _ranges_overlap(start1, dur1, start2, dur2):
"""Check if two time ranges overlap (in minutes since midnight)."""
end1 = start1 + dur1
end2 = start2 + dur2
return start1 < end2 and start2 < end1
def _check_schedule_conflicts(user_uuid, new_days, new_time, exclude_routine_id=None, new_routine_id=None):
"""Check if the proposed schedule conflicts with existing routines or medications.
Returns (has_conflict, conflict_message) tuple.
"""
if not new_days or not new_time:
return False, None
new_start = _time_str_to_minutes(new_time)
# Get duration of the routine being scheduled
if new_routine_id:
new_dur = _get_routine_duration_minutes(new_routine_id)
else:
new_dur = 1
# Check conflicts with other routines
user_routines = postgres.select("routines", {"user_uuid": user_uuid})
for r in user_routines:
if r["id"] == exclude_routine_id:
continue
other_sched = postgres.select_one("routine_schedules", {"routine_id": r["id"]})
if other_sched and other_sched.get("time") == new_time:
other_days = json.loads(other_sched.get("days", "[]"))
if any(d in other_days for d in new_days):
return True, f"Time conflicts with routine: {r.get('name', 'Unnamed routine')}"
if not other_sched or not other_sched.get("time"):
continue
other_days = other_sched.get("days", [])
if isinstance(other_days, str):
other_days = json.loads(other_days)
if not any(d in other_days for d in new_days):
continue
other_start = _time_str_to_minutes(other_sched["time"])
other_dur = _get_routine_duration_minutes(r["id"])
if _ranges_overlap(new_start, new_dur, other_start, other_dur):
return True, f"Time conflicts with routine: {r.get('name', 'Unnamed routine')}"
# Check conflicts with medications
user_meds = postgres.select("medications", {"user_uuid": user_uuid, "active": True})
@@ -673,12 +729,16 @@ def register(app):
med_times = med.get("times", [])
if isinstance(med_times, str):
med_times = json.loads(med_times)
if new_time in med_times:
# Check if medication runs on any of the same days
med_days = med.get("days_of_week", [])
if isinstance(med_days, str):
med_days = json.loads(med_days)
if not med_days or any(d in med_days for d in new_days):
med_days = med.get("days_of_week", [])
if isinstance(med_days, str):
med_days = json.loads(med_days)
# If med has no specific days, it runs every day
if med_days and not any(d in med_days for d in new_days):
continue
for mt in med_times:
med_start = _time_str_to_minutes(mt)
# Medication takes ~0 minutes, but check if it falls within routine window
if _ranges_overlap(new_start, new_dur, med_start, 1):
return True, f"Time conflicts with medication: {med.get('name', 'Unnamed medication')}"
return False, None
@@ -702,7 +762,8 @@ def register(app):
new_days = data.get("days", [])
new_time = data.get("time")
has_conflict, conflict_msg = _check_schedule_conflicts(
user_uuid, new_days, new_time, exclude_routine_id=routine_id
user_uuid, new_days, new_time, exclude_routine_id=routine_id,
new_routine_id=routine_id,
)
if has_conflict:
return flask.jsonify({"error": conflict_msg}), 409

View File

@@ -116,21 +116,26 @@ class JurySystem:
print(f"Error loading DBT knowledge base: {e}")
raise
async def query(self, query_text):
"""Query the DBT knowledge base"""
try:
# Get embedding
response = self.client.embeddings.create(
model="qwen/qwen3-embedding-8b", input=query_text
)
query_emb = response.data[0].embedding
def _retrieve_sync(self, query_text, top_k=5):
"""Embed query and search vector store. Returns list of chunk dicts."""
response = self.client.embeddings.create(
model="qwen/qwen3-embedding-8b", input=query_text
)
query_emb = response.data[0].embedding
return self.vector_store.search(query_emb, top_k=top_k)
# Search
context_chunks = self.vector_store.search(query_emb, top_k=5)
async def retrieve(self, query_text, top_k=5):
"""Async retrieval — returns list of {metadata, score} dicts."""
import asyncio
return await asyncio.to_thread(self._retrieve_sync, query_text, top_k)
async def query(self, query_text):
"""Query the DBT knowledge base (legacy path, kept for compatibility)."""
try:
context_chunks = await self.retrieve(query_text)
if not context_chunks:
return "I couldn't find relevant DBT information for that query."
# Generate answer
context_text = "\n\n---\n\n".join(
[chunk["metadata"]["text"] for chunk in context_chunks]
)
@@ -140,20 +145,8 @@ Use the provided context from the DBT Skills Training Handouts to answer the use
If the answer is not in the context, say you don't know based on the provided text.
Be concise, compassionate, and practical."""
user_prompt = f"Context:\n{context_text}\n\nQuestion: {query_text}"
response = self.client.chat.completions.create(
model=self.config.get("models", {}).get(
"generator", "openai/gpt-4o-mini"
),
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.7,
)
return response.choices[0].message.content
from ai.jury_council import generate_rag_answer
return await generate_rag_answer(query_text, context_text, system_prompt)
except Exception as e:
return f"Error querying DBT knowledge base: {e}"
@@ -176,13 +169,18 @@ def decodeJwtPayload(token):
return json.loads(base64.urlsafe_b64decode(payload))
def apiRequest(method, endpoint, token=None, data=None):
def apiRequest(method, endpoint, token=None, data=None, _retried=False):
url = f"{API_URL}{endpoint}"
headers = {"Content-Type": "application/json"}
if token:
headers["Authorization"] = f"Bearer {token}"
try:
resp = getattr(requests, method)(url, headers=headers, json=data, timeout=10)
# Auto-refresh on 401 using refresh token
if resp.status_code == 401 and not _retried:
new_token = _try_refresh_token_for_session(token)
if new_token:
return apiRequest(method, endpoint, token=new_token, data=data, _retried=True)
try:
return resp.json(), resp.status_code
except ValueError:
@@ -191,6 +189,31 @@ def apiRequest(method, endpoint, token=None, data=None):
return {"error": "API unavailable"}, 503
def _try_refresh_token_for_session(expired_token):
"""Find the discord user with this token and refresh it using their refresh token."""
for discord_id, session in user_sessions.items():
if session.get("token") == expired_token:
refresh_token = session.get("refresh_token")
if not refresh_token:
# Check cache for refresh token
cached = getCachedUser(discord_id)
if cached:
refresh_token = cached.get("refresh_token")
if refresh_token:
result, status = apiRequest("post", "/api/refresh",
data={"refresh_token": refresh_token},
_retried=True)
if status == 200 and "token" in result:
new_token = result["token"]
session["token"] = new_token
# Update cache
cached = getCachedUser(discord_id) or {}
cached["refresh_token"] = refresh_token
setCachedUser(discord_id, cached)
return new_token
return None
def loadCache():
try:
if os.path.exists(CACHE_FILE):
@@ -229,14 +252,32 @@ def setCachedUser(discord_id, user_data):
def negotiateToken(discord_id, username, password):
cached = getCachedUser(discord_id)
# Try refresh token first (avoids sending password)
if cached and cached.get("refresh_token"):
result, status = apiRequest(
"post", "/api/refresh",
data={"refresh_token": cached["refresh_token"]},
_retried=True,
)
if status == 200 and "token" in result:
token = result["token"]
payload = decodeJwtPayload(token)
user_uuid = payload["sub"]
cached["user_uuid"] = user_uuid
setCachedUser(discord_id, cached)
return token, user_uuid
# Fall back to password login, always request refresh token (trust_device)
login_data = {"username": username, "password": password, "trust_device": True}
if (
cached
and cached.get("username") == username
and cached.get("hashed_password")
and verifyPassword(password, cached.get("hashed_password"))
):
result, status = apiRequest(
"post", "/api/login", data={"username": username, "password": password}
)
result, status = apiRequest("post", "/api/login", data=login_data, _retried=True)
if status == 200 and "token" in result:
token = result["token"]
payload = decodeJwtPayload(token)
@@ -247,14 +288,13 @@ def negotiateToken(discord_id, username, password):
"hashed_password": cached["hashed_password"],
"user_uuid": user_uuid,
"username": username,
"refresh_token": result.get("refresh_token"),
},
)
return token, user_uuid
return None, None
result, status = apiRequest(
"post", "/api/login", data={"username": username, "password": password}
)
result, status = apiRequest("post", "/api/login", data=login_data, _retried=True)
if status == 200 and "token" in result:
token = result["token"]
payload = decodeJwtPayload(token)
@@ -265,6 +305,7 @@ def negotiateToken(discord_id, username, password):
"hashed_password": hashPassword(password),
"user_uuid": user_uuid,
"username": username,
"refresh_token": result.get("refresh_token"),
},
)
return token, user_uuid
@@ -428,7 +469,7 @@ async def handleActiveSessionShortcuts(message, session, active_session):
async def handleDBTQuery(message):
"""Handle DBT-related queries using JurySystem"""
"""Handle DBT-related queries using JurySystem + jury council pipeline."""
if not jury_system:
return False
@@ -456,13 +497,66 @@ async def handleDBTQuery(message):
user_input_lower = message.content.lower()
is_dbt_query = any(keyword in user_input_lower for keyword in dbt_keywords)
if is_dbt_query:
async with message.channel.typing():
response = await jury_system.query(message.content)
await message.channel.send(f"🧠 **DBT Support:**\n{response}")
return True
if not is_dbt_query:
return False
return False
from ai.jury_council import (
generate_search_questions,
run_jury_filter,
generate_rag_answer,
split_for_discord,
)
async with message.channel.typing():
# Step 1: Generate candidate questions via Qwen Nitro (fallback: qwen3-235b)
candidates, gen_error = await generate_search_questions(message.content)
if gen_error:
await message.channel.send(f"⚠️ **Question generator failed:** {gen_error}")
return True
# Step 2: Jury council filters candidates → safe question JSON list
jury_result = await run_jury_filter(candidates, message.content)
breakdown = jury_result.format_breakdown()
# Always show the jury deliberation (verbose, as requested)
for chunk in split_for_discord(breakdown):
await message.channel.send(chunk)
if jury_result.has_error:
return True
if not jury_result.safe_questions:
return True
await message.channel.send("🔍 Searching knowledge base with approved questions...")
# Step 3: Multi-query retrieval — deduplicated by chunk ID
seen_ids = set()
context_chunks = []
for q in jury_result.safe_questions:
results = await jury_system.retrieve(q)
for r in results:
chunk_id = r["metadata"].get("id")
if chunk_id not in seen_ids:
seen_ids.add(chunk_id)
context_chunks.append(r["metadata"]["text"])
if not context_chunks:
await message.channel.send("⚠️ No relevant content found in the knowledge base.")
return True
context = "\n\n---\n\n".join(context_chunks)
# Step 4: Generate answer with qwen3-235b
system_prompt = """You are a helpful mental health support assistant with expertise in DBT (Dialectical Behavior Therapy).
Use the provided context to answer the user's question accurately and compassionately.
If the answer is not in the context, say so — do not invent information.
Be concise, practical, and supportive."""
answer = await generate_rag_answer(message.content, context, system_prompt)
await message.channel.send(f"🧠 **Response:**\n{answer}")
return True
async def routeCommand(message):
@@ -540,10 +634,38 @@ async def routeCommand(message):
)
def _restore_sessions_from_cache():
"""Try to restore user sessions from cached refresh tokens on startup."""
restored = 0
for discord_id, cached in user_cache.items():
refresh_token = cached.get("refresh_token")
if not refresh_token:
continue
result, status = apiRequest(
"post", "/api/refresh",
data={"refresh_token": refresh_token},
_retried=True,
)
if status == 200 and "token" in result:
token = result["token"]
payload = decodeJwtPayload(token)
user_uuid = payload["sub"]
user_sessions[discord_id] = {
"token": token,
"user_uuid": user_uuid,
"username": cached.get("username", ""),
"refresh_token": refresh_token,
}
restored += 1
if restored:
print(f"Restored {restored} user session(s) from cache")
@client.event
async def on_ready():
print(f"Bot logged in as {client.user}")
loadCache()
_restore_sessions_from_cache()
backgroundLoop.start()

View File

@@ -7,6 +7,16 @@ import datetime
import os
REFRESH_TOKEN_SECRET = None
def _get_refresh_secret():
global REFRESH_TOKEN_SECRET
if REFRESH_TOKEN_SECRET is None:
REFRESH_TOKEN_SECRET = os.getenv("JWT_SECRET", "") + "_refresh"
return REFRESH_TOKEN_SECRET
def verifyLoginToken(login_token, username=False, userUUID=False):
if username:
userUUID = users.getUserUUID(username)
@@ -49,6 +59,44 @@ def getLoginToken(username, password):
return False
def createRefreshToken(userUUID):
"""Create a long-lived refresh token (30 days)."""
payload = {
"sub": str(userUUID),
"type": "refresh",
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=30),
}
return jwt.encode(payload, _get_refresh_secret(), algorithm="HS256")
def refreshAccessToken(refresh_token):
"""Validate a refresh token and return a new access token + user_uuid.
Returns (access_token, user_uuid) or (None, None)."""
try:
decoded = jwt.decode(
refresh_token, _get_refresh_secret(), algorithms=["HS256"]
)
if decoded.get("type") != "refresh":
return None, None
user_uuid = decoded.get("sub")
if not user_uuid:
return None, None
# Verify user still exists
user = postgres.select_one("users", {"id": user_uuid})
if not user:
return None, None
# Create new access token
payload = {
"sub": user_uuid,
"name": user.get("first_name", ""),
"exp": datetime.datetime.utcnow() + datetime.timedelta(hours=1),
}
access_token = jwt.encode(payload, os.getenv("JWT_SECRET"), algorithm="HS256")
return access_token, user_uuid
except (ExpiredSignatureError, InvalidTokenError):
return None, None
def unregisterUser(userUUID, password):
pw_hash = getUserpasswordHash(userUUID)
if not pw_hash:

View File

@@ -18,18 +18,27 @@ logger = logging.getLogger(__name__)
def _sendToEnabledChannels(notif_settings, message, user_uuid=None):
"""Send message to all enabled channels. Returns True if at least one succeeded."""
sent = False
logger.info(f"Sending notification to user {user_uuid}: {message[:80]}")
if notif_settings.get("discord_enabled") and notif_settings.get("discord_user_id"):
if discord.send_dm(notif_settings["discord_user_id"], message):
logger.debug(f"Discord DM sent to {notif_settings['discord_user_id']}")
sent = True
if notif_settings.get("ntfy_enabled") and notif_settings.get("ntfy_topic"):
if ntfy.send(notif_settings["ntfy_topic"], message):
logger.debug(f"ntfy sent to topic {notif_settings['ntfy_topic']}")
sent = True
if notif_settings.get("web_push_enabled") and user_uuid:
if web_push.send_to_user(user_uuid, message):
logger.debug(f"Web push sent for user {user_uuid}")
sent = True
else:
logger.warning(f"Web push failed or no subscriptions for user {user_uuid}")
if not sent:
logger.warning(f"No notification channels succeeded for user {user_uuid}")
return sent

View File

@@ -5,9 +5,9 @@ Override poll_callback() with your domain-specific logic.
"""
import os
import time
import time as time_module
import logging
from datetime import datetime, timezone, timedelta
from datetime import datetime, timezone, timedelta, time as time_type
import core.postgres as postgres
import core.notifications as notifications
@@ -249,6 +249,12 @@ def check_adaptive_medication_reminders():
# Use base time
check_time = sched.get("base_time")
# Normalize TIME objects to "HH:MM" strings for comparison
if isinstance(check_time, time_type):
check_time = check_time.strftime("%H:%M")
elif check_time is not None:
check_time = str(check_time)[:5]
if check_time != current_time:
continue
@@ -367,6 +373,11 @@ def check_nagging():
display_time = sched.get("adjusted_time")
else:
display_time = sched.get("base_time")
# Normalize TIME objects for display
if isinstance(display_time, time_type):
display_time = display_time.strftime("%H:%M")
elif display_time is not None:
display_time = str(display_time)[:5]
# Send nag notification
user_settings = notifications.getNotificationSettings(user_uuid)
@@ -446,6 +457,39 @@ def _get_distinct_user_uuids():
return uuids
def _is_med_due_today(med, today):
"""Check if a medication is due on the given date based on its frequency."""
from datetime import date as date_type
freq = med.get("frequency", "daily")
if freq == "as_needed":
return False
if freq == "specific_days":
current_day = today.strftime("%a").lower()
med_days = med.get("days_of_week", [])
if current_day not in med_days:
return False
if freq == "every_n_days":
start = med.get("start_date")
interval = med.get("interval_days")
if start and interval:
start_d = (
start
if isinstance(start, date_type)
else datetime.strptime(str(start), "%Y-%m-%d").date()
)
days_since = (today - start_d).days
if days_since < 0 or days_since % interval != 0:
return False
else:
return False
return True
def _check_per_user_midnight_schedules():
"""Create daily adaptive schedules for each user when it's midnight in
their timezone (within the poll window)."""
@@ -453,10 +497,13 @@ def _check_per_user_midnight_schedules():
try:
now = _user_now_for(user_uuid)
if now.hour == 0 and now.minute < POLL_INTERVAL / 60:
today = now.date()
user_meds = postgres.select(
"medications", where={"user_uuid": user_uuid, "active": True}
)
for med in user_meds:
if not _is_med_due_today(med, today):
continue
times = med.get("times", [])
if times:
adaptive_meds.create_daily_schedule(
@@ -499,7 +546,7 @@ def daemon_loop():
poll_callback()
except Exception as e:
logger.error(f"Poll callback error: {e}")
time.sleep(POLL_INTERVAL)
time_module.sleep(POLL_INTERVAL)
if __name__ == "__main__":

View File

@@ -9,6 +9,7 @@ export default function LoginPage() {
const [isLogin, setIsLogin] = useState(true);
const [username, setUsername] = useState('');
const [password, setPassword] = useState('');
const [trustDevice, setTrustDevice] = useState(false);
const [error, setError] = useState('');
const [isLoading, setIsLoading] = useState(false);
const { login, register } = useAuth();
@@ -21,10 +22,10 @@ export default function LoginPage() {
try {
if (isLogin) {
await login(username, password);
await login(username, password, trustDevice);
} else {
await register(username, password);
await login(username, password);
await login(username, password, trustDevice);
}
router.push('/');
} catch (err) {
@@ -80,6 +81,18 @@ export default function LoginPage() {
/>
</div>
{isLogin && (
<label className="flex items-center gap-2 text-sm text-gray-600 dark:text-gray-400">
<input
type="checkbox"
checked={trustDevice}
onChange={(e) => setTrustDevice(e.target.checked)}
className="w-4 h-4 rounded border-gray-300 text-indigo-500 focus:ring-indigo-500"
/>
This is a trusted device
</label>
)}
<button
type="submit"
disabled={isLoading}

View File

@@ -9,7 +9,7 @@ interface AuthContextType {
token: string | null;
isLoading: boolean;
isAuthenticated: boolean;
login: (username: string, password: string) => Promise<void>;
login: (username: string, password: string, trustDevice?: boolean) => Promise<void>;
register: (username: string, password: string) => Promise<void>;
logout: () => void;
refreshUser: () => Promise<void>;
@@ -54,8 +54,8 @@ export function AuthProvider({ children }: { children: React.ReactNode }) {
refreshUser();
}, [refreshUser]);
const login = async (username: string, password: string) => {
const result = await api.auth.login(username, password);
const login = async (username: string, password: string, trustDevice = false) => {
const result = await api.auth.login(username, password, trustDevice);
const storedToken = api.auth.getToken();
setToken(storedToken);

View File

@@ -11,11 +11,57 @@ function setToken(token: string): void {
function clearToken(): void {
localStorage.removeItem('token');
localStorage.removeItem('refresh_token');
}
function getRefreshToken(): string | null {
if (typeof window === 'undefined') return null;
return localStorage.getItem('refresh_token');
}
function setRefreshToken(token: string): void {
localStorage.setItem('refresh_token', token);
}
let refreshPromise: Promise<boolean> | null = null;
async function tryRefreshToken(): Promise<boolean> {
// Deduplicate concurrent refresh attempts
if (refreshPromise) return refreshPromise;
refreshPromise = (async () => {
const refreshToken = getRefreshToken();
if (!refreshToken) return false;
try {
const resp = await fetch(`${API_URL}/api/refresh`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refresh_token: refreshToken }),
});
if (resp.ok) {
const data = await resp.json();
if (data.token) {
setToken(data.token);
return true;
}
}
// Refresh token is invalid/expired - clear everything
clearToken();
return false;
} catch {
return false;
} finally {
refreshPromise = null;
}
})();
return refreshPromise;
}
async function request<T>(
endpoint: string,
options: RequestInit = {}
options: RequestInit = {},
_retried = false,
): Promise<T> {
const token = getToken();
const headers: HeadersInit = {
@@ -31,6 +77,14 @@ async function request<T>(
headers,
});
// Auto-refresh on 401
if (response.status === 401 && !_retried) {
const refreshed = await tryRefreshToken();
if (refreshed) {
return request<T>(endpoint, options, true);
}
}
if (!response.ok) {
const body = await response.text();
let errorMsg = 'Request failed';
@@ -49,12 +103,15 @@ async function request<T>(
export const api = {
// Auth
auth: {
login: async (username: string, password: string) => {
const result = await request<{ token: string }>('/api/login', {
login: async (username: string, password: string, trustDevice = false) => {
const result = await request<{ token: string; refresh_token?: string }>('/api/login', {
method: 'POST',
body: JSON.stringify({ username, password }),
body: JSON.stringify({ username, password, trust_device: trustDevice }),
});
setToken(result.token);
if (result.refresh_token) {
setRefreshToken(result.refresh_token);
}
return result;
},