Skip to content

OmniChampagne

Solver ID: OMNI-CHAMPAGNE

Usage

from invert import Solver

# fwd = ...    (mne.Forward object)
# evoked = ... (mne.Evoked object)

solver = Solver("OMNI-CHAMPAGNE")
solver.make_inverse_operator(fwd)
stc = solver.apply_inverse_operator(evoked)
stc.plot()

Overview

Adaptive sparse Bayesian solver that selects between a dipole-only Champagne-style model and a multi-order patch-dictionary model based on penalized evidence.

References

  1. Lukas Hecker 2025, unpublished

API Reference

Bases: BaseSolver

Adaptive Champagne/Flex-like solver via model selection.

Source code in invert/solvers/bayesian/omni_champagne.py
class SolverOmniChampagne(BaseSolver):
    """Adaptive Champagne/Flex-like solver via model selection."""

    meta = SolverMeta(
        slug="omni-champagne",
        full_name="OmniChampagne",
        category="Bayesian",
        description=(
            "Adaptive sparse Bayesian solver that selects between a dipole-only "
            "Champagne-style model and a multi-order patch-dictionary model based "
            "on penalized evidence."
        ),
        references=[
            "Lukas Hecker 2025, unpublished",
        ],
    )

    def __init__(
        self,
        name: str = "OmniChampagne",
        # Patch basis settings (match simulator defaults)
        n_orders: int = 3,
        diffusion_parameter: float = 0.1,
        adjacency_type: str = "spatial",
        adjacency_distance: float = 3e-3,
        # SBL settings
        update_rule: str = "MacKay",
        max_iter: int = 2000,
        pruning_thresh: float = 1e-3,
        convergence_criterion: float = 1e-8,
        # Model selection (penalize extra active atoms)
        complexity_penalty: float = 0.2,
        **kwargs,
    ) -> None:
        self.name = name
        self.n_orders = int(n_orders)
        self.diffusion_parameter = float(diffusion_parameter)
        self.adjacency_type = str(adjacency_type)
        self.adjacency_distance = float(adjacency_distance)

        self.update_rule = str(update_rule)
        self.max_iter = int(max_iter)
        self.pruning_thresh = float(pruning_thresh)
        self.convergence_criterion = float(convergence_criterion)

        self.complexity_penalty = float(complexity_penalty)

        super().__init__(**kwargs)

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------
    def make_inverse_operator(  # type: ignore[override]
        self,
        forward,
        mne_obj,
        *args,
        alpha: str | float = "auto",
        **kwargs,
    ):
        super().make_inverse_operator(forward, *args, alpha=alpha, **kwargs)
        data = self.unpack_data_obj(mne_obj)

        inv_op = self._fit_and_build_inverse_operator(data)
        self.inverse_operators = [InverseOperator(inv_op, self.name)]
        return self

    # ------------------------------------------------------------------
    # Model selection + inverse construction
    # ------------------------------------------------------------------
    def _fit_and_build_inverse_operator(self, Y: np.ndarray) -> np.ndarray:
        n_chans, n_dipoles = self.leadfield.shape

        # Noise estimate from data covariance
        C_y = self.data_covariance(Y, center=True, ddof=1)
        alpha_noise = float(np.trace(C_y) / n_chans)
        noise_cov = alpha_noise * np.eye(n_chans)
        Y_scaled = Y

        # Model A: dipole-only dictionary
        L_dip = self.leadfield
        fit_dip = self._fit_sbl(
            L_dip,
            Y_scaled,
            noise_cov=noise_cov,
            max_iter=self.max_iter,
            pruning_thresh=self.pruning_thresh,
            conv_crit=self.convergence_criterion,
            update_rule=self.update_rule,
        )
        W_dip = self._inverse_from_fit(L_dip, fit_dip, noise_cov=noise_cov)

        # Model B: patch dictionary (multi-order diffusion basis)
        sources_full = self._build_simulator_basis(n_dipoles)
        # For the patch model we intentionally exclude the order-0 (identity)
        # basis because we evaluate a separate dipole-only model above.
        sources = (
            sources_full[n_dipoles:, :]
            if sources_full.shape[0] > n_dipoles
            else sources_full
        )
        L_patch = self.leadfield @ sources.T  # (n_chans, n_candidates)
        fit_patch = self._fit_sbl(
            L_patch,
            Y_scaled,
            noise_cov=noise_cov,
            max_iter=self.max_iter,
            pruning_thresh=self.pruning_thresh,
            conv_crit=self.convergence_criterion,
            update_rule=self.update_rule,
        )
        W_patch_coeff = self._inverse_from_fit(L_patch, fit_patch, noise_cov=noise_cov)
        # Map coefficients back to dipole space (sources.T is (n_dipoles, n_candidates))
        W_patch = (
            sources.T[:, fit_patch.active_set] @ W_patch_coeff[fit_patch.active_set]
        )

        # Penalized evidence for selection
        score_dip = fit_dip.loss + self.complexity_penalty * float(
            len(fit_dip.active_set)
        )
        score_patch = fit_patch.loss + self.complexity_penalty * float(
            len(fit_patch.active_set)
        )

        if score_patch < score_dip:
            return W_patch
        return W_dip

    # ------------------------------------------------------------------
    # Basis construction (match SimulationGenerator logic)
    # ------------------------------------------------------------------
    def _build_simulator_basis(self, n_dipoles: int) -> np.ndarray:
        """Return stacked multi-order basis S with shape (n_candidates, n_dipoles)."""
        if self.n_orders <= 1:
            return np.eye(n_dipoles)

        I = np.eye(n_dipoles)
        if self.adjacency_type == "spatial":
            adjacency = mne.spatial_src_adjacency(self.forward["src"], verbose=0)
        else:
            adjacency = mne.spatial_dist_adjacency(
                self.forward["src"], self.adjacency_distance, verbose=None
            )
        adjacency = csr_matrix(adjacency)
        G = csr_matrix(I - self.diffusion_parameter * laplacian(adjacency))

        sources = csr_matrix(I)
        for _ in range(1, self.n_orders):
            last_block = sources.toarray()[-n_dipoles:, -n_dipoles:]
            new_sources = csr_matrix(last_block) @ G
            col_max = new_sources.max(axis=0).toarray().ravel()
            col_max = np.maximum(col_max, 1e-12)
            new_sources = new_sources / col_max[np.newaxis]
            sources = vstack([sources, new_sources])

        return sources.toarray()

    # ------------------------------------------------------------------
    # Sparse Bayesian learning core
    # ------------------------------------------------------------------
    def _fit_sbl(
        self,
        L_orig: np.ndarray,
        Y_scaled: np.ndarray,
        *,
        noise_cov: np.ndarray,
        max_iter: int,
        pruning_thresh: float,
        conv_crit: float,
        update_rule: str,
    ) -> _SBLFit:
        n_chans, n_atoms = L_orig.shape
        n_times = Y_scaled.shape[1]

        L = L_orig

        gammas = np.ones(n_atoms)
        active_set = np.arange(n_atoms)

        # Start with full set
        L_act = L
        gam_act = gammas

        loss_prev = None
        for _ in range(max_iter):
            # Posterior for current active set
            Sigma_y = noise_cov + (L_act * gam_act) @ L_act.T
            Sigma_y = 0.5 * (Sigma_y + Sigma_y.T)
            Sigma_y_inv = self._robust_inv(Sigma_y)
            mu_x = (L_act.T @ Sigma_y_inv @ Y_scaled) * gam_act[:, None]

            # Update gammas
            upper = np.mean(mu_x**2, axis=1)
            L_Sigma = Sigma_y_inv @ L_act
            z_diag = np.sum(L_act * L_Sigma, axis=0)

            rule = update_rule.lower()
            if rule == "convexity" or rule == "mm":
                gam_new = np.sqrt(upper / (z_diag + 1e-20))
            elif rule == "em":
                diag_sigma_x = gam_act - gam_act**2 * z_diag
                gam_new = diag_sigma_x + upper
            else:  # MacKay default
                gam_new = upper / (gam_act * z_diag + 1e-20)

            gam_new[~np.isfinite(gam_new)] = 0.0
            gam_new = np.maximum(gam_new, 0.0)
            if float(np.linalg.norm(gam_new)) == 0.0:
                break

            # Prune
            thresh = pruning_thresh * float(gam_new.max())
            keep = np.where(gam_new > thresh)[0]
            if keep.size == 0:
                break

            active_set = active_set[keep]
            gam_act = gam_new[keep]
            L_act = L_act[:, keep]

            # Recompute after pruning for loss and convergence
            Sigma_y = noise_cov + (L_act * gam_act) @ L_act.T
            Sigma_y = 0.5 * (Sigma_y + Sigma_y.T)
            Sigma_y_inv = self._robust_inv(Sigma_y)

            data_fit = float(np.trace(Sigma_y_inv @ Y_scaled @ Y_scaled.T) / n_times)
            eigvals = np.linalg.eigvalsh(Sigma_y)
            log_det = float(np.sum(np.log(np.maximum(eigvals, 1e-20))))
            loss = data_fit + log_det

            if loss_prev is not None:
                rel_change = (loss_prev - loss) / (abs(loss_prev) + 1e-20)
                if rel_change > 0 and rel_change < conv_crit:
                    loss_prev = loss
                    break
            loss_prev = loss

        # Final loss on the last active set state
        if loss_prev is None:
            Sigma_y = noise_cov + (L_act * gam_act) @ L_act.T
            Sigma_y = 0.5 * (Sigma_y + Sigma_y.T)
            Sigma_y_inv = self._robust_inv(Sigma_y)
            data_fit = float(np.trace(Sigma_y_inv @ Y_scaled @ Y_scaled.T) / n_times)
            eigvals = np.linalg.eigvalsh(Sigma_y)
            log_det = float(np.sum(np.log(np.maximum(eigvals, 1e-20))))
            loss_prev = data_fit + log_det
        else:
            Sigma_y_inv = Sigma_y_inv

        return _SBLFit(
            active_set=active_set.astype(int, copy=False),
            gammas=gam_act.astype(float, copy=False),
            sigma_y_inv=Sigma_y_inv,
            loss=float(loss_prev),
        )

    @staticmethod
    def _inverse_from_fit(
        L: np.ndarray, fit: _SBLFit, *, noise_cov: np.ndarray
    ) -> np.ndarray:
        """Build full inverse operator W (n_atoms, n_chans) for a given dictionary."""
        n_chans, n_atoms = L.shape

        gam_full = np.zeros(n_atoms, dtype=float)
        gam_full[fit.active_set] = fit.gammas
        Gamma = np.diag(gam_full)

        Sigma_y = noise_cov + (L * gam_full) @ L.T
        Sigma_y = 0.5 * (Sigma_y + Sigma_y.T)
        Sigma_y_inv = SolverOmniChampagne._robust_inv(Sigma_y)
        return Gamma @ L.T @ Sigma_y_inv  # (n_atoms, n_chans)

    @staticmethod
    def _robust_inv(M: np.ndarray) -> np.ndarray:
        try:
            return np.linalg.inv(M)
        except np.linalg.LinAlgError:
            return np.linalg.pinv(M)

__init__

__init__(
    name: str = "OmniChampagne",
    n_orders: int = 3,
    diffusion_parameter: float = 0.1,
    adjacency_type: str = "spatial",
    adjacency_distance: float = 0.003,
    update_rule: str = "MacKay",
    max_iter: int = 2000,
    pruning_thresh: float = 0.001,
    convergence_criterion: float = 1e-08,
    complexity_penalty: float = 0.2,
    **kwargs,
) -> None
Source code in invert/solvers/bayesian/omni_champagne.py
def __init__(
    self,
    name: str = "OmniChampagne",
    # Patch basis settings (match simulator defaults)
    n_orders: int = 3,
    diffusion_parameter: float = 0.1,
    adjacency_type: str = "spatial",
    adjacency_distance: float = 3e-3,
    # SBL settings
    update_rule: str = "MacKay",
    max_iter: int = 2000,
    pruning_thresh: float = 1e-3,
    convergence_criterion: float = 1e-8,
    # Model selection (penalize extra active atoms)
    complexity_penalty: float = 0.2,
    **kwargs,
) -> None:
    self.name = name
    self.n_orders = int(n_orders)
    self.diffusion_parameter = float(diffusion_parameter)
    self.adjacency_type = str(adjacency_type)
    self.adjacency_distance = float(adjacency_distance)

    self.update_rule = str(update_rule)
    self.max_iter = int(max_iter)
    self.pruning_thresh = float(pruning_thresh)
    self.convergence_criterion = float(convergence_criterion)

    self.complexity_penalty = float(complexity_penalty)

    super().__init__(**kwargs)

make_inverse_operator

make_inverse_operator(
    forward,
    mne_obj,
    *args,
    alpha: str | float = "auto",
    **kwargs,
)
Source code in invert/solvers/bayesian/omni_champagne.py
def make_inverse_operator(  # type: ignore[override]
    self,
    forward,
    mne_obj,
    *args,
    alpha: str | float = "auto",
    **kwargs,
):
    super().make_inverse_operator(forward, *args, alpha=alpha, **kwargs)
    data = self.unpack_data_obj(mne_obj)

    inv_op = self._fit_and_build_inverse_operator(data)
    self.inverse_operators = [InverseOperator(inv_op, self.name)]
    return self