265 lines
7.2 KiB
Python
265 lines
7.2 KiB
Python
"""
|
|
postgres.py - Generic PostgreSQL CRUD module
|
|
|
|
Requires: pip install psycopg2-binary
|
|
|
|
Connection config from environment:
|
|
DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASS
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
from contextlib import contextmanager
|
|
|
|
|
|
def _get_config():
|
|
return {
|
|
"host": os.environ.get("DB_HOST", "localhost"),
|
|
"port": int(os.environ.get("DB_PORT", 5432)),
|
|
"dbname": os.environ.get("DB_NAME", "app"),
|
|
"user": os.environ.get("DB_USER", "app"),
|
|
"password": os.environ.get("DB_PASS", ""),
|
|
}
|
|
|
|
|
|
def _safe_id(name):
|
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name):
|
|
raise ValueError(f"Invalid SQL identifier: {name}")
|
|
return f'"{name}"'
|
|
|
|
|
|
def _build_where(where, prefix=""):
|
|
clauses = []
|
|
params = {}
|
|
for i, (col, val) in enumerate(where.items()):
|
|
param_name = f"{prefix}{col}_{i}"
|
|
safe_col = _safe_id(col)
|
|
|
|
if isinstance(val, tuple) and len(val) == 2:
|
|
op, operand = val
|
|
op = op.upper()
|
|
allowed = {
|
|
"=",
|
|
"!=",
|
|
"<",
|
|
">",
|
|
"<=",
|
|
">=",
|
|
"LIKE",
|
|
"ILIKE",
|
|
"IN",
|
|
"IS",
|
|
"IS NOT",
|
|
}
|
|
if op not in allowed:
|
|
raise ValueError(f"Unsupported operator: {op}")
|
|
if op == "IN":
|
|
ph = ", ".join(f"%({param_name}_{j})s" for j in range(len(operand)))
|
|
clauses.append(f"{safe_col} IN ({ph})")
|
|
for j, item in enumerate(operand):
|
|
params[f"{param_name}_{j}"] = item
|
|
elif op in ("IS", "IS NOT"):
|
|
clauses.append(f"{safe_col} {op} NULL")
|
|
else:
|
|
clauses.append(f"{safe_col} {op} %({param_name})s")
|
|
params[param_name] = operand
|
|
elif val is None:
|
|
clauses.append(f"{safe_col} IS NULL")
|
|
else:
|
|
clauses.append(f"{safe_col} = %({param_name})s")
|
|
params[param_name] = val
|
|
|
|
return " AND ".join(clauses), params
|
|
|
|
|
|
@contextmanager
|
|
def get_connection():
|
|
conn = psycopg2.connect(**_get_config())
|
|
try:
|
|
yield conn
|
|
conn.commit()
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
@contextmanager
|
|
def get_cursor(dict_cursor=True):
|
|
with get_connection() as conn:
|
|
factory = psycopg2.extras.RealDictCursor if dict_cursor else None
|
|
cur = conn.cursor(cursor_factory=factory)
|
|
try:
|
|
yield cur
|
|
finally:
|
|
cur.close()
|
|
|
|
|
|
def insert(table, data):
|
|
columns = list(data.keys())
|
|
placeholders = [f"%({col})s" for col in columns]
|
|
safe_cols = [_safe_id(c) for c in columns]
|
|
|
|
query = f"""
|
|
INSERT INTO {_safe_id(table)}
|
|
({", ".join(safe_cols)})
|
|
VALUES ({", ".join(placeholders)})
|
|
RETURNING *
|
|
"""
|
|
with get_cursor() as cur:
|
|
cur.execute(query, data)
|
|
return dict(cur.fetchone()) if cur.rowcount else None
|
|
|
|
|
|
def select(table, where=None, order_by=None, limit=None, offset=None):
|
|
query = f"SELECT * FROM {_safe_id(table)}"
|
|
params = {}
|
|
|
|
if where:
|
|
clauses, params = _build_where(where)
|
|
query += f" WHERE {clauses}"
|
|
if order_by:
|
|
if isinstance(order_by, list):
|
|
order_by = ", ".join(order_by)
|
|
query += f" ORDER BY {order_by}"
|
|
if limit is not None:
|
|
query += f" LIMIT {int(limit)}"
|
|
if offset is not None:
|
|
query += f" OFFSET {int(offset)}"
|
|
|
|
with get_cursor() as cur:
|
|
cur.execute(query, params)
|
|
return [dict(row) for row in cur.fetchall()]
|
|
|
|
|
|
def select_one(table, where):
|
|
results = select(table, where=where, limit=1)
|
|
return results[0] if results else None
|
|
|
|
|
|
def update(table, data, where):
|
|
set_columns = list(data.keys())
|
|
set_clause = ", ".join(f"{_safe_id(col)} = %(set_{col})s" for col in set_columns)
|
|
params = {f"set_{col}": val for col, val in data.items()}
|
|
|
|
where_clause, where_params = _build_where(where, prefix="where_")
|
|
params.update(where_params)
|
|
|
|
query = f"""
|
|
UPDATE {_safe_id(table)}
|
|
SET {set_clause}
|
|
WHERE {where_clause}
|
|
RETURNING *
|
|
"""
|
|
with get_cursor() as cur:
|
|
cur.execute(query, params)
|
|
return [dict(row) for row in cur.fetchall()]
|
|
|
|
|
|
def delete(table, where):
|
|
where_clause, params = _build_where(where)
|
|
query = f"""
|
|
DELETE FROM {_safe_id(table)}
|
|
WHERE {where_clause}
|
|
RETURNING *
|
|
"""
|
|
with get_cursor() as cur:
|
|
cur.execute(query, params)
|
|
return [dict(row) for row in cur.fetchall()]
|
|
|
|
|
|
def count(table, where=None):
|
|
query = f"SELECT COUNT(*) as count FROM {_safe_id(table)}"
|
|
params = {}
|
|
if where:
|
|
clauses, params = _build_where(where)
|
|
query += f" WHERE {clauses}"
|
|
with get_cursor() as cur:
|
|
cur.execute(query, params)
|
|
return cur.fetchone()["count"]
|
|
|
|
|
|
def exists(table, where):
|
|
return count(table, where) > 0
|
|
|
|
|
|
def upsert(table, data, conflict_columns):
|
|
columns = list(data.keys())
|
|
placeholders = [f"%({col})s" for col in columns]
|
|
safe_cols = [_safe_id(c) for c in columns]
|
|
conflict_cols = [_safe_id(c) for c in conflict_columns]
|
|
|
|
update_cols = [c for c in columns if c not in conflict_columns]
|
|
update_clause = ", ".join(
|
|
f"{_safe_id(c)} = EXCLUDED.{_safe_id(c)}" for c in update_cols
|
|
)
|
|
|
|
query = f"""
|
|
INSERT INTO {_safe_id(table)}
|
|
({", ".join(safe_cols)})
|
|
VALUES ({", ".join(placeholders)})
|
|
ON CONFLICT ({", ".join(conflict_cols)})
|
|
DO UPDATE SET {update_clause}
|
|
RETURNING *
|
|
"""
|
|
with get_cursor() as cur:
|
|
cur.execute(query, data)
|
|
return dict(cur.fetchone()) if cur.rowcount else None
|
|
|
|
|
|
def insert_many(table, rows):
|
|
if not rows:
|
|
return 0
|
|
columns = list(rows[0].keys())
|
|
safe_cols = [_safe_id(c) for c in columns]
|
|
query = f"""
|
|
INSERT INTO {_safe_id(table)}
|
|
({", ".join(safe_cols)})
|
|
VALUES %s
|
|
"""
|
|
template = f"({', '.join(f'%({col})s' for col in columns)})"
|
|
with get_cursor() as cur:
|
|
psycopg2.extras.execute_values(
|
|
cur, query, rows, template=template, page_size=100
|
|
)
|
|
return cur.rowcount
|
|
|
|
|
|
def execute(query, params=None):
|
|
with get_cursor() as cur:
|
|
cur.execute(query, params or {})
|
|
if cur.description:
|
|
return [dict(row) for row in cur.fetchall()]
|
|
return cur.rowcount
|
|
|
|
|
|
def table_exists(table):
|
|
with get_cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT EXISTS (
|
|
SELECT FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = %(table)s
|
|
)
|
|
""",
|
|
{"table": table},
|
|
)
|
|
return cur.fetchone()["exists"]
|
|
|
|
|
|
def get_columns(table):
|
|
with get_cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT column_name, data_type, is_nullable, column_default
|
|
FROM information_schema.columns
|
|
WHERE table_schema = 'public' AND table_name = %(table)s
|
|
ORDER BY ordinal_position
|
|
""",
|
|
{"table": table},
|
|
)
|
|
return [dict(row) for row in cur.fetchall()]
|