#!/usr/bin/python3
#
# Univention Management Console
# UMC server
#
# SPDX-FileCopyrightText: 2024-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only
import os
import time
from collections.abc import Generator
from contextlib import contextmanager
from sqlalchemy import BigInteger, Column, String, create_engine, text
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
from tornado import ioloop
from univention.management.console.config import (
SQL_CONNECTION_ENV_VAR, SQL_MAX_OVERFLOW_ENV_VAR, SQL_POOL_RECYCLE_ENV_VAR, SQL_POOL_SIZE_ENV_VAR,
SQL_POOL_TIMEOUT_ENV_VAR,
)
from univention.management.console.log import CORE
from univention.management.console.sse import logout_notifiers
[docs]
class DBDisabledException(Exception):
pass
[docs]
class PostgresListenNotifyUnsupported(Exception):
pass
Base = declarative_base()
[docs]
class DBRegistry:
__engine = None
__registry = None
__init = False
_enabled = False
[docs]
@classmethod
def get(cls):
if not cls.__init:
cls.__create()
return cls.__registry()
[docs]
@classmethod
def enabled(cls):
if not cls.__init:
cls.__create()
return cls._enabled
@classmethod
def __create(cls):
cls.__init = True
connection_uri = os.environ.get(SQL_CONNECTION_ENV_VAR, None)
if connection_uri is None:
return
opts = {
'pool_pre_ping': True,
'pool_size': int(os.environ.get(SQL_POOL_SIZE_ENV_VAR)),
'max_overflow': int(os.environ.get(SQL_MAX_OVERFLOW_ENV_VAR)),
'pool_timeout': int(os.environ.get(SQL_POOL_TIMEOUT_ENV_VAR)),
'pool_recycle': int(os.environ.get(SQL_POOL_RECYCLE_ENV_VAR)),
}
url = make_url(connection_uri)
msg = f"Connecting to database {url.drivername}://{url.host}/{url.database} with parameters {', '.join([f'{k}={v}' for k, v in opts.items()])}"
CORE.process("Connecting to database: %s", msg)
engine = create_engine(
connection_uri,
**opts,
)
cls.__engine = engine
cls.__registry = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(cls.__engine)
try:
CORE.debug("Starting the PostgresListener")
PostgresListener(engine).listen()
except PostgresListenNotifyUnsupported as e:
CORE.warning('The configured database is not Postgres. The automatic portal refresh will not work!\n%s', e)
cls._enabled = True
[docs]
class PostgresListener:
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.conn = engine.connect()
[docs]
def verify_postgres(self, engine: Engine | None = None):
if engine is None:
engine = self.engine
if not self.engine.dialect.dialect_description == 'postgresql+psycopg2':
raise PostgresListenNotifyUnsupported(f"Expected sqlalchemy dialect 'pstgresql+psycopg' but got {self.engine.dialect.dialect_description}")
[docs]
def listen(self):
self.verify_postgres()
CORE.debug("Executing 'LISTEN logout'")
self.conn.execution_options(autocommit=True).execute(text("LISTEN logout"))
ioloop.IOLoop.current().asyncio_loop.add_reader(self.conn.connection, self.handle_postgres_notify)
[docs]
def handle_postgres_notify(self):
if self.conn is None:
return
self.conn.connection.poll()
while self.conn.connection.notifies:
notify = self.conn.connection.notifies.pop()
payload = notify.payload
notifier = logout_notifiers.get(payload)
if notifier is not None:
CORE.debug('Got a logout notifier for session %s', payload)
notifier.set()
[docs]
@classmethod
def notify(cls, conn: Connection | Session, session_id: str):
if isinstance(conn, Session):
connection = conn.connection()
else:
connection = conn
cls.verify_postgres(connection.engine)
connection.execution_options(autocommit=True).execute(text("NOTIFY logout, :session_id;").bindparams(session_id=session_id))
[docs]
@contextmanager
def get_session(auto_commit=True) -> Generator[Session, None, None]:
if not DBRegistry.enabled():
raise DBDisabledException
session = None
try:
session = DBRegistry.get()
yield session
finally:
if session is not None:
if auto_commit:
session.commit()
session.close()
[docs]
class DBSession(Base):
__tablename__ = 'sessions'
session_id = Column(String(256), primary_key=True)
expire_time = Column(BigInteger)
oidc_sid = Column(String(256))
oidc_sub = Column(String(256))
oidc_iss = Column(String(256))
sessions = {}
def __repr__(self):
return f'<Session(session_id={self.session_id}, expire_time={self.expire_time}, oidc_sid={self.oidc_sid}, oidc_sub={self.oidc_sub}, oidc_iss={self.oidc_iss})>'
[docs]
@classmethod
def get(cls, db_session, session_id):
return db_session.query(cls).filter(cls.session_id == session_id).first()
[docs]
@classmethod
def delete(cls, db_session: Session, session_id: str, send_postgres_logout_notify: bool = False):
if send_postgres_logout_notify:
try:
CORE.debug("Deleting a session that is not ours. Sending postgres notify")
PostgresListener.notify(db_session, session_id)
except PostgresListenNotifyUnsupported:
pass
db_session.query(cls).filter(cls.session_id == session_id).delete()
[docs]
@classmethod
def update(cls, db_session, session_id, umc_session):
expire_time = cls.calculate_session_end_time(umc_session)
db_session.query(cls).filter(cls.session_id == session_id).update({'expire_time': expire_time})
[docs]
@classmethod
def create(cls, db_session, session_id, umc_session):
oidc_params = {}
if umc_session.oidc:
oidc_params['oidc_sid'] = umc_session.oidc.claims.get('sid')
oidc_params['oidc_sub'] = umc_session.oidc.claims.get('sub')
oidc_params['oidc_iss'] = umc_session.oidc.claims.get('iss')
db_session.add(cls(session_id=session_id, expire_time=cls.calculate_session_end_time(umc_session), **oidc_params))
db_session.commit()
[docs]
@classmethod
def calculate_session_end_time(cls, umc_session):
session_valid_in_seconds = umc_session.session_end_time - time.monotonic()
real_session_end_time = time.time() + session_valid_in_seconds
return real_session_end_time
[docs]
@classmethod
def get_by_oidc(cls, db_session, claims):
oidc_sessions_by_sid = db_session.query(cls).filter(cls.oidc_iss == claims.get('iss'), cls.oidc_sid == claims.get('sid')).first()
if oidc_sessions_by_sid:
yield oidc_sessions_by_sid
else:
oidc_sessions_by_sub = db_session.query(cls).filter(cls.oidc_iss == claims.get('iss'), cls.oidc_sub == claims.get('sub')).all()
yield from oidc_sessions_by_sub