"""Definition of a grid object which will be need to define spatial grid as well as time grid.
"""
from collections.abc import Iterable
from functools import singledispatchmethod
from math import prod
import numpy as np
from ..tools.parameter import strictly_positive
[docs]class Coordinate1D:
"""Coordinate for an axis (one-dimensional grid)"""
__slots__ = ("value",)
[docs] def __init__(self, coordinate: int):
"""
:param coordinate: integer corresponding to the position on the axis
"""
self.value = coordinate
def __repr__(self):
return "Coordinate1D(" + str(self.value) + ")"
def __neg__(self):
return Coordinate1D(-self.value)
def __add__(self, other: int):
return Coordinate1D(self.value + other)
def __sub__(self, other: int):
return Coordinate1D(self.value - other)
def __iter__(self):
yield from [self.value]
def __eq__(self, other):
return self.value == other
def __mul__(self, other):
return Coordinate1D(self.value * other)
def __rmul__(self, other):
return self.__imul__(other)
def __imul__(self, other):
self.value *= other
return self
def __hash__(self):
return self.value.__hash__()
[docs]class CoordinateND:
"""Coordinate for a n-dimensional grid"""
__slots__ = ("value",)
[docs] def __init__(self, coordinates: Iterable[int]):
"""
:param coordinates: list of integers corresponding to the coordinates (positions) on each axis
"""
self.value = tuple(coordinates)
def __repr__(self):
dim = str(len(self.value))
return "Coordinate" + dim + "D" + repr(self.value) + ""
def __neg__(self):
return CoordinateND((-u for u in self.value))
def __add__(self, other: Iterable[int]):
return CoordinateND((u + v for u, v in zip(self.value, other)))
def __sub__(self, other: Iterable[int]):
return CoordinateND((u - v for u, v in zip(self.value, other)))
def __iter__(self):
yield from self.value
def __eq__(self, other):
return all(u == v for u, v in zip(self.value, other))
def __mul__(self, other):
return CoordinateND((v * other for v in self.value))
def __rmul__(self, other):
return self.__imul__(other)
def __imul__(self, other):
self.value = tuple(val * other for val in self.value)
return self
def __getitem__(self, item):
return self.value[item]
def __hash__(self):
return self.value.__hash__()
[docs]class Coordinates:
"""General Coordinates object which handles both one-dimensional and n-dimensional cases"""
def __new__(cls, coordinates):
if isinstance(coordinates, Iterable):
return CoordinateND(coordinates)
return Coordinate1D(coordinates)
[docs]class Grid:
"""A grid is a set of axes, each of them being in the form of an interval [a_0,a_1,...,a_K],
and i is the position of the i-th element a_i.
:Example:
For a 2d-grid specified by the axes [a_0,a_1,...,a_K] and [b_0,b_1,...,b_L], the point of coordinates (i,j)
has value (a_i, b_j)
"""
[docs] def __init__(self, axes: list[np.array]):
if not isinstance(axes, list):
raise ValueError("the axes input should be a list of np.arrays")
self.axes = axes
self.dimension = len(axes)
[docs] def number_of_points(self):
"""Total number of points in the grid"""
return prod(axe.size for axe in self.axes)
@singledispatchmethod
def __getitem__(self, coordinates) -> float:
return self.axes[0][coordinates.value]
@__getitem__.register
def _(self, coordinates: CoordinateND) -> tuple[float]:
return tuple(self.axes[k][c] for k, c in enumerate(coordinates))
@singledispatchmethod
def __setitem__(self, coordinates, value):
self.axes[coordinates] = value
@__setitem__.register
def __setitem__(self, coordinates: CoordinateND, value):
for k, (coordinate, val) in enumerate(zip(coordinates, value)):
self.axes[k][coordinate] = val