Dask backend#

Dask provides parallel, out-of-core arrays. saiunit accepts a dask.array.Array mantissa and keeps operations lazy: building a quantity, arithmetic, and most saiunit.math / saiunit.linalg operations do not trigger a compute. Use it for arrays that don’t fit in memory or for embarrassingly parallel array work on a cluster.

Installation#

pip install saiunit[dask]

Graceful import#

import saiunit as u

try:
    import dask.array as da
    HAVE_DASK = True
except ImportError:
    HAVE_DASK = False
    print('dask not installed; install with: pip install saiunit[dask]')
dask not installed; install with: pip install saiunit[dask]

Quick start — lazy by default#

if HAVE_DASK:
    import numpy as np
    big = da.from_array(np.arange(1_000_000.0), chunks=100_000)
    q = u.Quantity(big, unit=u.meter)
    print('backend =', q.backend)
    print('shape   =', q.shape)        # no compute
    print('lazy add:', (q + q).backend)

What requires compute#

Operations that need a Python scalar — float(q), int(q), q.tolist(), np.asarray(q), hash(q), operator.index(q) — raise BackendError. Call q.mantissa.compute() first, then materialize.

if HAVE_DASK:
    import numpy as np
    from saiunit import BackendError

    single = u.Quantity(da.from_array(np.array([42.0]), chunks=1), unit=u.meter)
    try:
        float(single)
    except BackendError as exc:
        print('expected:', exc)

    # materialize first
    eager_mantissa = single.mantissa.compute()
    print('after compute:', u.Quantity(eager_mantissa, unit=u.meter) / u.meter)

Conversion#

if HAVE_DASK:
    import numpy as np
    q_np = u.Quantity(np.arange(1_000_000.0), unit=u.meter)
    q_da = q_np.to_dask(chunks=100_000)
    print(q_da.backend, q_da.mantissa.chunks)

Mixed-backend arithmetic#

Mixing dask and non-dask quantities lands on the default-backend tiebreaker. If the result lands on dask, the non-dask operand is auto-lifted.

if HAVE_DASK:
    import numpy as np
    q_da = u.Quantity(da.from_array(np.array([1.0, 2.0]), chunks=1), unit=u.meter)
    q_np = u.Quantity(np.array([3.0, 4.0]), unit=u.meter)
    with u.using_backend('dask'):
        result = q_da + q_np
        print(result.backend)            # 'dask'
        print(result.mantissa.compute())

Limitations#

  • saiunit.autograd, saiunit.lax, saiunit.sparse are JAX-only.

  • Operations needing a concrete value require an explicit .compute().