Assertions

Contents

Assertions#

assert_axis_dimension(tensor, axis, expected)

Checks that tensor.shape[axis] == expected.

assert_axis_dimension_comparator(tensor, ...)

Asserts that pass_fn(tensor.shape[axis]) passes.

assert_axis_dimension_gt(tensor, axis, val)

Checks that tensor.shape[axis] > val.

assert_axis_dimension_gteq(tensor, axis, val)

Checks that tensor.shape[axis] >= val.

assert_axis_dimension_lt(tensor, axis, val)

Checks that tensor.shape[axis] < val.

assert_axis_dimension_lteq(tensor, axis, val)

Checks that tensor.shape[axis] <= val.

assert_devices_available(n, devtype[, ...])

Checks that n devices of a given type are available.

assert_equal(first, second)

Checks that the two objects are equal as determined by the == operator.

assert_equal_rank(inputs)

Checks that all arrays have the same rank.

assert_equal_size(inputs)

Checks that all arrays have the same size.

assert_equal_shape(inputs, *[, dims])

Checks that all arrays have the same shape.

assert_equal_shape_prefix(inputs, prefix_len)

Checks that the leading prefix_dims dims of all inputs have same shape.

assert_equal_shape_suffix(inputs, suffix_len)

Checks that the final suffix_len dims of all inputs have same shape.

assert_exactly_one_is_none(first, second)

Checks that one and only one of the arguments is None.

assert_gpu_available([backend])

Checks that at least one GPU device is available.

assert_is_broadcastable(shape_a, shape_b)

Checks that an array of shape_a is broadcastable to one of shape_b.

assert_is_divisible(numerator, denominator)

Checks that numerator is divisible by denominator.

assert_max_traces([fn, n])

Checks that a function is traced at most n times (inclusively).

assert_not_both_none(first, second)

Checks that at least one of the arguments is not None.

assert_numerical_grads(f, f_args, order[, atol])

Checks that autodiff and numerical gradients of a function match.

assert_rank(inputs, expected_ranks)

Checks that the rank of all inputs matches specified expected_ranks.

assert_scalar(x)

Checks that x is a scalar, as defined in pytypes.py (int or float).

assert_scalar_in(x, min_, max_[, included])

Checks that argument is a scalar within segment (by default).

assert_scalar_negative(x)

Checks that a scalar is negative.

assert_scalar_non_negative(x)

Checks that a scalar is non-negative.

assert_scalar_positive(x)

Checks that a scalar is positive.

assert_size(inputs, expected_sizes)

Checks that the size of all inputs matches specified expected_sizes.

assert_shape(inputs, expected_shapes)

Checks that the shape of all inputs matches specified expected_shapes.

assert_tpu_available([backend])

Checks that at least one TPU device is available.

assert_tree_all_finite(tree_like)

Checks that all leaves in a tree are finite.

assert_tree_has_only_ndarrays(tree)

Checks that all tree's leaves are n-dimensional arrays (tensors).

assert_tree_is_on_device(tree, *[, ...])

Checks that all leaves are ndarrays residing in device memory (in HBM).

assert_tree_is_on_host(tree, *[, ...])

Checks that all leaves are ndarrays residing in the host memory (on CPU).

assert_tree_is_sharded(tree, *, devices)

Checks that all leaves are ndarrays sharded across the specified devices.

assert_tree_no_nones(tree)

Checks that a tree does not contain None.

assert_tree_shape_prefix(tree, shape_prefix)

Checks that all tree leaves' shapes have the same prefix.

assert_tree_shape_suffix(tree, shape_suffix)

Checks that all tree leaves' shapes have the same suffix.

assert_trees_all_close(*trees[, rtol, atol])

Checks that all trees have leaves with approximately equal values.

assert_trees_all_close_ulp(*trees[, maxulp])

Checks that tree leaves differ by at most maxulp Units in the Last Place.

assert_trees_all_equal(*trees[, strict])

Checks that all trees have leaves with exactly equal values.

assert_trees_all_equal_comparator(...)

Checks that all trees are equal as per the custom comparator for leaves.

assert_trees_all_equal_dtypes(*trees)

Checks that trees' leaves have the same dtype.

assert_trees_all_equal_sizes(*trees)

Checks that trees have the same structure and leaves' sizes.

assert_trees_all_equal_shapes(*trees)

Checks that trees have the same structure and leaves' shapes.

assert_trees_all_equal_shapes_and_dtypes(*trees)

Checks that trees' leaves have the same shape and dtype.

assert_trees_all_equal_structs(*trees)

Checks that trees have the same structure.

assert_type(inputs, expected_types)

Checks that the type of all inputs matches specified expected_types.

chexify(fn[, async_check, errors])

Wraps a transformed function fn to enable Chex value assertions.

ChexifyChecks

A set of checks imported from checkify.

with_jittable_assertions(fn[, async_check])

An alias for chexify (see the docs).

block_until_chexify_assertions_complete()

Waits until all asynchronous checks complete.

Dimensions(**dim_sizes)

A lightweight utility that maps strings to shape tuples.

disable_asserts()

Disables all Chex assertions.

enable_asserts()

Enables Chex assertions.

clear_trace_counter()

Clears Chex traces' counter for assert_max_traces checks.

if_args_not_none(fn, *args, **kwargs)

Wrap chex assertion to only be evaluated if positional args not None.

Jax Assertions#

chex.assert_max_traces(fn: Callable[[...], Any] | int | None = None, n: Callable[[...], Any] | int | None = None)[source]#

Checks that a function is traced at most n times (inclusively).

JAX re-traces jitted functions every time the structure of passed arguments changes. Often this behaviour is inadvertent and leads to a significant performance drop which is hard to debug. This wrapper checks that the function is re-traced at most n times during program execution.

Examples:

@jax.jit
@chex.assert_max_traces(n=1)
def fn_sum_jitted(x, y):
  return x + y

def fn_sub(x, y):
  return x - y

fn_sub_pmapped = jax.pmap(chex.assert_max_retraces(fn_sub), n=10)
More about tracing:

https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

Parameters:
  • fn – A pure python function to wrap (i.e. it must not be a jitted function).

  • n – The maximum allowed number of retraces (non-negative).

Returns:

Decorated function that raises exception when it is re-traced n+1-st time.

Raises:

ValueError – If fn has already been jitted.

chex.assert_devices_available(n: int, devtype: str, backend: str | None = None, not_less_than: bool = False) None[source]#

Checks that n devices of a given type are available.

Parameters:
  • n – A required number of devices of the given type.

  • devtype – A type of devices, one of {'cpu', 'gpu', 'tpu'}.

  • backend – A type of backend to use (uses Jax default if not provided).

  • not_less_than – Whether to check if the number of devices is not less than n, instead of precise comparison.

Raises:

AssertionError – If number of available device of a given type is not equal or less than n.

chex.assert_gpu_available(backend: str | None = None) None[source]#

Checks that at least one GPU device is available.

Parameters:

backend – A type of backend to use (uses JAX default if not provided).

Raises:

AssertionError – If no GPU device available.

chex.assert_tpu_available(backend: str | None = None) None[source]#

Checks that at least one TPU device is available.

Parameters:

backend – A type of backend to use (uses JAX default if not provided).

Raises:

AssertionError – If no TPU device available.

Value (Runtime) Assertions#

chex.chexify(fn: Callable[..., Any], async_check: bool = True, errors: FrozenSet[checkify.ErrorCategory] = frozenset({<class 'jax._src.checkify.FailedCheckError'>})) Callable[..., Any][source]#

Wraps a transformed function fn to enable Chex value assertions.

Chex value/runtime assertions access concrete values of tensors (e.g. assert_tree_all_finite) which are not available during JAX tracing, see https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html and https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#ConcretizationTypeError.

This wrapper enables them in jitted/pmapped functions by performing a specifically designed JAX transformation https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#the-checkify-transformation and calling functionalised checks https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.checkify.check.html

Example:

@chex.chexify
@jax.jit
def logp1_abs_safe(x: chex.Array) -> chex.Array:
  chex.assert_tree_all_finite(x)
  return jnp.log(jnp.abs(x) + 1)

logp1_abs_safe(jnp.ones(2))  # OK
logp1_abs_safe(jnp.array([jnp.nan, 3]))  # FAILS
logp1_abs_safe.wait_checks()

Note 1: This wrapper allows identifying the first failed assertion in a jitted code by printing a pointer to the line where the failed assertion was invoked. For getting verbose messages (including concrete tensor values), an unjitted version of the code will need to be executed with the same input values. Chex does not currently provide tools to help with this.

Note 2: This wrapper fully supports asynchronous executions (see https://jax.readthedocs.io/en/latest/async_dispatch.html). To block program execution until asynchronous checks for a _chexified_ function fn complete, call fn.wait_checks(). Similarly, chex.block_until_chexify_assertions_complete() will block program execution until _all_ asyncronous checks complete.

Note 3: Chex automatically selects the backend for executing its assertions (i.e. CPU or device accelerator) depending on the program context.

Note 4: Value assertions can have impact on the performance of a function, see https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#limitations

Note 5: static assertions, such as assert_shape or assert_trees_all_equal_dtypes, can be called from a jitted function without chexify wrapper (since they do not access concrete values, only shapes and/or dtypes which are available during JAX tracing).

More examples can be found at deepmind/chex

Parameters:
  • fn – A transformed function to wrap.

  • async_check – Whether to check errors in the async dispatch mode. See https://jax.readthedocs.io/en/latest/async_dispatch.html.

  • errors – A set of checkify.ErrorCategory values which defines the set of enabled checks. By default only explicit checks are enabled (user). You can also for example enable NaN and Div-by-0 errors by passing the float set, or for example combine multiple sets through set operations (float | user).

Returns:

A _chexified_ function, i.e. the one with enabled value assertions. The returned function has wait_checks() method that blocks the caller until all pending async checks complete.

ChexifyChecks

A set of checks imported from checkify.

chex.with_jittable_assertions(fn: Callable[[...], Any], async_check: bool = True) Callable[[...], Any][source]#

An alias for chexify (see the docs).

chex.block_until_chexify_assertions_complete() None[source]#

Waits until all asynchronous checks complete.

See chexify for more detail.

Tree Assertions#

chex.assert_tree_all_finite(tree_like: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None[source]#

Checks that all leaves in a tree are finite.

Parameters:

tree_like – A pytree with array leaves.

Raises:

AssertionError – If any leaf in tree_like is non-finite.

chex.assert_tree_has_only_ndarrays(tree: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None[source]#

Checks that all tree’s leaves are n-dimensional arrays (tensors).

Parameters:

tree – A tree to assert.

Raises:

AssertionError – If the tree contains an object which is not an ndarray.

chex.assert_tree_is_on_device(tree: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], *, platform: Sequence[str] | str = ('gpu', 'tpu'), device: Device | None = None) None[source]#

Checks that all leaves are ndarrays residing in device memory (in HBM).

Sharded DeviceArrays are disallowed.

Parameters:
  • tree – A tree to assert.

  • platform – A platform or a list of platforms where the leaves are expected to reside. Ignored if device is specified.

  • device – An optional device where the tree’s arrays are expected to reside. Any device (except CPU) is accepted if not specified.

Raises:

AssertionError – If the tree contains a leaf that is not an ndarray or does not reside on the specified device or platform.

chex.assert_tree_is_on_host(tree: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], *, allow_cpu_device: bool = True, allow_sharded_arrays: bool = False) None[source]#

Checks that all leaves are ndarrays residing in the host memory (on CPU).

This assertion only accepts trees consisting of ndarrays.

Parameters:
  • tree – A tree to assert.

  • allow_cpu_device – Whether to allow JAX arrays that reside on a CPU device.

  • allow_sharded_arrays – Whether to allow sharded JAX arrays. Sharded arrays are considered “on host” only if they are sharded across CPU devices and allow_cpu_device is True.

Raises:

AssertionError – If the tree contains a leaf that is not an ndarray or does not reside on host.

chex.assert_tree_is_sharded(tree: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], *, devices: Sequence[Device]) None[source]#

Checks that all leaves are ndarrays sharded across the specified devices.

Parameters:
  • tree – A tree to assert.

  • devices – A list of devices which the tree’s leaves are expected to be sharded across. This list is order-sensitive.

Raises:

AssertionError – If the tree contains a leaf that is not a device array sharded across the specified devices.

chex.assert_tree_no_nones(tree: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None[source]#

Checks that a tree does not contain None.

Parameters:

tree – A tree to assert.

Raises:

AssertionError – If the tree contains at least one None.

chex.assert_tree_shape_prefix(tree: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], shape_prefix: Sequence[int]) None[source]#

Checks that all tree leaves’ shapes have the same prefix.

Parameters:
  • tree – A tree to check.

  • shape_prefix – An expected shape prefix.

Raises:

AssertionError – If some leaf’s shape doesn’t start with shape_prefix.

chex.assert_tree_shape_suffix(tree: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], shape_suffix: Sequence[int]) None[source]#

Checks that all tree leaves’ shapes have the same suffix.

Parameters:
  • tree – A tree to check.

  • shape_suffix – An expected shape suffix.

Raises:

AssertionError – If some leaf’s shape doesn’t end with shape_suffix.

chex.assert_trees_all_close(*trees: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], rtol: float = 1e-06, atol: float = 0.0) None[source]#

Checks that all trees have leaves with approximately equal values.

This compares the difference between values of actual and desired up to

atol + rtol * abs(desired).

Parameters:
  • *trees – A sequence of (at least 2) trees with array leaves.

  • rtol – A relative tolerance.

  • atol – An absolute tolerance.

Raises:

AssertionError – If actual and desired values are not equal up to specified tolerance.

chex.assert_trees_all_close_ulp(*trees: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], maxulp: int = 1) None[source]#

Checks that tree leaves differ by at most maxulp Units in the Last Place.

This is the Chex version of np.testing.assert_array_max_ulp.

Assertions on floating point values are tricky because the precision varies depending on the value. For example, with float32, the precision at 1 is np.spacing(np.float32(1.0)) ≈ 1e-7, but the precision at 5,000,000 is only np.spacing(np.float32(5e6)) = 0.5. This makes it hard to predict ahead of time what tolerance to use when checking whether two numbers are equal: a difference of only a couple of bits can equate to an arbitrarily large absolute difference.

Assertions based on _relative_ differences are one solution to this problem, but have the disadvantage that it’s hard to choose the tolerance. If you want to verify that two calculations produce _exactly_ the same result modulo the inherent non-determinism of floating point operations, do you set the tolerance to…0.01? 0.001? It’s hard to be sure you’ve set it low enough that you won’t miss one of your computations being slightly wrong.

Assertions based on ‘units in the last place’ (ULP) instead solve this problem by letting you specify tolerances in terms of the precision actually available at the current scale of your values. The ULP at some value x is essentially the spacing between the floating point numbers actually representable in the vicinity of x - equivalent to the ‘precision’ we discussed above. above. With a tolerance of, say, maxulp=5, you’re saying that two values are within 5 actually-representable-numbers of each other - a strong guarantee that two computations are as close as possible to identical, while still allowing reasonable wiggle room for small differences due to e.g. different operator orderings.

Note that this function is not currently supported within JIT contexts, and does not currently support bfloat16 dtypes.

Parameters:
  • *trees – A sequence of (at least 2) trees with array leaves.

  • maxulp – The maximum number of ULPs by which leaves may differ.

Raises:

AssertionError – If actual and desired values are not equal up to specified tolerance.

chex.assert_trees_all_equal(*trees: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], strict: bool = False) None[source]#

Checks that all trees have leaves with exactly equal values.

If you are comparing floating point numbers, an exact equality check may not be appropriate; consider using assert_trees_all_close.

Parameters:
  • *trees – A sequence of (at least 2) trees with array leaves.

  • strict – If True, disable special scalar handling as described in np.testing.assert_array_equals notes section.

Raises:

AssertionError – If the leaf values actual and desired are not exactly equal.

chex.assert_trees_all_equal_comparator(equality_comparator: Callable[[Any, Any], bool], error_msg_fn: Callable[[Any, Any], str], *trees: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None[source]#

Checks that all trees are equal as per the custom comparator for leaves.

Parameters:
  • equality_comparator – A custom function that accepts two leaves and checks whether they are equal. Expected to be transitive.

  • error_msg_fn – A function accepting two unequal as per equality_comparator leaves and returning an error message.

  • *trees – A sequence of (at least 2) trees to check on equality as per equality_comparator.

Raises:
  • ValueError – If trees does not contain at least 2 elements.

  • AssertionError – if equality_comparator returns False for any pair of trees from trees.

chex.assert_trees_all_equal_dtypes(*trees: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None[source]#

Checks that trees’ leaves have the same dtype.

Parameters:

*trees – A sequence of (at least 2) trees to check.

Raises:

AssertionError – If leaves’ dtypes for any two trees differ.

chex.assert_trees_all_equal_sizes(*trees: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None[source]#

Checks that trees have the same structure and leaves’ sizes.

Parameters:

*trees – A sequence of (at least 2) trees with array leaves.

Raises:

AssertionError – If trees’ structures or leaves’ sizes are different.

chex.assert_trees_all_equal_shapes(*trees: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None[source]#

Checks that trees have the same structure and leaves’ shapes.

Parameters:

*trees – A sequence of (at least 2) trees with array leaves.

Raises:

AssertionError – If trees’ structures or leaves’ shapes are different.

chex.assert_trees_all_equal_shapes_and_dtypes(*trees: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None[source]#

Checks that trees’ leaves have the same shape and dtype.

Parameters:

*trees – A sequence of (at least 2) trees to check.

Raises:

AssertionError – If leaves’ shapes or dtypes for any two trees differ.

chex.assert_trees_all_equal_structs(*trees: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) None[source]#

Checks that trees have the same structure.

Parameters:

*trees – A sequence of (at least 2) trees to assert equal structure between.

Raises:
  • ValueError – If trees does not contain at least 2 elements.

  • AssertionError – If structures of any two trees are different.

Generic Assertions#

chex.assert_axis_dimension(tensor: Array | ndarray | bool_ | number, axis: int, expected: int) None[source]#

Checks that tensor.shape[axis] == expected.

Parameters:
  • tensor – A JAX array.

  • axis – An integer specifying which axis to assert.

  • expected – An expected value of tensor.shape[axis].

Raises:

AssertionError – The dimension of the specified axis does not match the prescribed value.

chex.assert_axis_dimension_comparator(tensor: Array | ndarray | bool_ | number, axis: int, pass_fn: Callable[[int], bool], error_string: str)[source]#

Asserts that pass_fn(tensor.shape[axis]) passes.

Used to implement ==, >, >=, <, <= checks.

Parameters:
  • tensor – A JAX array.

  • axis – An integer specifying which axis to assert.

  • pass_fn – A callable which takes the size of the give dimension and returns false when the assertion should fail.

  • error_string – string which is inserted in assertion failure messages - ‘expected tensor to have dimension {error_string} on axis …’.

Raises:

AssertionError – if pass_fn(tensor.shape[axis], val) does not return true.

chex.assert_axis_dimension_gt(tensor: Array | ndarray | bool_ | number, axis: int, val: int) None[source]#

Checks that tensor.shape[axis] > val.

Parameters:
  • tensor – A JAX array.

  • axis – An integer specifying which axis to assert.

  • val – A value tensor.shape[axis] must be greater than.

Raises:

AssertionError – if the dimension of axis is <= val.

chex.assert_axis_dimension_gteq(tensor: Array | ndarray | bool_ | number, axis: int, val: int) None[source]#

Checks that tensor.shape[axis] >= val.

Parameters:
  • tensor – A JAX array.

  • axis – An integer specifying which axis to assert.

  • val – A value tensor.shape[axis] must be greater than or equal to.

Raises:

AssertionError – if the dimension of axis is < val.

chex.assert_axis_dimension_lt(tensor: Array | ndarray | bool_ | number, axis: int, val: int) None[source]#

Checks that tensor.shape[axis] < val.

Parameters:
  • tensor – A JAX Array.

  • axis – An integer specifiying with axis to assert.

  • val – A value tensor.shape[axis] must be less than.

Raises:

AssertionError – if the dimension of axis is >= val.

chex.assert_axis_dimension_lteq(tensor: Array | ndarray | bool_ | number, axis: int, val: int) None[source]#

Checks that tensor.shape[axis] <= val.

Parameters:
  • tensor – A JAX array.

  • axis – An integer specifying which axis to assert.

  • val – A value tensor.shape[axis] must be less than or equal to.

Raises:

AssertionError – if the dimension of axis is > val.

chex.assert_equal(first: Any, second: Any) None[source]#

Checks that the two objects are equal as determined by the == operator.

Arrays with more than one element cannot be compared. Use assert_trees_all_close to compare arrays.

Parameters:
  • first – A first object.

  • second – A second object.

Raises:

AssertionError – If not (first == second).

chex.assert_equal_rank(inputs: Sequence[Array | ndarray | bool_ | number]) None[source]#

Checks that all arrays have the same rank.

Parameters:

inputs – A collection of arrays.

Raises:
  • AssertionError – If the ranks of all arrays do not match.

  • ValueError – If inputs is not a collection of arrays.

chex.assert_equal_size(inputs: Sequence[Array | ndarray | bool_ | number]) None[source]#

Checks that all arrays have the same size.

Parameters:

inputs – A collection of arrays.

Raises:

AssertionError – If the size of all arrays do not match.

chex.assert_equal_shape(inputs: Sequence[Array | ndarray | bool_ | number], *, dims: int | Sequence[int] | None = None) None[source]#

Checks that all arrays have the same shape.

Parameters:
  • inputs – A collection of arrays.

  • dims – An optional integer or sequence of integers. If not provided, every dimension of every shape must match. If provided, equality of shape will only be asserted for the specified dim(s), i.e. to ensure all of a group of arrays have the same size in the first two dimensions, call assert_equal_shape(tensors_list, dims=(0, 1)).

Raises:
  • AssertionError – If the shapes of all arrays at specified dims do not match.

  • ValueError – If the provided dims are invalid indices into any of arrays; or if inputs is not a collection of arrays.

chex.assert_equal_shape_prefix(inputs: Sequence[Array | ndarray | bool_ | number], prefix_len: int) None[source]#

Checks that the leading prefix_dims dims of all inputs have same shape.

Parameters:
  • inputs – A collection of input arrays.

  • prefix_len – A number of leading dimensions to compare; each input’s shape will be sliced to shape[:prefix_len]. Negative values are accepted and have the conventional Python indexing semantics.

Raises:
  • AssertionError – If the shapes of all arrays do not match.

  • ValuleError – If inputs is not a collection of arrays.

chex.assert_equal_shape_suffix(inputs: Sequence[Array | ndarray | bool_ | number], suffix_len: int) None[source]#

Checks that the final suffix_len dims of all inputs have same shape.

Parameters:
  • inputs – A collection of input arrays.

  • suffix_len – A number of trailing dimensions to compare; each input’s shape will be sliced to shape[-suffix_len:]. Negative values are accepted and have the conventional Python indexing semantics.

Raises:
  • AssertionError – If the shapes of all arrays do not match.

  • ValuleError – If inputs is not a collection of arrays.

chex.assert_exactly_one_is_none(first: Any, second: Any) None[source]#

Checks that one and only one of the arguments is None.

Parameters:
  • first – A first object.

  • second – A second object.

Raises:

AssertionError – If (first is None) xor (second is None) is False.

chex.assert_is_broadcastable(shape_a: Sequence[int], shape_b: Sequence[int]) None[source]#

Checks that an array of shape_a is broadcastable to one of shape_b.

Parameters:
  • shape_a – A shape of the array to check.

  • shape_b – A target shape after broadcasting.

Raises:

AssertionError – If shape_a is not broadcastable to shape_b.

chex.assert_is_divisible(numerator: int, denominator: int) None[source]#

Checks that numerator is divisible by denominator.

Parameters:
  • numerator – A numerator.

  • denominator – A denominator.

Raises:

AssertionError – If numerator is not divisible by denominator.

chex.assert_not_both_none(first: Any, second: Any) None[source]#

Checks that at least one of the arguments is not None.

Parameters:
  • first – A first object.

  • second – A second object.

Raises:

AssertionError – If (first is None) and (second is None).

chex.assert_numerical_grads(f: Callable[[...], Array | ndarray | bool_ | number], f_args: Sequence[Array | ndarray | bool_ | number], order: int, atol: float = 0.01, **check_kwargs) None[source]#

Checks that autodiff and numerical gradients of a function match.

Parameters:
  • f – A function to check.

  • f_args – Arguments of the function.

  • order – An order of gradients.

  • atol – An absolute tolerance.

  • **check_kwargs – Kwargs for jax_test.check_grads.

Raises:

AssertionError – If automatic differentiation gradients deviate from finite difference gradients.

chex.assert_rank(inputs: float | int | Array | ndarray | bool_ | number | Sequence[Array | ndarray | bool_ | number], expected_ranks: int | Set[int] | Sequence[int | Set[int]]) None[source]#

Checks that the rank of all inputs matches specified expected_ranks.

Valid usages include:

assert_rank(x, 0)                      # x is scalar
assert_rank(x, 2)                      # x is a rank-2 array
assert_rank(x, {0, 2})                 # x is scalar or rank-2 array
assert_rank([x, y], 2)                 # x and y are rank-2 arrays
assert_rank([x, y], [0, 2])            # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2})            # x and y are scalar or rank-2 arrays
Parameters:
  • inputs – An array or a sequence of arrays.

  • expected_ranks – A sequence of expected ranks associated with each input, where the expected rank is either an integer or set of integer options; if all inputs have same rank, a single scalar or set of scalars may be passed as expected_ranks.

Raises:
  • AssertionError – If lengths of inputs and expected_ranks don’t match; if expected_ranks has wrong type; if the ranks of inputs do not match expected_ranks.

  • ValueError – If expected_ranks is not an integer and not a sequence of integets.

chex.assert_scalar(x: float | int) None[source]#

Checks that x is a scalar, as defined in pytypes.py (int or float).

Parameters:

x – An object to check.

Raises:

AssertionError – If x is not a scalar as per definition in pytypes.py.

chex.assert_scalar_in(x: Any, min_: float | int, max_: float | int, included: bool = True) None[source]#

Checks that argument is a scalar within segment (by default).

Parameters:
  • x – An object to check.

  • min – A left border of the segment.

  • max – A right border of the segment.

  • included – Whether to include the borders of the segment in the set of allowed values.

Raises:

AssertionError – If x is not a scalar; if x falls out of the segment.

chex.assert_scalar_negative(x: float | int) None[source]#

Checks that a scalar is negative.

Parameters:

x – A value to check.

Raises:

AssertionError – If x is not a scalar or strictly negative.

chex.assert_scalar_non_negative(x: float | int) None[source]#

Checks that a scalar is non-negative.

Parameters:

x – A value to check.

Raises:

AssertionError – If x is not a scalar or negative.

chex.assert_scalar_positive(x: float | int) None[source]#

Checks that a scalar is positive.

Parameters:

x – A value to check.

Raises:

AssertionError – If x is not a scalar or strictly positive.

chex.assert_size(inputs: float | int | Array | ndarray | bool_ | number | Sequence[Array | ndarray | bool_ | number], expected_sizes: Sequence[int | Set[int] | ellipsis | None] | Sequence[Sequence[int | Set[int] | ellipsis | None]]) None[source]#

Checks that the size of all inputs matches specified expected_sizes.

Valid usages include:

assert_size(x, 1)                   # x is scalar (size 1)
assert_size([x, y], (2, {1, 3}))    # x has size 2, y has size 1 or 3
assert_size([x, y], (2, ...))       # x has size 2, y has any size
assert_size([x, y], 1)              # x and y are scalar (size 1)
assert_size((x, y), (5, 2))         # x has size 5, y has size 2
Parameters:
  • inputs – An array or a sequence of arrays.

  • expected_sizes – A sqeuence of expected sizes associated with each input, where the expected size is a sequence of integer and None dimensions; if all inputs have same size, a single size may be passed as expected_sizes.

Raises:

AssertionError – If the lengths of inputs and expected_sizes do not match; if expected_sizes has wrong type; if size of input does not match expected_sizes.

chex.assert_shape(inputs: float | int | Array | ndarray | bool_ | number | Sequence[Array | ndarray | bool_ | number], expected_shapes: Sequence[int | Set[int] | ellipsis | None] | Sequence[Sequence[int | Set[int] | ellipsis | None]]) None[source]#

Checks that the shape of all inputs matches specified expected_shapes.

Valid usages include:

assert_shape(x, ())                  # x is scalar
assert_shape(x, (2, 3))              # x has shape (2, 3)
assert_shape(x, (2, {1, 3}))         # x has shape (2, 1) or (2, 3)
assert_shape(x, (2, None))           # x has rank 2 and `x.shape[0] == 2`
assert_shape(x, (2, ...))            # x has rank >= 1 and `x.shape[0] == 2`
assert_shape([x, y], ())             # x and y are scalar
assert_shape([x, y], [(), (2,3)])    # x is scalar and y has shape (2, 3)
Parameters:
  • inputs – An array or a sequence of arrays.

  • expected_shapes – A sequence of expected shapes associated with each input, where the expected shape is a sequence of integer and None dimensions; if all inputs have same shape, a single shape may be passed as expected_shapes.

Raises:

AssertionError – If the lengths of inputs and expected_shapes do not match; if expected_shapes has wrong type; if shape of input does not match expected_shapes.

chex.assert_type(inputs: float | int | Array | ndarray | bool_ | number | Sequence[Array | ndarray | bool_ | number], expected_types: str | type[Any] | dtype | SupportsDType | Sequence[str | type[Any] | dtype | SupportsDType]) None[source]#

Checks that the type of all inputs matches specified expected_types.

If the expected type is a Python type or abstract dtype (e.g. np.floating), assert that the input has the same sub-type. If the expected type is a concrete dtype (e.g. np.float32), assert that the input’s type is the same.

Example usage:

assert_type(7, int)
assert_type(7.1, float)
assert_type(False, bool)
assert_type([7, 8], int)
assert_type([7, 7.1], [int, float])
assert_type(np.array(7), int)
assert_type(np.array(7.1), float)
assert_type([jnp.array([7, 8]), np.array(7.1)], [int, float])
assert_type(jnp.array(1., dtype=jnp.bfloat16)), jnp.bfloat16)
assert_type(jnp.ones(1, dtype=np.int8), np.int8)
Parameters:
  • inputs – An array or a sequence of arrays or scalars.

  • expected_types – A sequence of expected types associated with each input; if all inputs have same type, a single type may be passed as expected_types.

Raises:

AssertionError – If lengths of inputs and expected_types don’t match; if expected_types contains unsupported pytype; if the types of inputs do not match the expected types.

Shapes and Named Dimensions#

class chex.Dimensions(**dim_sizes)[source]#

A lightweight utility that maps strings to shape tuples.

The most basic usage is:

>>> dims = chex.Dimensions(B=3, T=5, N=7)  # You can specify any letters.
>>> dims['NBT']
(7, 3, 5)

This is useful when dealing with many differently shaped arrays. For instance, let’s check the shape of this array:

>>> x = jnp.array([[2, 0, 5, 6, 3],
...                [5, 4, 4, 3, 3],
...                [0, 0, 5, 2, 0]])
>>> chex.assert_shape(x, dims['BT'])

The dimension sizes can be gotten directly, e.g. dims.N == 7. This can be useful in many applications. For instance, let’s one-hot encode our array.

>>> y = jax.nn.one_hot(x, dims.N)
>>> chex.assert_shape(y, dims['BTN'])

You can also store the shape of a given array in dims, e.g.

>>> z = jnp.array([[0, 6, 0, 2],
...                [4, 2, 2, 4]])
>>> dims['XY'] = z.shape
>>> dims
Dimensions(B=3, N=7, T=5, X=2, Y=4)

You can access the flat size of a shape as

>>> dims.size('BT')  # Same as prod(dims['BT']).
15

Similarly, you can flatten axes together by wrapping them in parentheses:

>>> dims['(BT)N']
(15, 7)

You can set a wildcard dimension, cf. chex.assert_shape():

>>> dims.W = None
>>> dims['BTW']
(3, 5, None)

Or you can use the wildcard character ‘*’ directly:

>>> dims['BT*']
(3, 5, None)

Single digits are interpreted as literal integers. Note that this notation is limited to single-digit literals.

>>> dims['BT123']
(3, 5, 1, 2, 3)

Support for single digits was mainly included to accommodate dummy axes introduced for consistent broadcasting. For instance, instead of using jnp.expand_dims you could do the following:

>>> w = y * x  # Cannot broadcast (3, 5, 7) with (3, 5)
Traceback (most recent call last):
    ...
ValueError: Incompatible shapes for broadcasting: ((3, 5, 7), (1, 3, 5))
>>> w = y * x.reshape(dims['BT1'])
>>> chex.assert_shape(w, dims['BTN'])

Sometimes you only care about some array dimensions but not all. You can use an underscore to ignore an axis, e.g.

>>> chex.assert_rank(y, 3)
>>> dims['__M'] = y.shape  # Skip the first two axes.

Finally note that a single-character key returns a tuple of length one.

>>> dims['M']
(7,)

Utils#

chex.disable_asserts() None[source]#

Disables all Chex assertions.

Use wisely.

chex.enable_asserts() None[source]#

Enables Chex assertions.

chex.clear_trace_counter() None[source]#

Clears Chex traces’ counter for assert_max_traces checks.

Use it to isolate unit tests that rely on assert_max_traces, by calling it at the start of the test case.

chex.if_args_not_none(fn, *args, **kwargs)[source]#

Wrap chex assertion to only be evaluated if positional args not None.

Warnings#

chex.create_deprecated_function_alias(fun, new_name, deprecated_alias)[source]#

Create a deprecated alias for a function.

Example usage: >>> g = create_deprecated_function_alias(f, ‘path.f’, ‘path.g’)

Parameters:
  • fun – the deprecated function.

  • new_name – the new name to use (you may include the path for clarity).

  • deprecated_alias – the old name (you may include the path for clarity).

Returns:

the wrapped function.

chex.warn_deprecated_function(fun, replacement)[source]#

A decorator to mark a function definition as deprecated.

Example usage: >>> @functools.partial(chex.warn_deprecated_function, replacement=’g’) … def f(a, b): … return a + b

Parameters:
  • fun – the deprecated function.

  • replacement – the name of the function to be used instead.

Returns:

the wrapped function.

chex.warn_keyword_args_only_in_future(fun, *, n=0)#

Warns if more than n positional arguments are passed to fun.

For instance: >>> @functools.partial(chex.warn_only_n_pos_args_in_future, n=1) … def f(a, b, c=1): … return a + b + c

Will raise a DeprecationWarning if f is called with more than one positional argument (e.g. both f(1, 2, 3) and f(1, 2, c=3) raise a warning).

Parameters:
  • fun – the function to wrap.

  • n – the number of positional arguments to allow.

Returns:

A wrapped function that emits a warning if more than n positional arguments are passed.

chex.warn_only_n_pos_args_in_future(fun, n)[source]#

Warns if more than n positional arguments are passed to fun.

For instance: >>> @functools.partial(chex.warn_only_n_pos_args_in_future, n=1) … def f(a, b, c=1): … return a + b + c

Will raise a DeprecationWarning if f is called with more than one positional argument (e.g. both f(1, 2, 3) and f(1, 2, c=3) raise a warning).

Parameters:
  • fun – the function to wrap.

  • n – the number of positional arguments to allow.

Returns:

A wrapped function that emits a warning if more than n positional arguments are passed.

Backend restriction#

chex.restrict_backends(*, allowed: Sequence[str] | None = None, forbidden: Sequence[str] | None = None)[source]#

Disallows JAX compilation for certain backends.

Parameters:
  • allowed – Names of backend platforms (e.g. ‘cpu’ or ‘tpu’) for which compilation is still to be permitted.

  • forbidden – Names of backend platforms for which compilation is to be forbidden.

Yields:

None, in a context where compilation for forbidden platforms will raise a RestrictedBackendError.

Raises:

ValueError – if neither allowed nor forbidden is specified (i.e. they are both None), or if anything is both allowed and forbidden.

Dataclasses#

chex.dataclass(cls=None, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, mappable_dataclass=True)[source]#

JAX-friendly wrapper for dataclasses.dataclass().

This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True).

Parameters:
  • cls – A class to decorate.

  • init – See dataclasses.dataclass().

  • repr – See dataclasses.dataclass().

  • eq – See dataclasses.dataclass().

  • order – See dataclasses.dataclass().

  • unsafe_hash – See dataclasses.dataclass().

  • frozen – See dataclasses.dataclass().

  • kw_only – See dataclasses.dataclass().

  • mappable_dataclass – If True (the default), methods to make the class implement the collections.abc.Mapping interface will be generated and the class will include collections.abc.Mapping in its base classes. True is the default, because being an instance of Mapping makes chex.dataclass compatible with e.g. jax.tree_util.tree_* methods, the tree library, or methods related to tensorflow/python/utils/nest.py. As a side-effect, e.g. np.testing.assert_array_equal will only check the field names are equal and not the content. Use chex.assert_tree_* instead.

Returns:

A JAX-friendly dataclass.

chex.mappable_dataclass(cls)[source]#

Exposes dataclass as collections.abc.Mapping descendent.

Allows to traverse dataclasses in methods from dm-tree library.

NOTE: changes dataclasses constructor to dict-type (i.e. positional args aren’t supported; however can use generators/iterables).

Parameters:

cls – A dataclass to mutate.

Returns:

Mutated dataclass implementing collections.abc.Mapping interface.

chex.register_dataclass_type_with_jax_tree_util(data_class)[source]#

Register an existing dataclass so JAX knows how to handle it.

This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information.

Parameters:

data_class – A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__.

Fakes#

fake_jit([enable_patching])

Context manager for patching jax.jit with the identity function.

fake_pmap([enable_patching, jit_result, ...])

Context manager for patching jax.pmap with jax.vmap.

fake_pmap_and_jit([enable_pmap_patching, ...])

Context manager for patching jax.jit and jax.pmap.

set_n_cpu_devices([n])

Forces XLA to use n CPU threads as host devices.

Transformations#

chex.fake_jit(enable_patching: bool = True) FakeContext[source]#

Context manager for patching jax.jit with the identity function.

This is intended to be used as a debugging tool to programmatically enable or disable JIT compilation.

Can be used either as a context managed scope:

with chex.fake_jit():
  @jax.jit
  def foo(x):
    ...

or by calling start and stop:

fake_jit_context = chex.fake_jit()
fake_jit_context.start()

@jax.jit
  def foo(x):
        ...

fake_jit_context.stop()
Parameters:

enable_patching – Whether to patch jax.jit.

Returns:

Context where jax.jit is patched with the identity function jax is configured to avoid jitting internally whenever possible in functions such as jax.lax.scan, etc.

chex.fake_pmap(enable_patching: bool = True, jit_result: bool = False, ignore_axis_index_groups: bool = False, fake_parallel_axis: bool = False) FakeContext[source]#

Context manager for patching jax.pmap with jax.vmap.

This is intended to be used as a debugging tool to programmatically replace pmap transformations with a non-parallel vmap transformation.

Can be used either as a context managed scope:

with chex.fake_pmap():
  @jax.pmap
  def foo(x):
    ...

or by calling start and stop:

fake_pmap_context = chex.fake_pmap()
fake_pmap_context.start()
@jax.pmap
  def foo(x):
    ...
fake_pmap_context.stop()
Parameters:
  • enable_patching – Whether to patch jax.pmap.

  • jit_result – Whether the transformed function should be jitted despite not being pmapped.

  • ignore_axis_index_groups – Whether to force any parallel operation within the context to set axis_index_groups to be None. This is a compatibility option to allow users of the axis_index_groups parameter to run under the fake_pmap context. This feature is not currently supported in vmap, and will fail, so we force the parameter to be None. Warning: This will produce different results to running under jax.pmap

  • fake_parallel_axis – Fake a parallel axis

Returns:

Context where jax.pmap is patched with jax.vmap.

chex.fake_pmap_and_jit(enable_pmap_patching: bool = True, enable_jit_patching: bool = True) FakeContext[source]#

Context manager for patching jax.jit and jax.pmap.

This is a convenience function, equivalent to nested chex.fake_pmap and chex.fake_jit contexts.

Note that calling (the true implementation of) jax.pmap will compile the function, so faking jax.jit in this case will not stop the function from being compiled.

Parameters:
  • enable_pmap_patching – Whether to patch jax.pmap.

  • enable_jit_patching – Whether to patch jax.jit.

Returns:

Context where jax.pmap and jax.jit are patched with jax.vmap and the identity function

Devices#

chex.set_n_cpu_devices(n: int | None = None) None[source]#

Forces XLA to use n CPU threads as host devices.

This allows jax.pmap to be tested on a single-CPU platform. This utility only takes effect before XLA backends are initialized, i.e. before any JAX operation is executed (including jax.devices() etc.). See google/jax#1408.

Parameters:

n – A required number of CPU devices (FLAGS.chex_n_cpu_devices is used by default).

Raises:

RuntimeError – If XLA backends were already initialized.

Pytypes#

Array

alias of Array | ndarray | bool_ | number

ArrayBatched

alias of Array

ArrayDevice

alias of Array

ArrayDeviceTree

alias of Array | Iterable[ArrayDeviceTree] | Mapping[Any, ArrayDeviceTree]

ArrayDType

alias of str | type[Any] | dtype | SupportsDType

ArrayNumpy

alias of ndarray

ArrayNumpyTree

alias of ndarray | Iterable[ArrayNumpyTree] | Mapping[Any, ArrayNumpyTree]

ArraySharded

alias of Array

ArrayTree

alias of Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]

Device

A descriptor of an available device.

Numeric

alias of Array | ndarray | bool_ | number | float | int

PRNGKey

alias of Array

PyTreeDef

Scalar

alias of float | int

Shape

alias of Sequence[int | Any]

Variants#

class chex.ChexVariantType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

An enumeration of available Chex variants.

Use self.variant.type to get type of the current test variant. See the docstring of chex.variants for more information.

class chex.TestCase(*args, **kwargs)[source]#

A class for Chex tests that use variants.

See the docstring for chex.variants for more information.

Note: chex.variants returns a generator producing one test per variant. Therefore, the used test class must support dynamic unrolling of these generators during module import. It is implemented (and battle-tested) in absl.parameterized.TestCase, and here we subclass from it.

chex.variants(test_method='__no__default__', with_jit: bool = False, without_jit: bool = False, with_device: bool = False, without_device: bool = False, with_pmap: bool = False) VariantsTestCaseGenerator#

Decorates a test to expose Chex variants.

The decorated test has access to a decorator called self.variant, which may be applied to functions to test different JAX behaviors. Consider:

@chex.variants(with_jit=True, without_jit=True)
def test(self):
  @self.variant
  def f(x, y):
    return x + y

  self.assertEqual(f(1, 2), 3)

In this example, the function test will be called twice: once with f jitted (i.e. using jax.jit) and another where f is not jitted.

Variants with_jit=True and with_pmap=True accept additional specific to them arguments. Example:

@chex.variants(with_jit=True)
def test(self):
  @self.variant(static_argnums=(1,))
  def f(x, y):
    # `y` is not traced.
    return x + y

  self.assertEqual(f(1, 2), 3)

Variant with_pmap=True also accepts broadcast_args_to_devices (whether to broadcast each input argument to all participating devices), reduce_fn (a function to apply to results of pmapped fn), and n_devices (number of devices to use in the pmap computation). See the docstring of _with_pmap for more details (including default values).

If used with absl.testing.parameterized, @chex.variants must wrap it:

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters('test', *args)
def test(self, *args):
  ...

Tests that use this wrapper must be inherited from parameterized.TestCase. For more examples see variants_test.py.

Parameters:
  • test_method – A test method to decorate.

  • with_jit – Whether to test with jax.jit.

  • without_jit – Whether to test without jax.jit. Any jit compilation done within the test method will not be affected.

  • with_device – Whether to test with args placed on device, using jax.device_put.

  • without_device – Whether to test with args (explicitly) not placed on device, using jax.device_get.

  • with_pmap – Whether to test with jax.pmap, with computation duplicated across devices.

Returns:

A decorated test_method.

chex.all_variants(test_method='__no__default__', with_jit: bool = True, without_jit: bool = True, with_device: bool = True, without_device: bool = True, with_pmap: bool = True) VariantsTestCaseGenerator#

Equivalent to chex.variants but with flipped defaults.

chex.params_product(*params_lists: Sequence[Sequence[Any]], named: bool = False) Sequence[Sequence[Any]][source]#

Generates a cartesian product of params_lists.

See tests from variants_test.py for examples of usage.

Parameters:
  • *params_lists – A list of params combinations.

  • named – Whether to generate test names (for absl.parameterized.named_parameters(…)).

Returns:

A cartesian product of params_lists combinations.