Skip to content

tracksdata.functional

Functional utilities for graph operations.

Classes:

Functions:

TilingScheme dataclass

TilingScheme(
    tile_shape: tuple[S, ...],
    overlap_shape: tuple[S, ...],
    attrs: list[str] | None = None,
)

Tiling scheme for the graph. Graph will be sliced with 'tile_shape' + 2 * 'overlap_shape' per axis.

Parameters:

  • tile_shape

    (tuple[S, ...]) –

    The shape of the tile.

  • overlap_shape

    (tuple[S, ...]) –

    The overlap between tiles PER SIDE.

  • attrs

    (list[str] | None, default: None ) –

    The attributes to include in the tile. If None, all attributes will be included. By default "t", "z", "y", "x" are included. If some columns are not present, they will be ignored.

ancestral_connected_edges

ancestral_connected_edges(
    input_graph: BaseGraph,
    reference_graph: BaseGraph,
    match: bool = True,
) -> list[int]

Let an ancestral path be any sequence from (target, source)-edges in the reference_graph. This function returns the subset of edges in the input_graph that are part of an ancestral path in the reference graph.

IMPORTANT: This function updates the input_graph in place when matching with the reference_graph.

Parameters:

  • input_graph

    (BaseGraph) –

    The input graph.

  • reference_graph

    (BaseGraph) –

    The reference graph.

  • match

    (bool, default: True ) –

    Whether to match the input graph with the reference graph.

Source code in src/tracksdata/functional/_labeling.py
def ancestral_connected_edges(
    input_graph: BaseGraph,
    reference_graph: BaseGraph,
    match: bool = True,
) -> list[int]:
    """
    Let an ancestral path be any sequence from (target, source)-edges in the `reference_graph`.
    This function returns the subset of edges in the `input_graph` that are
    part of an ancestral path in the reference graph.

    IMPORTANT: This function updates the `input_graph` in place when matching
    with the `reference_graph`.

    Parameters
    ----------
    input_graph : BaseGraph
        The input graph.
    reference_graph : BaseGraph
        The reference graph.
    match : bool, optional
        Whether to match the input graph with the reference graph.
    """
    if DEFAULT_ATTR_KEYS.TRACKLET_ID not in reference_graph.node_attr_keys:
        tracklet_graph = reference_graph.assign_tracklet_ids()
    else:
        tracklet_graph = reference_graph.tracklet_graph()

    if match:
        input_graph.match(reference_graph)

    elif DEFAULT_ATTR_KEYS.MATCHED_NODE_ID not in input_graph.node_attr_keys:
        raise ValueError(
            "`ancestral_connected_edges` requires the input graph to previously matched "
            f"and have a `{DEFAULT_ATTR_KEYS.MATCHED_NODE_ID}` column when `match=False`"
        )

    in_node_attrs = input_graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.MATCHED_NODE_ID])
    ref_node_attrs = reference_graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.TRACKLET_ID])

    in_node_attrs = in_node_attrs.filter(
        pl.col(DEFAULT_ATTR_KEYS.MATCHED_NODE_ID) >= 0,
    ).join(
        ref_node_attrs,
        left_on=DEFAULT_ATTR_KEYS.MATCHED_NODE_ID,
        right_on=DEFAULT_ATTR_KEYS.NODE_ID,
        how="left",
    )

    edge_attrs = input_graph.edge_attrs(attr_keys=[])

    edge_attrs = join_node_attrs_to_edges(
        node_attrs=in_node_attrs.select(DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.TRACKLET_ID),
        edge_attrs=edge_attrs,
    )

    tracklet_ancestral_edges = _ancestral_edges(tracklet_graph)
    input_graph_ancestral_edges = _input_graph_ancestral_edges(
        edge_attrs=edge_attrs,
        ancestral_edges=tracklet_ancestral_edges,
    )

    return input_graph_ancestral_edges

apply_tiled

apply_tiled(
    graph: BaseGraph,
    tiling_scheme: TilingScheme,
    func: MapFunc,
    *,
    agg_func: None,
) -> Iterator[T]
apply_tiled(
    graph: BaseGraph,
    tiling_scheme: TilingScheme,
    func: MapFunc,
    *,
    agg_func: ReduceFunc,
) -> R
apply_tiled(
    graph: BaseGraph,
    tiling_scheme: TilingScheme,
    func: MapFunc,
    *,
    agg_func: ReduceFunc | None = None,
) -> Iterator[T] | R

Apply a function to a graph tiled by the tiling scheme. Graph will be sliced with 'tile_shape' + 2 * 'overlap_shape' per axis.

Parameters:

  • graph

    (BaseGraph) –

    The graph to apply the function to.

  • tiling_scheme

    (TilingScheme) –

    The tiling scheme to use.

  • func

    (MapFunc) –

    The function to apply to each tile. It takes two arguments: - filtered_graph_with_overlap: the subgraph inside the tile with the overlap - filtered_graph: the subgraph inside the tile without the overlap If all overlaps are 0, filtered_graph_with_overlap == filtered_graph with minimal overhead.

  • agg_func

    (ReduceFunc | None, default: None ) –

    The function to reduce the results of the function. If None, the results will be yielded.

Returns:

  • Iterator[T] | R

    The results of the function. If agg_func is provided, the results will be reduced. Otherwise, the results will be yielded.

Source code in src/tracksdata/functional/_apply.py
def apply_tiled(
    graph: BaseGraph,
    tiling_scheme: TilingScheme,
    func: MapFunc,
    *,
    agg_func: ReduceFunc | None = None,
) -> Iterator[T] | R:
    """
    Apply a function to a graph tiled by the tiling scheme.
    Graph will be sliced with 'tile_shape' + 2 * 'overlap_shape' per axis.

    Parameters
    ----------
    graph : BaseGraph
        The graph to apply the function to.
    tiling_scheme : TilingScheme
        The tiling scheme to use.
    func : MapFunc
        The function to apply to each tile.
        It takes two arguments:
        - filtered_graph_with_overlap: the subgraph inside the tile with the overlap
        - filtered_graph: the subgraph inside the tile without the overlap
        If all overlaps are 0, filtered_graph_with_overlap == filtered_graph with minimal overhead.

    agg_func : ReduceFunc | None, optional
        The function to reduce the results of the function. If None, the results will be yielded.

    Returns
    -------
    Iterator[T] | R
        The results of the function. If agg_func is provided, the results will be reduced.
        Otherwise, the results will be yielded.
    """
    # this needs to be a separate function because python behave weirdly
    # with functions with both yield and return statements
    res_generator = _yield_apply_tiled(
        graph=graph,
        tiling_scheme=tiling_scheme,
        func=func,
    )

    # if agg_func is provided, we need to reduce the results
    if agg_func is not None:
        return agg_func(res_generator)

    return res_generator

join_node_attrs_to_edges

join_node_attrs_to_edges(
    node_attrs: DataFrame,
    edge_attrs: DataFrame,
    node_id_key: str = DEFAULT_ATTR_KEYS.NODE_ID,
    source_key: str = DEFAULT_ATTR_KEYS.EDGE_SOURCE,
    target_key: str = DEFAULT_ATTR_KEYS.EDGE_TARGET,
    source_prefix: str = "source_",
    target_prefix: str = "target_",
    how: JoinStrategy = "left",
) -> pl.DataFrame

Add node attributes to edge attributes by joining on the node ID.

Parameters:

  • node_attrs

    (DataFrame) –

    Node attributes.

  • edge_attrs

    (DataFrame) –

    Edge attributes.

  • node_id_key

    (str, default: NODE_ID ) –

    The key of the node ID column.

  • source_key

    (str, default: EDGE_SOURCE ) –

    The key of the source column.

  • target_key

    (str, default: EDGE_TARGET ) –

    The key of the target column.

  • source_prefix

    (str, default: 'source_' ) –

    The prefix of the source column.

  • target_prefix

    (str, default: 'target_' ) –

    The prefix of the target column.

  • how

    (JoinStrategy, default: 'left' ) –

    The join type, where the "left" dataframe is the edge attributes and the "right" dataframe is the node attributes.

Returns:

  • DataFrame

    Edge attributes with node attributes added.

Examples:

node_attrs = pl.DataFrame({"node_id": [1, 2, 3], "a": [1, 2, 3], "b": [4, 5, 6]})
edge_attrs = pl.DataFrame({"source": [1, 2], "target": [2, 3]})
node_attr_to_edges(node_attrs, edge_attrs)
# shape: (2, 5)
# ┌──────────┬──────────┬──────────┬──────────┬────────────────┬────────────────┐
# │ source_a ┆ source_b ┆ target_a ┆ target_b ┆ source_node_id ┆ target_node_id │
# │ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---            ┆ ---            │
# │ i64      ┆ i64      ┆ i64      ┆ i64      ┆ i64            ┆ i64            │
# ╞══════════╪══════════╪══════════╪══════════╪════════════════╪════════════════╡
# │ 1        ┆ 4        ┆ 2        ┆ 5        ┆ 1              ┆ 2              │
# │ 2        ┆ 5        ┆ 3        ┆ 6        ┆ 2              ┆ 3              │
# └──────────┴──────────┴──────────┴──────────┴────────────────┴────────────────┘
Source code in src/tracksdata/functional/_edges.py
def join_node_attrs_to_edges(
    node_attrs: pl.DataFrame,
    edge_attrs: pl.DataFrame,
    node_id_key: str = DEFAULT_ATTR_KEYS.NODE_ID,
    source_key: str = DEFAULT_ATTR_KEYS.EDGE_SOURCE,
    target_key: str = DEFAULT_ATTR_KEYS.EDGE_TARGET,
    source_prefix: str = "source_",
    target_prefix: str = "target_",
    how: JoinStrategy = "left",
) -> pl.DataFrame:
    """
    Add node attributes to edge attributes by joining on the node ID.

    Parameters
    ----------
    node_attrs : pl.DataFrame
        Node attributes.
    edge_attrs : pl.DataFrame
        Edge attributes.
    node_id_key : str, optional
        The key of the node ID column.
    source_key : str, optional
        The key of the source column.
    target_key : str, optional
        The key of the target column.
    source_prefix : str, optional
        The prefix of the source column.
    target_prefix : str, optional
        The prefix of the target column.
    how : JoinStrategy, optional
        The join type, where the "left" dataframe is the edge attributes and
        the "right" dataframe is the node attributes.

    Returns
    -------
    pl.DataFrame
        Edge attributes with node attributes added.

    Examples
    --------
    ```python
    node_attrs = pl.DataFrame({"node_id": [1, 2, 3], "a": [1, 2, 3], "b": [4, 5, 6]})
    edge_attrs = pl.DataFrame({"source": [1, 2], "target": [2, 3]})
    node_attr_to_edges(node_attrs, edge_attrs)
    # shape: (2, 5)
    # ┌──────────┬──────────┬──────────┬──────────┬────────────────┬────────────────┐
    # │ source_a ┆ source_b ┆ target_a ┆ target_b ┆ source_node_id ┆ target_node_id │
    # │ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---            ┆ ---            │
    # │ i64      ┆ i64      ┆ i64      ┆ i64      ┆ i64            ┆ i64            │
    # ╞══════════╪══════════╪══════════╪══════════╪════════════════╪════════════════╡
    # │ 1        ┆ 4        ┆ 2        ┆ 5        ┆ 1              ┆ 2              │
    # │ 2        ┆ 5        ┆ 3        ┆ 6        ┆ 2              ┆ 3              │
    # └──────────┴──────────┴──────────┴──────────┴────────────────┴────────────────┘
    ```
    """
    node_attr_keys = node_attrs.columns
    node_attr_keys.remove(node_id_key)

    source_mapping = dict(zip(node_attr_keys, [f"{source_prefix}{c}" for c in node_attr_keys], strict=False))
    target_mapping = dict(zip(node_attr_keys, [f"{target_prefix}{c}" for c in node_attr_keys], strict=False))

    edge_attrs = edge_attrs.join(
        node_attrs.select(node_id_key, *node_attr_keys).rename(source_mapping),
        left_on=source_key,
        right_on=node_id_key,
        how=how,
    ).join(
        node_attrs.select(node_id_key, *node_attr_keys).rename(target_mapping),
        left_on=target_key,
        right_on=node_id_key,
        how=how,
    )

    return edge_attrs

rx_digraph_to_napari_dict

rx_digraph_to_napari_dict(
    tracklet_graph: PyDiGraph,
) -> dict[int, list[int]]

Convert a tracklet graph to a napari-ready dictionary. The input is a (child -> parent) graph (forward in time) and it is converted to a (parent -> child) dictionary (backward in time).

Parameters

tracklet_graph : rx.PyDiGraph The tracklet graph to convert.

Returns

dict[int, list[int]] A dictionary of parent -> child relationships.

Source code in src/tracksdata/functional/_napari.py
def rx_digraph_to_napari_dict(
    tracklet_graph: rx.PyDiGraph,
) -> dict[int, list[int]]:
    """
    Convert a tracklet graph to a napari-ready dictionary.
    The input is a (child -> parent) graph (forward in time) and it is converted
    to a (parent -> child) dictionary (backward in time).

    Parameters
    ----------
    tracklet_graph : rx.PyDiGraph
        The tracklet graph to convert.

    Returns
    -------
    dict[int, list[int]]
        A dictionary of parent -> child relationships.
    """
    dict_graph = {}
    for parent, child in tracklet_graph.edges():
        dict_graph.setdefault(child, []).append(parent)
    return dict_graph

to_napari_format

to_napari_format(
    graph: BaseGraph,
    shape: tuple[int, ...] | None,
    solution_key: str | None,
    output_tracklet_id_key: str,
    mask_key: None,
) -> tuple[pl.DataFrame, dict[int, int]]
to_napari_format(
    graph: BaseGraph,
    shape: tuple[int, ...] | None,
    solution_key: str | None,
    output_tracklet_id_key: str,
    mask_key: str,
) -> tuple[pl.DataFrame, dict[int, int], GraphArrayView]
to_napari_format(
    graph: BaseGraph,
    shape: tuple[int, ...] | None = None,
    solution_key: str | None = DEFAULT_ATTR_KEYS.SOLUTION,
    output_tracklet_id_key: str = DEFAULT_ATTR_KEYS.TRACKLET_ID,
    mask_key: str | None = None,
    chunk_shape: tuple[int] | None = None,
    buffer_cache_size: int | None = None,
) -> (
    tuple[pl.DataFrame, dict[int, int], GraphArrayView]
    | tuple[pl.DataFrame, dict[int, int]]
)

Convert the subgraph of solution nodes to a napari-ready format.

This includes: - a tracks layer with the solution tracks - a graph with the parent-child relationships for the solution tracks - a labels layer with the solution nodes if mask_key is provided.

IMPORTANT: This function will reset the track ids if they already exist.

Parameters:

  • graph

    (BaseGraph) –

    The graph to convert.

  • shape

    (tuple[int, ...] | None, default: None ) –

    The shape of the labels layer. If None, the shape is inferred from the graph metadata shape key.

  • solution_key

    (str, default: SOLUTION ) –

    The key of the solution attribute. If None, the graph is not filtered by the solution attribute.

  • output_tracklet_id_key

    (str, default: TRACKLET_ID ) –

    The key of the output track id attribute.

  • mask_key

    (str | None, default: None ) –

    The key of the mask attribute.

  • chunk_shape

    (tuple[int] | None, default: None ) –

    The chunk shape for the labels layer. If None, the default chunk size is used.

  • buffer_cache_size

    (int, default: None ) –

    The maximum number of buffers to keep in the cache for the labels layer. If None, the default buffer cache size is used.

Examples:

labels = ...
graph = ...
tracks_data, dict_graph, array_view = to_napari_format(graph, labels.shape, mask_key="mask")

Returns:

  • tuple[DataFrame, dict[int, int], GraphArrayView] | tuple[DataFrame, dict[int, int]]
    • tracks_data: The tracks data as a polars DataFrame.
    • dict_graph: A dictionary of parent -> child relationships.
    • array_view: The array view of the solution graph if mask_key is provided.
Source code in src/tracksdata/functional/_napari.py
def to_napari_format(
    graph: BaseGraph,
    shape: tuple[int, ...] | None = None,
    solution_key: str | None = DEFAULT_ATTR_KEYS.SOLUTION,
    output_tracklet_id_key: str = DEFAULT_ATTR_KEYS.TRACKLET_ID,
    mask_key: str | None = None,
    chunk_shape: tuple[int] | None = None,
    buffer_cache_size: int | None = None,
) -> (
    tuple[
        pl.DataFrame,
        dict[int, int],
        "GraphArrayView",
    ]
    | tuple[
        pl.DataFrame,
        dict[int, int],
    ]
):
    """
    Convert the subgraph of solution nodes to a napari-ready format.

    This includes:
    - a tracks layer with the solution tracks
    - a graph with the parent-child relationships for the solution tracks
    - a labels layer with the solution nodes if `mask_key` is provided.

    IMPORTANT: This function will reset the track ids if they already exist.

    Parameters
    ----------
    graph : BaseGraph
        The graph to convert.
    shape : tuple[int, ...] | None, optional
        The shape of the labels layer. If None, the shape is inferred from the graph metadata `shape` key.
    solution_key : str, optional
        The key of the solution attribute. If None, the graph is not filtered by the solution attribute.
    output_tracklet_id_key : str, optional
        The key of the output track id attribute.
    mask_key : str | None, optional
        The key of the mask attribute.
    chunk_shape : tuple[int] | None, optional
        The chunk shape for the labels layer. If None, the default chunk size is used.
    buffer_cache_size : int, optional
        The maximum number of buffers to keep in the cache for the labels layer.
        If None, the default buffer cache size is used.

    Examples
    --------

    ```python
    labels = ...
    graph = ...
    tracks_data, dict_graph, array_view = to_napari_format(graph, labels.shape, mask_key="mask")
    ```

    Returns
    -------
    tuple[pl.DataFrame, dict[int, int], GraphArrayView] | tuple[pl.DataFrame, dict[int, int]]
        - tracks_data: The tracks data as a polars DataFrame.
        - dict_graph: A dictionary of parent -> child relationships.
        - array_view: The array view of the solution graph if `mask_key` is provided.
    """
    if solution_key is not None:
        solution_graph = graph.filter(
            NodeAttr(solution_key) == True,
            EdgeAttr(solution_key) == True,
        ).subgraph()

    else:
        solution_graph = graph

    shape = _validate_shape(shape, solution_graph, "to_napari_format")

    tracks_graph = solution_graph.assign_tracklet_ids(output_tracklet_id_key)
    dict_graph = {tracks_graph[child]: tracks_graph[parent] for parent, child in tracks_graph.edge_list()}

    spatial_cols = ["z", "y", "x"][-len(shape) + 1 :]

    tracks_data = solution_graph.node_attrs(
        attr_keys=[output_tracklet_id_key, DEFAULT_ATTR_KEYS.T, *spatial_cols],
    )

    # sorting columns
    tracks_data = tracks_data.select([output_tracklet_id_key, DEFAULT_ATTR_KEYS.T, *spatial_cols])

    if mask_key is not None:
        from tracksdata.array._graph_array import GraphArrayView

        array_view = GraphArrayView(
            solution_graph,
            shape=shape,
            attr_key=output_tracklet_id_key,
            chunk_shape=chunk_shape,
            buffer_cache_size=buffer_cache_size,
        )

        return tracks_data, dict_graph, array_view

    return tracks_data, dict_graph