Discretization steps (stories.steps)
Explicit step
- class stories.steps.explicit.ExplicitStep
Bases:
ProximalStepThis class implements the explicit proximal step associated with the Wasserstein distance, i.e. v = -nabla J(x), where J is a potential.
- inference_step(x: Array, a: Array, potential_fun: Callable, tau: float) Array
Performs an explicit step on the input distribution and returns the updated distribution, given a potential function.
- Parameters:
x (Array) – The input distribution of size (batch_size, n_dims)
a (Array) – The input histogram (batch_size,)
potential_fun (Callable) – A potential function.
tau (float) – The time step, which should be greater than 0.
- Returns:
The updated distribution of size (batch_size, n_dims).
- Return type:
Array
- training_step(x: Array, a: Array, potential_network: Module, potential_params: Mapping[str, Mapping[str, Any]], tau: float) Array
Performs an explicit step on the input distribution and returns the updated distribution. This function differs from the inference step in that it takes a potential network as input and returns the updated distribution.
- Parameters:
x (Array) – The input distribution of size (batch_size, n_dims)
a (Array) – The input histogram (batch_size,)
potential_network (nn.Module) – A potential function parameterized by a
network. (neural)
potential_params (optax.Params) – The parameters of the potential network.
tau (float) – The time step, which should be greater than 0.
- Returns:
The updated distribution of size (batch_size, n_dims).
- Return type:
Array
ICNN implicit step
- class stories.steps.icnn_implicit.ICNNImplicitStep(maxiter: int = 100, implicit_diff: bool = True, log_callback: Callable | None = None, tol: float = 1e-08)
Bases:
ProximalStepThis class defines an implicit proximal step corresponding to the squared Wasserstein distance, assuming the transportation plan is the identity (each cell mapped to itself). This step is “implicit” in the sense that instead of computing a velocity field it predicts the next timepoint as an argmin and thus requires solving an optimization problem.
- Parameters:
maxiter (int, optional) – The maximum number of iterations for the optimization loop. Defaults to 100.
implicit_diff (bool, optional) – Whether to differentiate implicitly through the optimization loop. Defaults to True.
log_callback (Callable, optional) – A callback used to log the proximal loss. Defaults to None.
tol (float, optional) – The tolerance for the optimization loop. Defaults to 1e-8.
- inference_step(x: Array, a: Array, potential_fun: Callable, tau: float) Array
Performs an implicit step on the input distribution and returns the updated distribution, given a potential function. If logging is available, logs the proximal cost.
- Parameters:
x (jax.Array) – The input distribution of size (batch_size, n_dims)
a (jax.Array) – The input histogram (batch_size,)
potential_fun (Callable) – A potential function.
tau (float) – The time step, which should be greater than 0.
- Returns:
The updated distribution of size (batch_size, n_dims).
- Return type:
jax.Array
- training_step(x: Array, a: Array, potential_network: Module, potential_params: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], tau: float) Array
Performs an implicit step on the input distribution and returns the updated distribution. This function differs from the inference step in that it takes a potential network as input and returns the updated distribution. Logging is not available in this function because it prevents implicit differentiation.
- Parameters:
x (jax.Array) – The input distribution of size (batch_size, n_dims)
a (jax.Array) – The input histogram (batch_size,)
potential_network (nn.Module) – A potential function parameterized by a
network. (neural)
potential_params (optax.Params) – The parameters of the potential network.
tau (float) – The time step, which should be greater than 0.
- Returns:
The updated distribution of size (batch_size, n_dims).
- Return type:
jax.Array