Axolotl neuron regeneration

This tutorial will walk you through the different steps to train a neural network to learn a differentaition potential from spatial transcirptomics through time, using STORIES. As a demo dataset, we will use a subset of axolotl neuron regeneration from the ARTISTA dataset by Wei et al.

Imports

STORIES relies on AnnData for input/output, and on the JAX ecosystem for GPU computations.

[1]:
import os

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import optax
import scanpy as sc

import stories

Check that JAX detects the GPU. If not, refer to the JAX installation page.

[2]:
import jax

jax.devices()
[2]:
[cuda(id=0)]

Load the data

You can download the processed demo dataset in the .h5ad format from this Figshare link. The original data comes from the ARTISTA dataset.

[3]:
#!wget -O artista_traj_processed.h5ad https://figshare.com/ndownloader/files/47845111?private_link=209a3b6408ea8849d5ec

The important fields for this tutorial are the following:

  • adata.obs["time"]: a number indicating the time point. Time points do not have to be evenly spaced. Importantly, STORIES assumes one slice per timepoint.

  • adata.obs["growth"]: a score computed based on estimated proliferation and apoptosis rates, which estimated the number of descendants of a cell. STORIES can weight cells proportionally to such a score. This is optional, and by default all cells have the same weight.

  • adata.obsm["X_pca_harmony"]: a batch-corrected PCA embedding of gene expression data. This is the input of the model.

  • adata.obsm["X_isomap"]: a 2D embedding of the data. Here, we use Isomap, but you could used PHATE or UMAP instead.

  • adata.obsm["spatial"]: 2D spatial coordinates of the cells.

[4]:
# Load the data
adata = ad.read_h5ad("artista_traj_processed.h5ad")
adata
[4]:
AnnData object with n_obs × n_vars = 5904 × 10000
    obs: 'CellID', 'spatial_leiden_e30_s8', 'Batch', 'cell_id', 'seurat_clusters', 'inj_uninj', 'D_V', 'Annotation', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'n_genes', 'apoptosis', 'nsc', 'growth', 'time'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'means', 'variances', 'residual_variances', 'highly_variable_rank', 'highly_variable_nbatches', 'highly_variable_intersection', 'highly_variable'
    uns: 'Annotation_colors', 'hvg', 'neighbors', 'pca', 'umap'
    obsm: 'X_isomap', 'X_pca', 'X_pca_harmony', 'X_spatial', 'X_umap', 'spatial'
    varm: 'PCs'
    layers: 'counts'
    obsp: 'connectivities', 'distances'

Preprocess the data

In order to make the gene expression and spatial terms more comparable, we suggest some basic rescaling of the data.

[5]:
# Select a given number of principal components then normalize the embedding.
adata.obsm["X_pca_harmony"] = adata.obsm["X_pca_harmony"][:, :20]
adata.obsm["X_pca_harmony"] /= adata.obsm["X_pca_harmony"].max()
[6]:
# Center and scale each timepoint in space.
# Importantly, we have only one slice per time point.
timepoints = np.sort(np.unique(adata.obs["time"]))
adata.obsm["spatial"] = adata.obsm["spatial"].astype(float)
for t in timepoints:
    idx = adata.obs["time"] == t

    mu = np.mean(adata.obsm["spatial"][idx, :], axis=0)
    adata.obsm["spatial"][idx, :] -= mu

    std = np.std(adata.obsm["spatial"][idx, :], axis=0)
    adata.obsm["spatial"][idx, :] /= std
[7]:
fig, axes = plt.subplots(
    1, 6, figsize=(30, 5), sharey=True, sharex=True, constrained_layout=True
)
for i, t in enumerate(sorted(adata.obs["time"].unique())):
    idx = adata.obs["time"] == t
    sc.pl.embedding(
        adata[idx], "spatial", color="Annotation", s=30, ax=axes[i], show=False
    )
../_images/vignettes_demo_axolotl_16_0.png

Train the model

The most important parameter when defining the model is the relative weight of space comapred to gene expression. Here, we set it to 1e-3. Larger values will give more importance to space, and a value of 0 will ignore space.

[8]:
# Initialize the model.
model = stories.SpaceTime(quadratic_weight=1e-3)

Now, we can start to train the model. model.fit() accepts additional parameters such as the number of iterations, and the batch size. We use AdamW with a cosine decay as an optimizer. The training will stop when the validation loss stops improving, and weights corresponding to the best validation loss are kept.

[9]:
scheduler = optax.cosine_decay_schedule(1e-2, 10_000)
model.fit(
    adata=adata,
    time_key="time",
    omics_key="X_pca_harmony",
    space_key="spatial",
    weight_key="growth",
    optimizer=optax.adamw(scheduler),
    checkpoint_manager=f"{os.getcwd()}/ckpt_axolotl",
)
 58%|█████▊    | 5751/10000 [48:10<35:35,  1.99it/s, iteration=5752, train_loss=0.2511269, val_loss=0.34566662]
Met early stopping criteria, breaking...

Display the potential

STORIES learned a potential function. Let us apply it to all cells, populating the AnnData object with a new field adata.obs["potential"].

[10]:
stories.tools.compute_potential(adata, model, "X_pca_harmony")
[11]:
palette = {
    "IMN": "#FFE368",
    "dpEX": "#FF6666",
    "mpEX": "#d33f6a",
    "nptxEX": "#ef9708",
    "rIPC1": "#2CCD39",
    "rIPC2": "#7E0AD1",
    "reaEGC": "#00F2CE",
    "wntEGC": "#1ce6ff",
}
[12]:
sc.pl.embedding(
    adata,
    basis="isomap",
    color=["Annotation", "potential"],
    vmax="p98",  # Colorbar up to the 98th percentile
)
../_images/vignettes_demo_axolotl_26_0.png

Display the velocity

The opposite of the potential’s gradient points towards more differentiated states. Let us compute this velocity for all cells, populating the AnnData object with a new field adata.obsm["X_velo"].

[13]:
stories.tools.compute_velocity(adata, model, "X_pca_harmony")

STORIES provides a convenient wrapper function around CellRank to visualize this velocity. Trajectories from EGCs to mature neurons clearly emerges.

[14]:
stories.tools.plot_velocity(
    adata,
    "X_pca_harmony",
    basis="isomap",
    color="Annotation",
    palette=palette,
    s=50,
)
100%|██████████| 5904/5904 [00:02<00:00, 2396.57cell/s]
100%|██████████| 5904/5904 [00:01<00:00, 3474.23cell/s]
../_images/vignettes_demo_axolotl_31_1.png

Transcription factors

Finally, STORIES provides a function to perform a Wilcoxon rank-sum test for TF enrichment among the best-scoring genes. It expects a TSV file from the TRRUST database, which has a file for mouse and a file for human TF-target interactions.

[20]:
#!wget https://www.grnpedia.org/trrust/data/trrust_rawdata.human.tsv

Since there is no file for axolotl, the following piece of code extracts human gene names from the original annotation.

[46]:
import re


def extract_gene_name(text):
    # Regular expression patterns ordered by priority
    patterns = [
        r"(\b\w+)\s*\|",  # gene name without hint (highest priority)
        r"(\b\w+)\[hs\]",  # [hs] hint
        r"(\b\w+)\[nr\]",  # [nr] hint
        r"(\b\w+)\[.*?\]",  # any other hint
        r"(AMEX[\da-zA-Z]+)",  # AMEX code
    ]

    for pattern in patterns:
        match = re.search(pattern, text)
        if match:
            return match.group(1)

    return text  # default return if no pattern matches


adata.var["clean_name"] = [extract_gene_name(g) for g in adata.var_names]
[72]:
stories.tools.tf_enrich(
    adata, trrust_path="trrust_rawdata.human.tsv", gene_key="clean_name"
)
100%|██████████| 591/591 [00:01<00:00, 341.58it/s]
100%|██████████| 591/591 [00:01<00:00, 339.47it/s]
../_images/vignettes_demo_axolotl_48_1.png