event_lattice.py 5.24 KB
import math
from functools import cache
from itertools import accumulate
import operator


    
def uniform_op(x):
    n = len(list(x))
    return 1.0 if n == 0 else 1.0/n


def max_op(x):
    return max(x)


def min_op(x):
    return min(x)


def sum_op(x):
    return sum(x)


def stableprod_op(x):
    log_x = map(math.log, x)
    return math.exp(sum(log_x))


def prod_op(x):
    return list(accumulate(x, func=lambda a,b: a*b))[-1]


class Event:
    """"Events."""

    @staticmethod
    def parse(text):
        return frozenset(text)

    @staticmethod
    def from_str(text):
        return Event(Event.parse(text))


    def __init__(self, literals):
        """Instantiate from a (frozen) set of literals.
        For example: e = Event(frozenset("abc"))."""
        self._literals = literals


    def literals(self):
        return self._literals


    def __iter__(self):
        return self._literals.__iter__()


    def co(self):
        """Negation is case based: A = not a; a = not A."""
        return Event(x.swapcase() for x in self._literals)

    
    @cache
    def is_consistent(self):
        return all(x.swapcase() not in self._literals for x in self._literals)

    def __hash__(self) -> int:
        return self._literals.__hash__()


    def __eq__(self, other):
        return self._literals.__eq__(other._literals)


    def __repr__(self) -> str:
        return ''.join(str(x) for x in sorted(self._literals))

    def invert(self):
        return self.co()


    def __eq__(self, other):
        return self._literals.__eq__(other._literals)

    def __or__(self, other):
        return Event(self._literals | other._literals)


    def __le__(self, other):
        return self._literals.__le__(other._literals)


    def __lt__(self, other):
        return self._literals.__lt__(other._literals)


    def __ne__(self, other):
        return self._literals.__ne__(other._literals)


    def __ge__(self, other):
        return self._literals.__ge__(other._literals)


    def __gt__(self, other):
        return self._literals.__gt__(other._literals)


class Lattice:

    @staticmethod
    def parse(d):
        result = dict()
        for k, v in d.items():
            key = Event.from_str(k)
            result[key] = v
        return result


    @staticmethod
    def close_literals(smodels):
        base_lits = list(accumulate(smodels, func=operator.or_))[-1]
        lits = set()
        for x in base_lits.literals():
            lits.add(x)
            lits.add(x.swapcase())
        return lits

    def __init__(self, smodels_dict):
        """Create base for Events Lattice."""
        self._smodels = smodels_dict
        self._literals = Lattice.close_literals(self._smodels.keys())

    def literals(self):
        return self._literals

    def stable_models(self):
        return list(map(set, self._smodels.keys()))

    @cache
    def lower_bound(self, event):
        return list(set(filter(lambda sm: sm <= event, self._smodels)))

    @cache
    def upper_bound(self, event):
        return list(set(filter(lambda sm: event <= sm, self._smodels)))


    def event_class(self, event):
        return EventsClass(
            self.upper_bound(event),
            self.lower_bound(event),
            self)

    def related(self, u, v):
        u_consistent = u.is_consistent()
        v_consistent = v.is_consistent()
        if u_consistent and (u_consistent == v_consistent):
            return \
                self.lower_bound(u) == self.lower_bound(v) and \
                self.upper_bound(u) == self.upper_bound(v)
        else:
            return u_consistent == v_consistent

    def factors(self, event):
        return [self.lower_bound(event), self.upper_bound(event)]

    def propagated_value(self, event:Event, 
            lower_op=sum_op, 
            upper_op=prod_op):
        value = 0

        if not event.is_consistent():
            return value

        lb = self.lower_bound(event)
        len_lb = len(lb)
        if len_lb > 1:
            value = lower_op(map(lambda sm: self._smodels[sm], lb))
        elif len_lb == 1:
            value = self._smodels[lb[0]]
        else:
            ub = self.upper_bound(event)
            len_ub = len(ub)
            if len_ub > 1:
                value = upper_op(map(lambda sm: self._smodels[sm], ub))
            elif len_ub == 1:
                value = self._smodels[ub[0]]

        return value

    def __repr__(self):
        smodels_repr = ',\n\t\t'.join(f"{Event.from_str(k)}: {v:>5}" for k,v in self._smodels.items())
        lits_repr = ','.join(sorted(self._literals))

        return "{\n" +\
            f"\t'stable_models': {{\n\t\t {smodels_repr} \n\t}}\n" +\
            f"\t'literals': {{ {lits_repr} }} \n" +\
            "}"



class EventsClass:
    def __init__(self, upper, lower, lattice:Lattice):
        self._upper = upper
        self._lower = lower
        self._lattice = lattice

    def __repr__(self):
        upper_repr = "" if len(self._upper) == 0 else ",".join(str(x) for x in self._upper)
        lower_repr = "" if len(self._lower) == 0 else ",".join(str(x) for x in self._lower)
        return f"<{upper_repr}|{lower_repr}>"

    def __contains__(self, event:Event):
        return self.lattice.lower_bound(event) == self._lower and \
            self.lattice.upper_bound(event) == self._upper