Source code for bppy.model.sync_statement

import random
import itertools
from math import prod
from bppy.utils.weighted_sampling import *
from uuid import uuid4

request = "request"
waitFor = "waitFor"
block = "block"
mustFinish = "mustFinish"
priority = "priority"
localReward = "localReward"

[docs] class sync(dict): def __init__(self, *, request=None, waitFor=None, block=None, mustFinish=None, priority=None, localReward=None, **kwargs): #'''TODO: warn if req/waitFor/block are not BEvent or choice''' if request is not None: self["request"] = request if waitFor is not None: self["waitFor"] = waitFor if block is not None: self["block"] = block if mustFinish is not None: self["mustFinish"] = mustFinish if priority is not None: self["priority"] = priority if localReward is not None: self["localReward"] = localReward super().__init__(kwargs)
[docs] class choice(dict): """ A class to represent a discrete choice object. Keys correspond to the possible choices, and values correspond to the probability of each choice. """ def __init__(self, data, repeat=1, replace=True, sorted=False, *args, **kwargs): if not isinstance(data, dict): raise TypeError("data must be a dict") if not replace and repeat > len(data): raise ValueError("repeat must be smaller than the number of choices") self.repeat = repeat self.replace = replace self.sorted = sorted super().__init__(data, **kwargs) self._id = uuid4() # returns iterable
[docs] def options(self): dist = [self.keys(), self.values()] if self.replace and not self.sorted: pv, pp = [itertools.product(l, repeat=self.repeat) for l in dist] combined_probs = [prod(event_probs) for event_probs in pp] elif self.replace and self.sorted: pv, pp = [itertools.combinations_with_replacement(l, r=self.repeat) for l in dist] elif not self.replace and self.sorted: pv = itertools.combinations(dist[0], r=self.repeat) pv, combined_probs = zip(*[sequence_probability_nr_s(self, perm) for perm in pv]) elif not self.replace and not self.sorted: pv, pp = [itertools.permutations(l, r=self.repeat) for l in dist] else: raise RuntimeError("Invalid combination of replace and sorted") if self.repeat == 1: return zip([v[0] for v in pv], combined_probs) return zip(pv, combined_probs)
[docs] def sample(self): if self.replace: res = random.choices(list(self.keys()), self.values(), k=self.repeat) if self.repeat == 1: return res[0] if self.sorted: sort(res) return res else: res = weighted_sample_without_replacement(list(self.keys()), self.values(), k=self.repeat) if self.repeat == 1: return res[0] if self.sorted: res.sort() return res
def __eq__(self, other): return isinstance(other, choice) and self._id == other._id