initdb.py 6.91 KB
#!/usr/bin/env python3

'''
Initializes or updates database
'''

# python standard libraries
import csv
import argparse
import re
from string import capwords
from concurrent.futures import ThreadPoolExecutor

# third party libraries
import bcrypt
import sqlalchemy as sa

# this project
from aprendizations.models import Base, Student


# ===========================================================================
#   Parse command line options
def parse_commandline_arguments():
    '''Parse command line arguments'''

    argparser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='Insert new users into a database. Users can be imported '
                    'from CSV files in the SIIUE format or defined in the '
                    'command line. If the database does not exist, a new one'
                    ' is created.')

    argparser.add_argument('csvfile',
                           nargs='*',
                           type=str,
                           default='',
                           help='CSV file to import (SIIUE)')

    argparser.add_argument('--db',
                           default='students.db',
                           type=str,
                           help='database file')

    argparser.add_argument('-A', '--admin',
                           action='store_true',
                           help='insert the admin user')

    argparser.add_argument('-a', '--add',
                           nargs=2,
                           action='append',
                           metavar=('uid', 'name'),
                           help='add new user')

    argparser.add_argument('-u', '--update',
                           nargs='+',
                           metavar='uid',
                           default=[],
                           help='users to update')

    argparser.add_argument('--pw',
                           default=None,
                           type=str,
                           help='set password for new and updated users')

    argparser.add_argument('-V', '--verbose',
                           action='store_true',
                           help='show all students in database')

    return argparser.parse_args()


# ===========================================================================
def get_students_from_csv(filename):
    '''Reads CSV file with enrolled students in SIIUE format.
    SIIUE names can have suffixes like "(TE)" and are sometimes capitalized.
    These suffixes are removed.'''

    # SIIUE format for CSV files
    csv_settings = {
        'delimiter': ';',
        'quotechar': '"',
        'skipinitialspace': True,
        }

    try:
        with open(filename, encoding='iso-8859-1') as file:
            csvreader = csv.DictReader(file, **csv_settings)
            students = [{
                'uid': s['N.º'],
                'name': capwords(re.sub(r'\(.*\)', '', s['Nome']).strip())
            } for s in csvreader]
    except OSError:
        print(f'!!! Error reading file "{filename}" !!!')
        students = []
    except csv.Error:
        print(f'!!! Error parsing CSV from "{filename}" !!!')
        students = []

    return students


# ===========================================================================
def hashpw(student, passw=None):
    '''replace password by hash for a single student'''
    print('.', end='', flush=True)
    passw = (passw or student.get('pw', None) or student['uid']).encode('utf-8')
    student['pw'] = bcrypt.hashpw(passw, bcrypt.gensalt())


# ===========================================================================
def show_students_in_database(session, verbose=False):
    '''print students that are in the database'''
    users = session.query(Student).all()
    total = len(users)

    print('\nRegistered users:')
    if total == 0:
        print('  -- none --')
    else:
        users.sort(key=lambda u: f'{u.id:>12}')  # sort by number
        if verbose:
            for user in users:
                print(f'{user.id:>12}   {user.name}')
        else:
            print(f'{users[0].id:>12}   {users[0].name}')
            if total > 1:
                print(f'{users[1].id:>12}   {users[1].name}')
            if total > 3:
                print('           |   |')
            if total > 2:
                print(f'{users[-1].id:>12}   {users[-1].name}')
    print(f'Total: {total}.')


# ===========================================================================
def main():
    '''performs the main functions'''

    args = parse_commandline_arguments()

    # --- database stuff
    print(f'Using database: {args.db}')
    engine = sa.create_engine(f'sqlite:///{args.db}', echo=False)
    Base.metadata.create_all(engine)  # Creates schema if needed
    session = sa.orm.sessionmaker(bind=engine)()

    # --- build list of students to insert/update
    students = []

    for csvfile in args.csvfile:
        # print('Adding users from:', csvfile)
        students.extend(get_students_from_csv(csvfile))

    if args.admin:
        # print('Adding user: 0, Admin.')
        students.append({'uid': '0', 'name': 'Admin'})

    if args.add:
        for uid, name in args.add:
            # print(f'Adding user: {uid}, {name}.')
            students.append({'uid': uid, 'name': name})

    # --- only insert students that are not yet in the database
    db_students = {user.id for user in session.query(Student).all()}
    new_students = list(filter(lambda s: s['uid'] not in db_students, students))

    if new_students:
        # --- password hashing
        print('Generating password hashes', end='')
        with ThreadPoolExecutor() as executor:
            executor.map(lambda s: hashpw(s, args.pw), new_students)

        print('\nAdding students:')
        for student in new_students:
            print(f'  + {student["uid"]}, {student["name"]}')

        try:
            session.add_all([Student(id=s['uid'],
                                     name=s['name'],
                                     password=s['pw'])
                            for s in new_students])
            session.commit()
        except sa.exc.IntegrityError:
            print('!!! Integrity error. Aborted !!!\n')
            session.rollback()

        print(f'Inserted {len(new_students)} new student(s).')
    else:
        print('There are no new students to add.')

    # --- update data for student in the database
    for student_id in args.update:
        print(f'Updating password of: {student_id}')
        student = session.query(Student).get(student_id)
        if student is not None:
            passw = (args.pw or student_id).encode('utf-8')
            student.password = bcrypt.hashpw(passw, bcrypt.gensalt())
            session.commit()
        else:
            print(f'!!! Student {student_id} does not exist. Skipped!!!')

    show_students_in_database(session, args.verbose)

    session.close()


# ===========================================================================
if __name__ == '__main__':
    main()