"""Huffman's Tree method to generate discrete probability distribution
"""
import bisect
from collections.abc import Callable
from operator import attrgetter
from typing import Any
import numpy as np
from numpy.typing import NDArray
from ..sampling import Sampling
from ..univariate.uniform import Uniform
[docs]class HuffmanTree(Sampling):
[docs] def __init__(self, probabilities: np.array, states: Callable[[list[int]], Any]):
super().__init__()
self.K = len(probabilities) - 1
self.head = create_huffman_tree(np.array(probabilities))
self.uniform = Uniform()
self.states = states
[docs] def sample(self, size: int = 1) -> NDArray[float]:
us = self.uniform.sample(size=size)
res_states = np.empty(shape=size, dtype=int)
head = self.head
for k in range(size):
state, sampling_cost = sample_with_u(us[k], head)
self.sampling_cost += sampling_cost
res_states[k] = self.states(state)
return res_states
[docs]def sample_with_u(u: float, head: "Node") -> tuple[int, int]:
sampling_cost = 0
ptr = head
while not ptr.is_leaf:
left = ptr.left_node
left_val = left.value
if u < left_val:
ptr = left
else:
u -= left_val
ptr = ptr.right_node
sampling_cost += 1
return ptr.state, sampling_cost
[docs]class Node:
"""A node in the tree is defined by its value and whether it is a leaf (or has children nodes)"""
__slots__ = ("value", "is_leaf")
[docs] def __init__(self, value: float, is_leaf: bool):
self.value = value
self.is_leaf = is_leaf
[docs]class InternalNode(Node):
"""An internal node is not a lead, that is it has at least a left node child or a right node child"""
__slots__ = ("left_node", "right_node")
[docs] def __init__(self, value: float, left_node: Node, right_node: Node):
super().__init__(value=value, is_leaf=False)
self.left_node = left_node
self.right_node = right_node
[docs]class Leaf(Node):
"""A leaf node has no children"""
__slots__ = ("state",)
[docs] def __init__(self, value: float, state: int):
super().__init__(value=value, is_leaf=True)
self.state = state
[docs]class Heap:
[docs] def __init__(self, nodes: [Node]):
# the heap is sorted in decreasing order (in the value of the nodes)
self.nodes = nodes
self.nodes.sort(key=attrgetter("value"), reverse=True)
self._values = [node.value for node in self.nodes][
::-1
] # warning: the values are in increasing order
[docs] def pop(self) -> Node:
self._values.pop()
return self.nodes.pop()
[docs] def insert(self, node: Node):
index = bisect.bisect_left(self._values, node.value)
self._values.insert(index, node.value)
index = len(self._values) - index
self.nodes.insert(index, node)
[docs]def create_huffman_tree(probabilities) -> Node:
length = len(probabilities) - 1
heap = Heap([Leaf(p, state) for state, p in enumerate(probabilities)])
for _ in range(length):
node1 = heap.pop()
node2 = heap.pop()
node = InternalNode(node1.value + node2.value, node1, node2)
heap.insert(node)
return heap.nodes[0]