Source code for univention.management.console.session_db

#!/usr/bin/python3
#
# Univention Management Console
#  UMC server
#
# SPDX-FileCopyrightText: 2024-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

import os
import time
from collections.abc import Generator
from contextlib import contextmanager

from sqlalchemy import BigInteger, Column, String, create_engine, text
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
from tornado import ioloop

from univention.management.console.config import (
    SQL_CONNECTION_ENV_VAR, SQL_MAX_OVERFLOW_ENV_VAR, SQL_POOL_RECYCLE_ENV_VAR, SQL_POOL_SIZE_ENV_VAR,
    SQL_POOL_TIMEOUT_ENV_VAR,
)
from univention.management.console.log import CORE
from univention.management.console.sse import logout_notifiers


[docs] class DBDisabledException(Exception): pass
[docs] class PostgresListenNotifyUnsupported(Exception): pass
Base = declarative_base()
[docs] class DBRegistry: __engine = None __registry = None __init = False _enabled = False
[docs] @classmethod def get(cls): if not cls.__init: cls.__create() return cls.__registry()
[docs] @classmethod def enabled(cls): if not cls.__init: cls.__create() return cls._enabled
@classmethod def __create(cls): cls.__init = True connection_uri = os.environ.get(SQL_CONNECTION_ENV_VAR, None) if connection_uri is None: return opts = { 'pool_pre_ping': True, 'pool_size': int(os.environ.get(SQL_POOL_SIZE_ENV_VAR)), 'max_overflow': int(os.environ.get(SQL_MAX_OVERFLOW_ENV_VAR)), 'pool_timeout': int(os.environ.get(SQL_POOL_TIMEOUT_ENV_VAR)), 'pool_recycle': int(os.environ.get(SQL_POOL_RECYCLE_ENV_VAR)), } url = make_url(connection_uri) msg = f"Connecting to database {url.drivername}://{url.host}/{url.database} with parameters {', '.join([f'{k}={v}' for k, v in opts.items()])}" CORE.process("Connecting to database: %s", msg) engine = create_engine( connection_uri, **opts, ) cls.__engine = engine cls.__registry = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base.metadata.create_all(cls.__engine) try: CORE.debug("Starting the PostgresListener") PostgresListener(engine).listen() except PostgresListenNotifyUnsupported as e: CORE.warning('The configured database is not Postgres. The automatic portal refresh will not work!\n%s', e) cls._enabled = True
[docs] class PostgresListener: def __init__(self, engine: Engine) -> None: self.engine = engine self.conn = engine.connect()
[docs] def verify_postgres(self, engine: Engine | None = None): if engine is None: engine = self.engine if not self.engine.dialect.dialect_description == 'postgresql+psycopg2': raise PostgresListenNotifyUnsupported(f"Expected sqlalchemy dialect 'pstgresql+psycopg' but got {self.engine.dialect.dialect_description}")
[docs] def listen(self): self.verify_postgres() CORE.debug("Executing 'LISTEN logout'") self.conn.execution_options(autocommit=True).execute(text("LISTEN logout")) ioloop.IOLoop.current().asyncio_loop.add_reader(self.conn.connection, self.handle_postgres_notify)
[docs] def handle_postgres_notify(self): if self.conn is None: return self.conn.connection.poll() while self.conn.connection.notifies: notify = self.conn.connection.notifies.pop() payload = notify.payload notifier = logout_notifiers.get(payload) if notifier is not None: CORE.debug('Got a logout notifier for session %s', payload) notifier.set()
[docs] @classmethod def notify(cls, conn: Connection | Session, session_id: str): if isinstance(conn, Session): connection = conn.connection() else: connection = conn cls.verify_postgres(connection.engine) connection.execution_options(autocommit=True).execute(text("NOTIFY logout, :session_id;").bindparams(session_id=session_id))
[docs] @contextmanager def get_session(auto_commit=True) -> Generator[Session, None, None]: if not DBRegistry.enabled(): raise DBDisabledException session = None try: session = DBRegistry.get() yield session finally: if session is not None: if auto_commit: session.commit() session.close()
[docs] class DBSession(Base): __tablename__ = 'sessions' session_id = Column(String(256), primary_key=True) expire_time = Column(BigInteger) oidc_sid = Column(String(256)) oidc_sub = Column(String(256)) oidc_iss = Column(String(256)) sessions = {} def __repr__(self): return f'<Session(session_id={self.session_id}, expire_time={self.expire_time}, oidc_sid={self.oidc_sid}, oidc_sub={self.oidc_sub}, oidc_iss={self.oidc_iss})>'
[docs] @classmethod def get(cls, db_session, session_id): return db_session.query(cls).filter(cls.session_id == session_id).first()
[docs] @classmethod def delete(cls, db_session: Session, session_id: str, send_postgres_logout_notify: bool = False): if send_postgres_logout_notify: try: CORE.debug("Deleting a session that is not ours. Sending postgres notify") PostgresListener.notify(db_session, session_id) except PostgresListenNotifyUnsupported: pass db_session.query(cls).filter(cls.session_id == session_id).delete()
[docs] @classmethod def update(cls, db_session, session_id, umc_session): expire_time = cls.calculate_session_end_time(umc_session) db_session.query(cls).filter(cls.session_id == session_id).update({'expire_time': expire_time})
[docs] @classmethod def create(cls, db_session, session_id, umc_session): oidc_params = {} if umc_session.oidc: oidc_params['oidc_sid'] = umc_session.oidc.claims.get('sid') oidc_params['oidc_sub'] = umc_session.oidc.claims.get('sub') oidc_params['oidc_iss'] = umc_session.oidc.claims.get('iss') db_session.add(cls(session_id=session_id, expire_time=cls.calculate_session_end_time(umc_session), **oidc_params)) db_session.commit()
[docs] @classmethod def calculate_session_end_time(cls, umc_session): session_valid_in_seconds = umc_session.session_end_time - time.monotonic() real_session_end_time = time.time() + session_valid_in_seconds return real_session_end_time
[docs] @classmethod def get_by_oidc(cls, db_session, claims): oidc_sessions_by_sid = db_session.query(cls).filter(cls.oidc_iss == claims.get('iss'), cls.oidc_sid == claims.get('sid')).first() if oidc_sessions_by_sid: yield oidc_sessions_by_sid else: oidc_sessions_by_sub = db_session.query(cls).filter(cls.oidc_iss == claims.get('iss'), cls.oidc_sub == claims.get('sub')).all() yield from oidc_sessions_by_sub