Skip to content

Flex-NL-Champagne

Solver ID: FLEX-NL-CHAMPAGNE

Usage

from invert import Solver

# fwd = ...    (mne.Forward object)
# evoked = ... (mne.Evoked object)

solver = Solver("FLEX-NL-CHAMPAGNE")
solver.make_inverse_operator(fwd)
stc = solver.apply_inverse_operator(evoked)
stc.plot()

Overview

Two-pass refined flexible-extent Champagne variant using a multi-order diffusion dictionary with Convexity/MM updates and refinement.

References

  1. Lukas Hecker 2025, unpublished

API Reference

Bases: BaseSolver

Two-pass refined FlexChampagne with Convexity updates.

Source code in invert/solvers/bayesian/flex_nl_champagne.py
class SolverFlexNLChampagne(BaseSolver):
    """Two-pass refined FlexChampagne with Convexity updates."""

    meta = SolverMeta(
        slug="flex-nl-champagne",
        full_name="Flex-NL-Champagne",
        category="Bayesian",
        description=(
            "Two-pass refined flexible-extent Champagne variant using a multi-order "
            "diffusion dictionary with Convexity/MM updates and refinement."
        ),
        references=[
            "Lukas Hecker 2025, unpublished",
        ],
    )

    def __init__(
        self,
        name="FlexNLChampagne",
        n_orders=4,
        diffusion_parameter=0.1,
        adjacency_type="spatial",
        adjacency_distance=3e-3,
        **kwargs,
    ):
        self.name = name
        self.n_orders = n_orders
        self.diffusion_parameter = diffusion_parameter
        self.adjacency_type = adjacency_type
        self.adjacency_distance = adjacency_distance
        self.is_prepared = False
        super().__init__(**kwargs)

    def make_inverse_operator(
        self,
        forward,
        mne_obj,
        *args,
        alpha="auto",
        max_iter=2000,
        pruning_thresh=1e-3,
        convergence_criterion=1e-8,
        **kwargs,
    ):
        super().make_inverse_operator(forward, *args, alpha=alpha, **kwargs)
        data = self.unpack_data_obj(mne_obj)

        if not self.is_prepared:
            self._prepare_flex()

        inverse_operator = self._two_pass_flex(
            data, pruning_thresh, max_iter, convergence_criterion
        )
        self.inverse_operators = [InverseOperator(inverse_operator, self.name)]
        return self

    def _prepare_flex(self):
        n_dipoles = self.leadfield.shape[1]
        I = np.identity(n_dipoles)

        self.leadfields = [deepcopy(self.leadfield)]
        self.gradients = [csr_matrix(I)]

        if self.n_orders == 0:
            self.is_prepared = True
            return

        if self.adjacency_type == "spatial":
            adjacency = mne.spatial_src_adjacency(self.forward["src"], verbose=0)
        else:
            adjacency = mne.spatial_dist_adjacency(
                self.forward["src"], self.adjacency_distance, verbose=None
            )

        LL = laplacian(adjacency)
        smoothing_operator = csr_matrix(I - self.diffusion_parameter * LL)

        for i in range(self.n_orders):
            S_i = smoothing_operator ** (i + 1)
            new_lf = self.leadfields[0] @ S_i
            new_grad = self.gradients[0] @ S_i
            self.leadfields.append(new_lf)
            self.gradients.append(new_grad)

        for i in range(len(self.gradients)):
            row_sums = self.gradients[i].sum(axis=1).ravel()
            scaling = 1.0 / np.maximum(np.abs(np.asarray(row_sums).ravel()), 1e-12)
            self.gradients[i] = csr_matrix(
                self.gradients[i].multiply(scaling.reshape(-1, 1))
            )

        self.is_prepared = True

    def _run_sbl(
        self,
        L,
        Y_scaled,
        noise_cov,
        max_iter,
        pruning_thresh,
        conv_crit,
        update_rule="Convexity",
        init_gammas=None,
    ):
        """Core SBL loop. Returns (active_set, gammas)."""
        L.shape[0]
        n_times = Y_scaled.shape[1]
        n_atoms = L.shape[1]

        gammas = init_gammas if init_gammas is not None else np.ones(n_atoms)
        active_set = np.arange(n_atoms)
        L_act = deepcopy(L)

        loss_list = []

        for i_iter in range(max_iter):
            old_gammas = deepcopy(gammas)

            Sigma_y = noise_cov + (L_act * gammas) @ L_act.T
            Sigma_y = 0.5 * (Sigma_y + Sigma_y.T)
            Sigma_y_inv = self._robust_inv(Sigma_y)
            mu_x = (L_act.T @ Sigma_y_inv @ Y_scaled) * gammas[:, None]

            upper = np.mean(mu_x**2, axis=1)
            L_Sigma = Sigma_y_inv @ L_act
            z_diag = np.sum(L_act * L_Sigma, axis=0)

            if update_rule == "Convexity":
                gammas = np.sqrt(upper / (z_diag + 1e-20))
            else:  # MacKay
                gammas = upper / (gammas * z_diag + 1e-20)

            gammas[~np.isfinite(gammas)] = 0.0
            gammas = np.maximum(gammas, 0.0)

            if np.linalg.norm(gammas) == 0:
                gammas = old_gammas
                break

            # Annealed pruning: start lenient, tighten
            anneal = min(1.0, (i_iter + 1) / 50.0)
            thresh = (pruning_thresh * anneal) * gammas.max()
            keep = np.where(gammas > thresh)[0]
            if len(keep) == 0:
                gammas = old_gammas
                break
            active_set = active_set[keep]
            gammas = gammas[keep]
            L_act = L_act[:, keep]

            Sigma_y = noise_cov + (L_act * gammas) @ L_act.T
            Sigma_y = 0.5 * (Sigma_y + Sigma_y.T)
            Sigma_y_inv = self._robust_inv(Sigma_y)

            data_fit = np.trace(Sigma_y_inv @ Y_scaled @ Y_scaled.T) / n_times
            eigvals = np.linalg.eigvalsh(Sigma_y)
            log_det = float(np.sum(np.log(np.maximum(eigvals, 1e-20))))
            loss = float(data_fit + log_det)
            loss_list.append(loss)

            if len(loss_list) > 1:
                rel_change = (loss_list[-2] - loss) / (abs(loss_list[-2]) + 1e-20)
                if rel_change > 0 and rel_change < conv_crit:
                    break

        return active_set, gammas

    def _two_pass_flex(self, Y, pruning_thresh, max_iter, conv_crit):
        n_chans, n_dipoles = self.leadfield.shape
        n_times = Y.shape[1]
        n_orders = len(self.leadfields)

        # Build extended leadfield
        L_ext = np.hstack(self.leadfields)
        n_ext = L_ext.shape[1]

        # Noise estimate from data covariance
        C_y = self.data_covariance(Y, center=True, ddof=1)
        alpha_noise = float(np.trace(C_y) / n_chans)
        noise_cov = alpha_noise * np.identity(n_chans)

        Y_scaled = Y

        # === Pass 1: Identify active atoms with Convexity rule ===
        active_set_1, gammas_1 = self._run_sbl(
            L_ext,
            Y_scaled,
            noise_cov,
            max_iter=max_iter,
            pruning_thresh=pruning_thresh,
            conv_crit=conv_crit,
            update_rule="Convexity",
        )

        if len(active_set_1) == 0:
            return np.zeros((n_dipoles, n_chans))

        # === Pass 2: Refine on active atoms with fresh noise estimate ===
        # Use active atoms from pass 1 as the dictionary
        L_refined = L_ext[:, active_set_1]

        # Re-estimate noise from pass 1 residuals
        Sigma_y_1 = noise_cov + (L_refined * gammas_1) @ L_refined.T
        Sigma_y_1 = 0.5 * (Sigma_y_1 + Sigma_y_1.T)
        Sigma_y_1_inv = self._robust_inv(Sigma_y_1)
        mu_x_1 = (L_refined.T @ Sigma_y_1_inv @ Y_scaled) * gammas_1[:, None]
        residuals = Y_scaled - L_refined @ mu_x_1
        refined_noise = float(np.trace(residuals @ residuals.T) / (n_chans * n_times))
        refined_noise = max(refined_noise, 1e-10)
        noise_cov_2 = refined_noise * np.identity(n_chans)

        # Run second pass with MacKay, warm-started from pass 1 gammas
        active_set_2, gammas_2 = self._run_sbl(
            L_refined,
            Y_scaled,
            noise_cov_2,
            max_iter=max_iter,
            pruning_thresh=pruning_thresh,
            conv_crit=conv_crit,
            update_rule="MacKay",
            init_gammas=gammas_1.copy(),
        )

        # Map back to global indices
        final_active = active_set_1[active_set_2]
        final_gammas = gammas_2

        # Reconstruct
        gammas_full = np.zeros(n_ext)
        gammas_full[final_active] = final_gammas

        gammas_per_order = gammas_full.reshape(n_orders, n_dipoles)
        best_order = np.argmax(gammas_per_order, axis=0)
        best_gamma = np.max(gammas_per_order, axis=0)

        active_dipoles = np.where(best_gamma > pruning_thresh * best_gamma.max())[0]

        if len(active_dipoles) == 0:
            return np.zeros((n_dipoles, n_chans))

        L_reduced = np.stack(
            [self.leadfields[best_order[d]][:, d] for d in active_dipoles], axis=1
        )
        gamma_reduced = best_gamma[active_dipoles]

        grad_cols = []
        for d in active_dipoles:
            g = self.gradients[best_order[d]][d].toarray().ravel()
            grad_cols.append(g)
        G = np.stack(grad_cols, axis=1)

        Sigma_y_r = noise_cov_2 + (L_reduced * gamma_reduced) @ L_reduced.T
        Sigma_y_r = 0.5 * (Sigma_y_r + Sigma_y_r.T)
        Sigma_y_r_inv = self._robust_inv(Sigma_y_r)
        inv_op = G @ np.diag(gamma_reduced) @ L_reduced.T @ Sigma_y_r_inv

        return inv_op

    @staticmethod
    def _robust_inv(M):
        try:
            return np.linalg.inv(M)
        except np.linalg.LinAlgError:
            return np.linalg.pinv(M)

__init__

__init__(
    name="FlexNLChampagne",
    n_orders=4,
    diffusion_parameter=0.1,
    adjacency_type="spatial",
    adjacency_distance=0.003,
    **kwargs,
)
Source code in invert/solvers/bayesian/flex_nl_champagne.py
def __init__(
    self,
    name="FlexNLChampagne",
    n_orders=4,
    diffusion_parameter=0.1,
    adjacency_type="spatial",
    adjacency_distance=3e-3,
    **kwargs,
):
    self.name = name
    self.n_orders = n_orders
    self.diffusion_parameter = diffusion_parameter
    self.adjacency_type = adjacency_type
    self.adjacency_distance = adjacency_distance
    self.is_prepared = False
    super().__init__(**kwargs)

make_inverse_operator

make_inverse_operator(
    forward,
    mne_obj,
    *args,
    alpha="auto",
    max_iter=2000,
    pruning_thresh=0.001,
    convergence_criterion=1e-08,
    **kwargs,
)
Source code in invert/solvers/bayesian/flex_nl_champagne.py
def make_inverse_operator(
    self,
    forward,
    mne_obj,
    *args,
    alpha="auto",
    max_iter=2000,
    pruning_thresh=1e-3,
    convergence_criterion=1e-8,
    **kwargs,
):
    super().make_inverse_operator(forward, *args, alpha=alpha, **kwargs)
    data = self.unpack_data_obj(mne_obj)

    if not self.is_prepared:
        self._prepare_flex()

    inverse_operator = self._two_pass_flex(
        data, pruning_thresh, max_iter, convergence_criterion
    )
    self.inverse_operators = [InverseOperator(inverse_operator, self.name)]
    return self