diff --git a/aprendizations/learnapp.py b/aprendizations/learnapp.py index dddc325..2d0fbf3 100644 --- a/aprendizations/learnapp.py +++ b/aprendizations/learnapp.py @@ -6,7 +6,7 @@ This is the main controller of the application. # python standard library import asyncio from collections import defaultdict -from contextlib import contextmanager # `with` statement in db sessions +# from contextlib import contextmanager # `with` statement in db sessions from datetime import datetime import logging from random import random @@ -16,8 +16,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, DefaultDict # third party libraries import bcrypt import networkx as nx -import sqlalchemy as sa -import sqlalchemy.orm as orm +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session # this project from aprendizations.models import Student, Answer, Topic, StudentTopic @@ -57,22 +57,22 @@ class LearnApp(): # ------------------------------------------------------------------------ - @contextmanager - def _db_session(self, **kw): - ''' - helper to manage db sessions using the `with` statement, for example - with self._db_session() as s: s.query(...) - ''' - session = self.Session(**kw) - try: - yield session - session.commit() - except Exception: - logger.error('!!! Database rollback !!!') - session.rollback() - raise - finally: - session.close() + # @contextmanager + # def _db_session(self, **kw): + # ''' + # helper to manage db sessions using the `with` statement, for example + # with self._db_session() as s: s.query(...) + # ''' + # session = self.Session(**kw) + # try: + # yield session + # session.commit() + # except Exception: + # logger.error('!!! Database rollback !!!') + # session.rollback() + # raise + # finally: + # session.close() # ------------------------------------------------------------------------ def __init__(self, @@ -369,30 +369,31 @@ class LearnApp(): ''' Fill db table 'Topic' with topics from the graph, if new ''' - with self._db_session() as sess: - new = [Topic(id=t) for t in topics - if (t,) not in sess.query(Topic.id)] - + with Session(self._engine) as session: + db_topics = session.execute(select(Topic.id)).scalars().all() + new = [Topic(id=t) for t in topics if t not in db_topics] if new: - sess.add_all(new) + session.add_all(new) + session.commit() logger.info('Added %d new topic(s) to the database', len(new)) # ------------------------------------------------------------------------ def _db_setup(self, database: str) -> None: - '''setup and check database contents''' + ''' + Setup and check database contents + ''' logger.info('Checking database "%s":', database) if not exists(database): raise LearnException('Database does not exist. ' 'Use "initdb-aprendizations" to create') - engine = sa.create_engine(f'sqlite:///{database}', echo=False) - self.Session = orm.sessionmaker(bind=engine) + self._engine = create_engine(f'sqlite:///{database}', future=True) try: - with self._db_session() as sess: - count_students: int = sess.query(Student).count() - count_topics: int = sess.query(Topic).count() - count_answers: int = sess.query(Answer).count() + with Session(self._engine) as session: + count_students: int = session.query(Student).count() + count_topics: int = session.query(Topic).count() + count_answers: int = session.query(Answer).count() except Exception as exc: logger.error('Database "%s" not usable!', database) raise DatabaseUnusableError() from exc diff --git a/aprendizations/models.py b/aprendizations/models.py index 9beef48..2a98717 100644 --- a/aprendizations/models.py +++ b/aprendizations/models.py @@ -1,7 +1,4 @@ -# python standard library -# from typing import Any - # third party libraries from sqlalchemy import Column, ForeignKey, Integer, Float, String from sqlalchemy.orm import declarative_base, relationship @@ -10,6 +7,7 @@ from sqlalchemy.orm import declarative_base, relationship # =========================================================================== # Declare ORM # FIXME Any is a workaround for mypy static type checking (see https://github.com/python/mypy/issues/6372) +# from typing import Any # Base: Any = declarative_base() Base = declarative_base() diff --git a/mypy.ini b/mypy.ini index b4b39f8..79c689e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,6 @@ python_version = 3.9 plugins = sqlalchemy.ext.mypy.plugin - [mypy-pygments.*] ignore_missing_imports = True -- libgit2 0.21.2