# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import math
import numbers
from typing import TYPE_CHECKING, Optional, Sequence, Union
import numpy as np
from ._jax_compat import Tracer as _Tracer
from ._typing import Array
if TYPE_CHECKING:
import jax
from ._base_quantity import Quantity
__all__ = [
"SparseMatrix"
]
def _same_sparsity_pattern(a, b) -> bool:
"""Check whether two index arrays describe the same sparsity pattern.
Returns ``True`` if ``a`` and ``b`` are the same Python object, or if both
are concrete arrays with equal shape and values. Under JIT tracing, falls
back to object identity since traced values aren't comparable in Python
boolean context.
"""
if a is b:
return True
if isinstance(a, _Tracer) or isinstance(b, _Tracer):
return False
a_shape = getattr(a, "shape", None)
b_shape = getattr(b, "shape", None)
if a_shape is not None and b_shape is not None and a_shape != b_shape:
return False
return bool(np.array_equal(a, b))
[docs]
class SparseMatrix:
"""
Base class for sparse matrices in ``saiunit``.
This base class defines the interface that all sparse matrix implementations
in the ``saiunit`` package should follow. Concrete subclasses must implement
the abstract methods defined here.
Attributes
----------
data : Array
The non-zero values in the sparse matrix.
Notes
-----
This class provides ``NotImplementedError`` for most operations, requiring concrete
subclasses to implement them according to their specific sparse format.
Examples
--------
``SparseMatrix`` is not instantiated directly. Use a concrete subclass such as
:class:`~saiunit.sparse.CSR`, :class:`~saiunit.sparse.CSC`, or
:class:`~saiunit.sparse.COO`.
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0.], [0., 2.]])
>>> csr = susparse.CSR.fromdense(dense)
>>> isinstance(csr, susparse.SparseMatrix)
True
"""
data: Array
shape: tuple[int, ...]
nse: property
dtype: property
__hash__ = None # type: ignore[assignment]
def __init__(
self,
args: tuple[Array, ...],
*,
shape: Sequence[int]
):
self.shape = tuple(int(s) for s in shape)
def __len__(self):
return self.shape[0]
@property
def size(self) -> int:
return math.prod(self.shape)
@property
def ndim(self) -> int:
return len(self.shape)
def __repr__(self):
name = self.__class__.__name__
try:
nse = self.nse
dtype = self.dtype
shape = list(self.shape)
except Exception:
repr_ = f"{name}(<invalid>)"
else:
repr_ = f"{name}({dtype}{shape}, {nse=})"
return repr_
@property
def T(self):
return self.transpose()
def block_until_ready(self):
for arg in self.tree_flatten()[0]:
arg.block_until_ready()
return self
def tree_flatten(self):
raise NotImplementedError(f"{self.__class__}.tree_flatten")
@classmethod
def tree_unflatten(cls, aux_data, children):
raise NotImplementedError(f"{cls}.tree_unflatten")
def transpose(self, axes=None):
raise NotImplementedError(f"{self.__class__}.transpose")
def todense(self):
raise NotImplementedError(f"{self.__class__}.todense")
[docs]
def with_data(
self,
data: Union[Array, np.ndarray, numbers.Number, 'Quantity']
):
"""
Create a new sparse matrix with the same sparsity structure but different data.
Parameters
----------
data : Array, numpy.ndarray, numbers.Number, or Quantity
The new non-zero values. Must have the same shape, dtype, and unit
as the current ``self.data``.
Returns
-------
SparseMatrix
A new sparse matrix of the same type with the provided data.
Raises
------
NotImplementedError
If called on the abstract base class directly.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0.], [0., 2.]])
>>> csr = susparse.CSR.fromdense(dense)
>>> new_csr = csr.with_data(csr.data * 3)
>>> new_csr.todense()
Array([[3., 0.],
[0., 6.]], dtype=float32)
"""
raise NotImplementedError(f"{self.__class__}.assign_data")
[docs]
def sum(self, axis: Optional[Union[int, Sequence[int]]] = None):
"""
Sum of the elements of the sparse matrix.
Parameters
----------
axis : int, sequence of int, or None, optional
Axis or axes along which the sum is computed. The default (``None``)
computes the sum of the flattened array. Currently only ``None`` is
supported.
Returns
-------
Array or Quantity
The sum of all elements in the sparse matrix.
Raises
------
NotImplementedError
If ``axis`` is not ``None``.
"""
if axis is not None:
raise NotImplementedError("CSR.sum with axis is not implemented.")
return self.data.sum()
[docs]
def yw_to_w(
self,
y_dim_arr: Union[Array, np.ndarray, 'Quantity'],
w_dim_arr: Union[Array, np.ndarray, 'Quantity']
) -> Union[Array, 'Quantity']:
"""
The protocol method to convert the product of the sparse matrix and a vector to the sparse matrix data.
This protocol method is primarily used in `brainscale <https://github.com/chaobrain/brainscale>`_.
Args:
y_dim_arr: The first vector.
w_dim_arr: The second vector.
Returns:
The outer product of the two vectors.
"""
raise NotImplementedError(f"{self.__class__}.yw_to_y is not implemented.")
def __abs__(self):
raise NotImplementedError(f"{self.__class__}.__abs__ is not implemented.")
def __neg__(self):
raise NotImplementedError(f"{self.__class__}.__neg__ is not implemented.")
def __pos__(self):
raise NotImplementedError(f"{self.__class__}.__pos__ is not implemented.")
def __matmul__(self, other):
raise NotImplementedError(f"{self.__class__}.__matmul__ is not implemented.")
def __rmatmul__(self, other):
raise NotImplementedError(f"{self.__class__}.__rmatmul__ is not implemented.")
def __mul__(self, other):
raise NotImplementedError(f"{self.__class__}.__mul__ is not implemented.")
def __rmul__(self, other):
raise NotImplementedError(f"{self.__class__}.__rmul__ is not implemented.")
def __add__(self, other):
raise NotImplementedError(f"{self.__class__}.__add__ is not implemented.")
def __radd__(self, other):
raise NotImplementedError(f"{self.__class__}.__radd__ is not implemented.")
def __sub__(self, other):
raise NotImplementedError(f"{self.__class__}.__sub__ is not implemented.")
def __rsub__(self, other):
raise NotImplementedError(f"{self.__class__}.__rsub__ is not implemented.")
def __div__(self, other):
raise NotImplementedError(f"{self.__class__}.__div__ is not implemented.")
def __rdiv__(self, other):
raise NotImplementedError(f"{self.__class__}.__rdiv__ is not implemented.")
def __truediv__(self, other):
raise NotImplementedError(f"{self.__class__}.__truediv__ is not implemented.")
def __rtruediv__(self, other):
raise NotImplementedError(f"{self.__class__}.__rtruediv__ is not implemented.")
def __floordiv__(self, other):
raise NotImplementedError(f"{self.__class__}.__floordiv__ is not implemented.")
def __rfloordiv__(self, other):
raise NotImplementedError(f"{self.__class__}.__rfloordiv__ is not implemented.")
def __mod__(self, other):
raise NotImplementedError(f"{self.__class__}.__mod__ is not implemented.")
def __rmod__(self, other):
raise NotImplementedError(f"{self.__class__}.__rmod__ is not implemented.")
def __getitem__(self, item):
raise NotImplementedError(f"{self.__class__}.__getitem__ is not implemented.")