initdb.py 6.44 KB
#!/usr/bin/env python3

'''
Initializes or updates database
'''

# python standard libraries
import csv
import argparse
import re
from string import capwords

# third party libraries
import bcrypt
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError, NoResultFound

# this project
from aprendizations.models import Base, Student


# ===========================================================================
#   Parse command line options
def parse_commandline_arguments():
    '''Parse command line arguments'''

    argparser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='Insert new users into a database. Users can be imported '
                    'from CSV files in the SIIUE format or defined in the '
                    'command line. If the database does not exist, a new one'
                    ' is created.')

    argparser.add_argument('csvfile',
                           nargs='*',
                           type=str,
                           default='',
                           help='CSV file to import (SIIUE)')

    argparser.add_argument('--db',
                           default='students.db',
                           type=str,
                           help='database file')

    argparser.add_argument('-A', '--admin',
                           action='store_true',
                           help='insert the admin user')

    argparser.add_argument('-a', '--add',
                           nargs=2,
                           action='append',
                           metavar=('uid', 'name'),
                           help='add new user')

    argparser.add_argument('-u', '--update',
                           nargs='+',
                           metavar='uid',
                           default=[],
                           help='users to update')

    argparser.add_argument('--pw',
                           default=None,
                           type=str,
                           help='set password for new and updated users')

    argparser.add_argument('-V', '--verbose',
                           action='store_true',
                           help='show all students in database')

    return argparser.parse_args()


# ===========================================================================
def get_students_from_csv(filename):
    '''Reads CSV file with enrolled students in SIIUE format.
    SIIUE names can have suffixes like "(TE)" and are sometimes capitalized.
    These suffixes are removed.'''

    # SIIUE format for CSV files
    csv_settings = {
        'delimiter': ';',
        'quotechar': '"',
        'skipinitialspace': True,
        }

    students = []
    try:
        with open(filename, encoding='iso-8859-1') as file:
            csvreader = csv.DictReader(file, **csv_settings)
            students = [{
                'uid': s['N.º'],
                'name': capwords(re.sub(r'\(.*\)', '', s['Nome']).strip())
                } for s in csvreader]
    except OSError:
        print(f'!!! Error reading file "{filename}" !!!')
    except csv.Error:
        print(f'!!! Error parsing CSV from "{filename}" !!!')

    return students


# ===========================================================================
def show_students_in_database(session, verbose=False) -> None:
    '''shows students in the database'''
    users = session.execute(select(Student)).scalars().all()
    total = len(users)

    print('\nRegistered users:')
    if users:
        users.sort(key=lambda u: f'{u.id:>12}')  # sort by right aligned string
        if verbose:
            for user in users:
                print(f'{user.id:>12}   {user.name}')
        else:
            print(f'{users[0].id:>12}   {users[0].name}')
            if total > 1:
                print(f'{users[1].id:>12}   {users[1].name}')
            if total > 3:
                print('           |   |')
            if total > 2:
                print(f'{users[-1].id:>12}   {users[-1].name}')
    print(f'Total: {total}.')


# ===========================================================================
def main():
    '''performs the main functions'''

    args = parse_commandline_arguments()

    # --- database stuff
    print(f'Database: {args.db}')
    engine = create_engine(f'sqlite:///{args.db}', echo=False, future=True)
    Base.metadata.create_all(engine)  # Creates schema if needed
    session = Session(engine, future=True)

    # --- make list of students to insert/update
    students = []

    for csvfile in args.csvfile:
        students += get_students_from_csv(csvfile)

    if args.admin:
        students.append({'uid': '0', 'name': 'Admin'})

    if args.add is not None:
        for uid, name in args.add:
            students.append({'uid': uid, 'name': name})

    # --- only insert students that are not yet in the database
    print('\nInserting new students:')
    db_students = session.execute(select(Student.id)).scalars().all()
    new_students = [s for s in students if s['uid'] not in set(db_students)]
    for student in new_students:
        print(f'  {student["uid"]}, {student["name"]}')

        passwd = args.pw or student['uid']
        hashed_pw = bcrypt.hashpw(passwd.encode('utf-8'), bcrypt.gensalt())
        session.add(Student(id=student['uid'],
                            name=student['name'],
                            password=hashed_pw))

    try:
        session.commit()
    except IntegrityError:
        print('!!! Integrity error. Aborted !!!\n')
        session.rollback()
    else:
        print(f'Total {len(new_students)} new student(s).')

    # --- update data for students in the database
    if args.update:
        print('\nUpdating passwords of students:')
        count = 0
        for sid in args.update:
            try:
                query = select(Student).filter_by(id=sid)
                student = session.execute(query).scalar_one()
            except NoResultFound:
                print(f'  -> student {sid} does not exist!')
                continue
            count += 1
            print(f'  {sid}, {student.name}')
            passwd = (args.pw or sid).encode('utf-8')
            student.password = bcrypt.hashpw(passwd, bcrypt.gensalt())

        session.commit()
        print(f'Total {count} password(s) updated.')

    show_students_in_database(session, args.verbose)

    session.close()


# ===========================================================================
if __name__ == '__main__':
    main()