learnapp.py 14.5 KB

# python standard library
from os import path, sys
import logging
from contextlib import contextmanager  # `with` statement in db sessions
import asyncio
from datetime import datetime

# user installed libraries
import bcrypt
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import networkx as nx

# this project
from models import Student, Answer, Topic, StudentTopic
from knowledge import StudentKnowledge
from factory import QFactory
from tools import load_yaml

# setup logger for this module
logger = logging.getLogger(__name__)

# ============================================================================
class LearnAppException(Exception):
    pass


# ============================================================================
# helper functions
# ============================================================================
async def _bcrypt_hash(a, b):
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(None, bcrypt.hashpw, a.encode('utf-8'), b)

async def check_password(try_pw, pw):
    return pw == await _bcrypt_hash(try_pw, pw)

async def bcrypt_hash_gen(new_pw):
    return await _bcrypt_hash(new_pw, bcrypt.gensalt())


# ============================================================================
# LearnApp - application logic
# ============================================================================
class LearnApp(object):
    # ------------------------------------------------------------------------
    # helper to manage db sessions using the `with` statement, for example
    #   with self.db_session() as s:  s.query(...)
    # ------------------------------------------------------------------------
    @contextmanager
    def db_session(self, **kw):
        session = self.Session(**kw)
        try:
            yield session
            session.commit()
        except:
            session.rollback()
            logger.error('DB rollback!!!')
        finally:
            session.close()

    # ------------------------------------------------------------------------
    def __init__(self, config_files):
        self.db_setup()       # setup database and check students
        self.online = dict()  # online students

        self.deps = nx.DiGraph()
        for c in config_files:
            self.populate_graph(c)

        self.db_add_missing_topics(self.deps.nodes())

    # ------------------------------------------------------------------------
    # login
    # ------------------------------------------------------------------------
    async def login(self, uid, try_pw):
        if uid.startswith('l'):  # remove prefix 'l'
            uid = uid[1:]

        with self.db_session() as s:
            try:
                name, password = s.query(Student.name, Student.password).filter_by(id=uid).one()
            except:
                logger.info(f'User "{uid}" does not exist!')
                return False

        pw_ok = await check_password(try_pw, password)  # async bcrypt
        if pw_ok:
            if uid in self.online:
                logger.warning(f'User "{uid}" already logged in, overwriting state')
            else:
                logger.info(f'User "{uid}" logged in successfully')

            with self.db_session() as s:
                tt = s.query(StudentTopic).filter_by(student_id=uid)

            state = {t.topic_id: {
                'level': t.level,
                'date': datetime.strptime(t.date, "%Y-%m-%d %H:%M:%S.%f")
                } for t in tt}

            self.online[uid] = {
                'number': uid,
                'name': name,
                'state': StudentKnowledge(self.deps, state=state),
                }

        else:
            logger.info(f'User "{uid}" wrong password!')

        return pw_ok

    # ------------------------------------------------------------------------
    # logout
    # ------------------------------------------------------------------------
    def logout(self, uid):
        del self.online[uid]
        logger.info(f'User "{uid}" logged out')

    # ------------------------------------------------------------------------
    # change_password. returns True if password is successfully changed.
    # ------------------------------------------------------------------------
    async def change_password(self, uid, pw):
        if not pw:
            return False

        pw = await bcrypt_hash_gen(pw)

        with self.db_session() as s:
            u = s.query(Student).get(uid)
            u.password = pw

        logger.info(f'User "{uid}" changed password')
        return True

    # ------------------------------------------------------------------------
    # checks answer (updating student state) and returns grade.
    # ------------------------------------------------------------------------
    async def check_answer(self, uid, answer):
        knowledge = self.online[uid]['state']
        q = await knowledge.check_answer(answer) # also moves to next question
        logger.info(f'User "{uid}" got {q["grade"]:.2} in question "{q["ref"]}"')
        topic = knowledge.get_current_topic()

        # always save grade of answered question
        with self.db_session() as s:
            s.add(Answer(
                ref=q['ref'],
                grade=q['grade'],
                starttime=str(q['start_time']),
                finishtime=str(q['finish_time']),
                student_id=uid,
                topic_id=topic))
            logger.debug(f'Saved "{q["ref"]}" into database')

        if knowledge.get_current_question() is None:
            # finished topic, save into database
            logger.info(f'User "{uid}" finished "{topic}"')
            level = knowledge.get_topic_level(topic)
            date = str(knowledge.get_topic_date(topic))

            with self.db_session() as s:
                a = s.query(StudentTopic).filter_by(student_id=uid, topic_id=topic).one_or_none()
                if a is None:
                    # insert new studenttopic into database
                    logger.debug('Database insert new studenttopic')
                    t = s.query(Topic).get(topic)
                    u = s.query(Student).get(uid)
                    a = StudentTopic(level=level, date=date, topic=t, student=u) # association object
                    u.topics.append(a)
                else:
                    # update studenttopic in database
                    logger.debug('Database update studenttopic')
                    a.level = level
                    a.date = date

                s.add(a)

            logger.debug(f'Saved topic "{topic}" into database')
            return 'finished_topic'

        if q['tries'] > 0 and q['grade'] <= 0.999:
            return 'wrong'
        # elif q['tries'] <= 0 and q['grade'] <= 0.999:
        #     return 'max_tries_exceeded'
        else:
            return 'new_question'
        # return q['grade']

    # ------------------------------------------------------------------------
    # Start new topic
    # ------------------------------------------------------------------------
    async def start_topic(self, uid, topic):
        loop = asyncio.get_running_loop()
        try:
            await loop.run_in_executor(None, self.online[uid]['state'].init_topic, topic)
        except KeyError as e:
            logger.warning(f'User "{uid}" tried to open nonexistent topic: "{topic}"')
            raise e
        else:
            logger.info(f'User "{uid}" started "{topic}"')

    # ------------------------------------------------------------------------
    # Fill db table 'Topic' with topics from the graph if not already there.
    # ------------------------------------------------------------------------
    def db_add_missing_topics(self, topics):
        with self.db_session() as s:
            dbtopics = [t[0] for t in s.query(Topic.id)]  # get topics from DB
            missing_topics = [Topic(id=t) for t in topics if t not in dbtopics]
            if missing_topics:
                s.add_all(missing_topics)
                logger.info(f'Added {len(missing_topics)} new topics to the database')

    # ------------------------------------------------------------------------
    # setup and check database
    # ------------------------------------------------------------------------
    def db_setup(self, db='students.db'):
        logger.info(f'Checking database "{db}":')
        engine = create_engine(f'sqlite:///{db}', echo=False)
        self.Session = sessionmaker(bind=engine)
        try:
            with self.db_session() as s:
                n = s.query(Student).count()
                m = s.query(Topic).count()
                q = s.query(Answer).count()
        except Exception as e:
            logger.critical(f'Database "{db}" not usable!')
            raise e
        else:
            logger.info(f'{n:6} students')
            logger.info(f'{m:6} topics')
            logger.info(f'{q:6} answers')


    # ========================================================================
    # methods that do not change state (pure functions)
    # ========================================================================


    # ------------------------------------------------------------------------
    def get_student_name(self, uid):
        return self.online[uid].get('name', '')

    # ------------------------------------------------------------------------
    def get_student_state(self, uid):
        return self.online[uid]['state'].get_knowledge_state()

    # ------------------------------------------------------------------------
    def get_student_progress(self, uid):
        return self.online[uid]['state'].get_topic_progress()

    # ------------------------------------------------------------------------
    def get_current_question(self, uid):
        return self.online[uid]['state'].get_current_question()     # dict

    # ------------------------------------------------------------------------
    def get_student_question_type(self, uid):
        return self.online[uid]['state'].get_current_question()['type']

    # ------------------------------------------------------------------------
    def get_student_topic(self, uid):
        return self.online[uid]['state'].get_current_topic()        # str

    # ------------------------------------------------------------------------
    def get_title(self):
        return self.deps.graph.get('title', '')  # FIXME

    # ------------------------------------------------------------------------
    def get_topic_name(self, ref):
        return self.deps.node[ref]['name']

    # # ------------------------------------------------------------------------
    # def get_topic_type(self, ref):
    #     return self.deps.node[ref]['type']

    # ------------------------------------------------------------------------
    def get_current_public_dir(self, uid):
        topic = self.online[uid]['state'].get_current_topic()
        p = self.deps.graph['path']
        return path.join(p, topic, 'public')



    # ============================================================================
    # Populates a digraph.
    #
    # First, topics such as `computer/mips/exceptions` are added as nodes
    # together with dependencies. Then, questions are loaded to a factory.
    #
    #   g.graph['path']         base path where topic directories are located
    #   g.graph['title']        title defined in the configuration YAML
    #   g.graph['database']     sqlite3 database file to use
    #
    # Nodes are the topic references e.g. 'my/topic'
    #   g.node['my/topic']['name']      name of the topic
    #   g.node['my/topic']['questions'] list of question refs defined in YAML
    #   g.node['my/topic']['factory']   dict with question factories
    #
    # Edges are obtained from the deps defined in the YAML file for each topic.
    # ----------------------------------------------------------------------------
    def populate_graph(self, conffile):
        logger.info(f'Loading {conffile} and populating graph:')
        g = self.deps                   # the graph
        config = load_yaml(conffile)    # course configuration

        # set attributes of the graph
        prefix = path.expanduser(config.get('path', '.'))
        # title = config.get('title', '')
        # database = config.get('database', 'students.db')

        # default attributes that apply to the topics
        default_file = config.get('file', 'questions.yaml')
        default_shuffle = config.get('shuffle', True)
        default_choose = config.get('choose', 9999)
        default_forgetting_factor = config.get('forgetting_factor', 1.0)

        # create graph
        # g = nx.DiGraph(path=prefix, title=title, database=database, config=config)


        # iterate over topics and populate graph
        topics = config.get('topics', {})
        tcount = qcount = 0   # topic and question counters
        for tref, attr in topics.items():
            if tref in g:
                logger.error(f'--> Topic {tref} already exists. Skipped.')
                continue

            # add topic to the graph
            g.add_node(tref)
            t = g.node[tref]  # current topic node

            topicpath = path.join(prefix, tref)

            t['type'] = attr.get('type', 'topic')
            t['name'] = attr.get('name', tref)
            t['path'] = topicpath
            t['file'] = attr.get('file', default_file)
            t['shuffle'] = attr.get('shuffle', default_shuffle)
            t['forgetting_factor'] = attr.get('forgetting_factor', default_forgetting_factor)
            g.add_edges_from((d,tref) for d in attr.get('deps', []))

            # load questions as list of dicts
            questions = load_yaml(path.join(topicpath, t['file']), default=[])

            # if questions are left undefined, include all.
            # refs undefined in questions.yaml are set to topic:n
            t['questions'] = attr.get('questions',
                [q.setdefault('ref', f'{tref}:{i}') for i, q in enumerate(questions)])

            # topic will generate a certain amount of questions
            t['choose'] = min(attr.get('choose', default_choose), len(t['questions']))

            # make questions factory (without repeating same question) FIXME move to somewhere else?
            t['factory'] = {}
            for q in questions:
                if q['ref'] in t['questions']:
                    q['path'] = topicpath  # fullpath added to each question
                    t['factory'][q['ref']] = QFactory(q)

            logger.info(f'{len(t["questions"]):6} {tref}')
            qcount += len(t["questions"])  # count total questions
            tcount += 1

        logger.info(f'Total loaded: {tcount} topics, {qcount} questions')