initdb.py 6.82 KB
#!/usr/bin/env python3

"""
Commandline utility to initialize and update student database
"""

# base
import csv
import argparse
import re
from string import capwords

# installed packages
from argon2 import PasswordHasher
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError

# this project
from .models import Base, Student


# ============================================================================
def parse_commandline_arguments():
    """Parse command line options"""

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="Insert new users into a database. Users can be imported "
        "from CSV files in the SIIUE format or defined in the "
        "command line. If the database does not exist, a new one "
        "is created.",
    )

    parser.add_argument(
        "csvfile", nargs="*", type=str, default="", help="CSV file to import (SIIUE)"
    )

    parser.add_argument("--db", default="students.db", type=str, help="database file")

    parser.add_argument(
        "-A", "--admin", action="store_true", help='insert admin user 0 "Admin"'
    )

    parser.add_argument(
        "-a",
        "--add",
        nargs=2,
        action="append",
        metavar=("uid", "name"),
        help="add new user id and name",
    )

    parser.add_argument(
        "-u",
        "--update",
        nargs="+",
        metavar="uid",
        default=[],
        help="list of users whose password is to be updated",
    )

    parser.add_argument(
        "-U",
        "--update-all",
        action="store_true",
        help="all except admin will have the password updated",
    )

    parser.add_argument(
        "--pw", default=None, type=str, help="password for new or updated users"
    )

    parser.add_argument(
        "-V", "--verbose", action="store_true", help="show all students in database"
    )

    return parser.parse_args()


# ============================================================================
def get_students_from_csv(filename: str):
    """
    SIIUE names have alien strings like "(TE)" and are sometimes capitalized
    We remove them so that students dont keep asking what it means
    """
    csv_settings = {
        "delimiter": ";",
        "quotechar": '"',
        "skipinitialspace": True,
    }

    try:
        with open(filename, encoding="iso-8859-1") as file:
            csvreader = csv.DictReader(file, **csv_settings)
            students = [
                {
                    "uid": s["N.º"],
                    "name": capwords(re.sub(r"\(.*\)", "", s["Nome"]).strip()),
                }
                for s in csvreader
            ]
    except OSError:
        print(f'!!! Error reading file "{filename}" !!!')
        students = []
    except csv.Error:
        print(f'!!! Error parsing CSV from "{filename}" !!!')
        students = []

    return students


def insert_students_into_db(session, students) -> None:
    """insert list of students into the database"""
    try:
        session.add_all(
            [Student(id=s["uid"], name=s["name"], password=s["pw"]) for s in students]
        )
        session.commit()
    except IntegrityError:
        print("!!! Integrity error. Users already in database. Aborted !!!\n")
        session.rollback()


# ============================================================================
def show_students_in_database(session, verbose=False):
    """get students from database"""
    users = session.execute(select(Student)).scalars().all()
    total = len(users)

    print("Registered users:")
    if total == 0:
        print("  -- none --")
    else:
        users.sort(key=lambda u: f"{u.id:>12}")  # sort by number
        if verbose:
            for user in users:
                print(f"{user.id:>12}   {user.name}")
        else:
            print(f"{users[0].id:>12}   {users[0].name}")
            if total > 1:
                print(f"{users[1].id:>12}   {users[1].name}")
            if total > 3:
                print("           |   |")
            if total > 2:
                print(f"{users[-1].id:>12}   {users[-1].name}")
    print(f"Total: {total}.")


# ============================================================================
def main():
    """insert, update, show students from database"""

    ph = PasswordHasher()
    args = parse_commandline_arguments()

    # --- database
    print(f"Database: {args.db}")
    engine = create_engine(f"sqlite:///{args.db}", echo=False, future=True)
    Base.metadata.create_all(engine)  # Criates schema if needed
    session = Session(engine, future=True) # FIXME: future?

    # --- build list of new students to insert
    students = []

    if args.admin:
        print("Adding user: 0, Admin.")
        students.append({"uid": "0", "name": "Admin"})

    for csvfile in args.csvfile:
        print("Adding users from:", csvfile)
        students.extend(get_students_from_csv(csvfile))

    if args.add:
        for uid, name in args.add:
            print(f"Adding user: {uid}, {name}.")
            students.append({"uid": uid, "name": name})

    # --- insert new students
    if students:
        if args.pw is None:
            print("Set passwords to empty")
            for s in students:
                s["pw"] = ""
        else:
            print("Generating password hashes")
            for s in students:
                s["pw"] = ph.hash(args.pw)
                print(".", end="", flush=True)
        print(f"\nInserting {len(students)}")
        insert_students_into_db(session, students)

    # --- update all students
    if args.update_all:
        query = select(Student).where(Student.id != "0")
        all_students = session.execute(query).scalars().all()
        if args.pw is None:
            print(f"Resetting password of {len(all_students)} users")
            for student in all_students:
                student.password = ''
        else:
            print(f"Updating password of {len(all_students)} users")
            for student in all_students:
                student.password = ph.hash(args.pw)
                print(".", end="", flush=True)
            print()
        session.commit()

    # --- update only specified students
    else:
        for student_id in args.update:
            query = select(Student).where(Student.id == student_id)
            student = session.execute(query).scalar_one()
            if args.pw is None:
                print(f"Resetting password of user {student_id}")
                student.password = ""
            else:
                print(f"Updating password of user {student_id}")
                student.password = ph.hash(args.pw)
        session.commit()

    show_students_in_database(session, args.verbose)

    session.close()


# ============================================================================
if __name__ == "__main__":
    main()