Skip to content

SubspaceSBL (SSM + NL-Champagne)

Solver ID: SUBSPACE-SBL

Usage

from invert import Solver

# fwd = ...    (mne.Forward object)
# evoked = ... (mne.Evoked object)

solver = Solver("SUBSPACE-SBL")
solver.make_inverse_operator(fwd)
stc = solver.apply_inverse_operator(evoked)
stc.plot()

Overview

Two-stage solver that detects sources with signal subspace matching and refines amplitudes/noise parameters using NL-Champagne on the reduced problem.

References

  1. Lukas Hecker 2025, unpublished

API Reference

Bases: BaseSolver

SSM source detection + NLChampagne amplitude refinement.

Source code in invert/solvers/bayesian/subspace_sbl.py
class SolverSubspaceSBL(BaseSolver):
    """SSM source detection + NLChampagne amplitude refinement."""

    meta = SolverMeta(
        slug="subspace-sbl",
        full_name="SubspaceSBL (SSM + NL-Champagne)",
        category="Bayesian",
        description=(
            "Two-stage solver that detects sources with signal subspace matching "
            "and refines amplitudes/noise parameters using NL-Champagne on the "
            "reduced problem."
        ),
        references=[
            "Lukas Hecker 2025, unpublished",
        ],
    )

    def __init__(
        self,
        name="SubspaceSBL",
        n_orders=3,
        scale_leadfield=False,
        diffusion_parameter=0.1,
        adjacency_type="spatial",
        adjacency_distance=3e-3,
        **kwargs,
    ):
        self.name = name
        self.n_orders = n_orders
        self.scale_leadfield = scale_leadfield
        self.diffusion_parameter = diffusion_parameter
        self.adjacency_type = adjacency_type
        self.adjacency_distance = adjacency_distance
        self.is_prepared = False
        super().__init__(**kwargs)

    def make_inverse_operator(
        self,
        forward,
        mne_obj,
        *args,
        alpha="auto",
        n="enhanced",
        max_iter_ssm=5,
        max_iter_nlc=500,
        lambda_reg1=0.001,
        lambda_reg2=0.0001,
        lambda_reg3=0.0,
        pruning_thresh=1e-3,
        convergence_criterion=1e-8,
        **kwargs,
    ):
        super().make_inverse_operator(forward, *args, alpha=alpha, **kwargs)
        data = self.unpack_data_obj(mne_obj)

        if not self.is_prepared:
            self._prepare_flex()

        inverse_operator = self._ssm_nlc(
            data,
            n=n,
            max_iter_ssm=max_iter_ssm,
            max_iter_nlc=max_iter_nlc,
            lambda_reg1=lambda_reg1,
            lambda_reg2=lambda_reg2,
            lambda_reg3=lambda_reg3,
            pruning_thresh=pruning_thresh,
            conv_crit=convergence_criterion,
        )
        self.inverse_operators = [InverseOperator(inverse_operator, self.name)]
        return self

    # ================================================================
    # Stage 1: SSM source detection (exact copy of SSM algorithm)
    # ================================================================

    def _prepare_flex(self):
        n_dipoles = self.leadfield.shape[1]
        I = np.identity(n_dipoles)

        self.leadfields = [deepcopy(self.leadfield)]
        self.gradients = [csr_matrix(I)]

        if self.n_orders == 0:
            self.is_prepared = True
            return

        if self.adjacency_type == "spatial":
            adjacency = mne.spatial_src_adjacency(self.forward["src"], verbose=0)
        else:
            adjacency = mne.spatial_dist_adjacency(
                self.forward["src"], self.adjacency_distance, verbose=None
            )

        LL = laplacian(adjacency)
        if self.diffusion_parameter == "auto":
            alphas = [0.05, 0.075, 0.1, 0.125, 0.15, 0.175]
            smoothing_operators = [csr_matrix(I - a * LL) for a in alphas]
        else:
            smoothing_operators = [
                csr_matrix(I - self.diffusion_parameter * LL),
            ]

        for smoothing_operator in smoothing_operators:
            for i in range(self.n_orders):
                S_i = smoothing_operator ** (i + 1)
                new_lf = self.leadfields[0] @ S_i
                new_grad = self.gradients[0] @ S_i
                if self.scale_leadfield:
                    new_lf /= np.linalg.norm(new_lf, axis=0)
                self.leadfields.append(new_lf)
                self.gradients.append(new_grad)

        for i in range(len(self.gradients)):
            row_sums = self.gradients[i].sum(axis=1).ravel()
            scaling = 1.0 / np.maximum(np.abs(np.asarray(row_sums).ravel()), 1e-12)
            self.gradients[i] = csr_matrix(
                self.gradients[i].multiply(scaling.reshape(-1, 1))
            )

        self.is_prepared = True

    def _ssm_detect(
        self,
        Y,
        n="enhanced",
        max_iter=5,
        lambda_reg1=0.001,
        lambda_reg2=0.0001,
        lambda_reg3=0.0,
    ):
        """Run SSM to detect source locations and extents.

        Returns list of (order, dipole) tuples.
        """
        n_chans, n_dipoles = self.leadfield.shape
        n_time = Y.shape[1]
        leadfields = self.leadfields

        # Determine number of sources
        if isinstance(n, str):
            n_comp = self.estimate_n_sources(Y, method=n)
        else:
            n_comp = deepcopy(n)

        # Scale per channel type
        Y_work = deepcopy(Y)
        channel_types = self.forward["info"].get_channel_types()
        for ch_type in set(channel_types):
            sel = np.where(np.array(channel_types) == ch_type)[0]
            C_ch = Y_work[sel] @ Y_work[sel].T
            scaler = np.sqrt(np.trace(C_ch)) / C_ch.shape[0]
            Y_work[sel] /= scaler

        # SSM data projection matrix
        M_Y = Y_work.T @ Y_work
        YY = M_Y + lambda_reg1 * np.trace(M_Y) * np.eye(n_time)
        P_Y = (Y_work @ np.linalg.inv(YY)) @ Y_work.T
        C = P_Y.T @ P_Y

        P_A = np.zeros((n_chans, n_chans))

        S_SSM = []
        A_q = []

        # Initial source
        S_SSM.append(self._get_source_ssm(C, P_A, leadfields, lambda_reg=lambda_reg3))
        for _ in range(1, n_comp):
            order, location = S_SSM[-1]
            A_q.append(leadfields[order][:, location])
            P_A = self._compute_projection_matrix(A_q, lambda_reg=lambda_reg2)
            S_SSM.append(
                self._get_source_ssm(C, P_A, leadfields, S_SSM, lambda_reg=lambda_reg3)
            )
        A_q.append(leadfields[S_SSM[-1][0]][:, S_SSM[-1][1]])

        # Refinement phase
        S_SSM_2 = deepcopy(S_SSM)
        if len(S_SSM_2) > 1:
            S_prev = deepcopy(S_SSM_2)
            for _j in range(max_iter):
                A_q_j = A_q.copy()
                for qq in range(n_comp):
                    A_temp = np.delete(A_q_j, qq, axis=0)
                    qq_temp = np.delete(S_SSM_2, qq, axis=0)
                    P_A = self._compute_projection_matrix(
                        A_temp, lambda_reg=lambda_reg2
                    )
                    S_SSM_2[qq] = self._get_source_ssm(
                        C, P_A, leadfields, qq_temp, lambda_reg=lambda_reg3
                    )
                    A_q_j[qq] = leadfields[S_SSM_2[qq][0]][:, S_SSM_2[qq][1]]
                if S_SSM_2 == S_prev:
                    break
                S_prev = deepcopy(S_SSM_2)

        return S_SSM_2

    def _get_source_ssm(
        self,
        C,
        P_A,
        leadfields,
        q_ignore=None,
        lambda_reg=0.0,
    ):
        if q_ignore is None:
            q_ignore = []
        n_dipoles = leadfields[0].shape[1]
        n_orders = len(leadfields)

        R = np.eye(P_A.shape[0]) - P_A
        expression = np.zeros((n_orders, n_dipoles))

        for jj in range(n_orders):
            a_s = R @ leadfields[jj]
            upper = np.einsum("ij,ij->j", a_s, C @ a_s)
            lower = np.einsum("ij,ij->j", a_s, a_s) + lambda_reg
            expression[jj] = upper / lower

        if len(q_ignore) > 0:
            for order, dipole in q_ignore:
                expression[order, dipole] = np.nan

        order, dipole = np.unravel_index(np.nanargmax(expression), expression.shape)
        return order, dipole

    @staticmethod
    def _compute_projection_matrix(A_q, lambda_reg=0.0001):
        A_q = np.stack(A_q, axis=1)
        M_A = A_q.T @ A_q
        AA = M_A + lambda_reg * np.trace(M_A) * np.eye(M_A.shape[0])
        P_A = (A_q @ np.linalg.inv(AA)) @ A_q.T
        return P_A

    # ================================================================
    # Stage 2: NLChampagne amplitude refinement
    # ================================================================

    def _nlc_refine(
        self, Y, candidates, max_iter=500, pruning_thresh=1e-3, conv_crit=1e-8
    ):
        """Run NLChampagne on detected sources to refine amplitudes.

        Parameters
        ----------
        Y : array (n_chans, n_times)
        candidates : list of (order, dipole) tuples from SSM

        Returns
        -------
        gamma_refined : array of per-source variances
        llambda : array of per-channel noise variances
        """
        n_chans = Y.shape[0]
        n_times = Y.shape[1]
        k = len(candidates)

        # Build low-rank leadfield from detected sources
        L_sel = np.stack(
            [self.leadfields[order][:, dipole] for order, dipole in candidates], axis=1
        )  # (n_chans, k)

        # Scale data
        Y_scaled = deepcopy(Y)
        Y_scaled /= abs(Y_scaled).mean() + 1e-12

        # Initialize
        alpha = np.ones(k)
        C_y = self.data_covariance(Y_scaled, center=True, ddof=1)
        llambda = np.ones(n_chans) * float(np.trace(C_y) / (n_chans * 100))

        loss_list = []
        for _ in range(max_iter):
            prev_alpha = deepcopy(alpha)

            Sigma_y = (L_sel * alpha) @ L_sel.T + np.diag(llambda)
            Sigma_y = 0.5 * (Sigma_y + Sigma_y.T)
            try:
                Sigma_y_inv = np.linalg.inv(Sigma_y)
            except np.linalg.LinAlgError:
                Sigma_y_inv = np.linalg.pinv(Sigma_y)

            # Alpha update (Convexity/MM)
            s_bar = (L_sel.T @ Sigma_y_inv @ Y_scaled) * alpha[:, None]
            z_hat = np.sum(L_sel * (Sigma_y_inv @ L_sel), axis=0)
            C_s_bar = np.sum(s_bar**2, axis=1) / n_times
            alpha = np.sqrt(C_s_bar / (z_hat + 1e-20))
            alpha[~np.isfinite(alpha)] = 0.0
            alpha = np.maximum(alpha, 0.0)

            # Lambda update (Convex Bound)
            Y_hat = L_sel @ s_bar
            residual_sq = np.sum((Y_scaled - Y_hat) ** 2, axis=1) / n_times
            diag_inv = np.diag(Sigma_y_inv)
            llambda = np.sqrt(residual_sq / (diag_inv + 1e-20))
            llambda = np.maximum(llambda, 1e-10)

            # Convergence
            with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
                sign, log_det = np.linalg.slogdet(Sigma_y)
            if sign <= 0:
                log_det = -np.inf
            summation = (
                np.sum(np.einsum("ti,ij,tj->t", Y_scaled.T, Sigma_y_inv, Y_scaled.T))
                / n_times
            )
            loss = float(log_det + summation)
            loss_list.append(loss)

            if (
                loss == float("-inf")
                or loss == float("inf")
                or np.linalg.norm(alpha) == 0
            ):
                alpha = prev_alpha
                break

            if len(loss_list) > 1:
                change = abs(1 - loss_list[-1] / (loss_list[-2] + 1e-20))
                if change < conv_crit:
                    break

        return alpha, llambda

    # ================================================================
    # Combined pipeline
    # ================================================================

    def _ssm_nlc(
        self,
        Y,
        n="enhanced",
        max_iter_ssm=5,
        max_iter_nlc=500,
        lambda_reg1=0.001,
        lambda_reg2=0.0001,
        lambda_reg3=0.0,
        pruning_thresh=1e-3,
        conv_crit=1e-8,
    ):
        n_chans, n_dipoles = self.leadfield.shape

        # Stage 1: SSM detection
        candidates = self._ssm_detect(
            Y,
            n=n,
            max_iter=max_iter_ssm,
            lambda_reg1=lambda_reg1,
            lambda_reg2=lambda_reg2,
            lambda_reg3=lambda_reg3,
        )

        # Stage 2: NLChampagne refinement
        gamma, llambda = self._nlc_refine(
            Y,
            candidates,
            max_iter=max_iter_nlc,
            pruning_thresh=pruning_thresh,
            conv_crit=conv_crit,
        )

        # Build final inverse operator
        L_sel = np.stack(
            [self.leadfields[order][:, dipole] for order, dipole in candidates], axis=1
        )
        gradients = np.stack(
            [self.gradients[order][dipole].toarray() for order, dipole in candidates],
            axis=1,
        )[0]

        # Use SBL-refined source covariance instead of identity
        Gamma = np.diag(gamma)
        Sigma_y = np.diag(llambda) + (L_sel * gamma) @ L_sel.T
        Sigma_y = 0.5 * (Sigma_y + Sigma_y.T)
        try:
            Sigma_y_inv = np.linalg.inv(Sigma_y)
        except np.linalg.LinAlgError:
            Sigma_y_inv = np.linalg.pinv(Sigma_y)

        inverse_operator = gradients.T @ Gamma @ L_sel.T @ Sigma_y_inv

        return inverse_operator

    @staticmethod
    def _robust_inv(M):
        try:
            return np.linalg.inv(M)
        except np.linalg.LinAlgError:
            return np.linalg.pinv(M)

__init__

__init__(
    name="SubspaceSBL",
    n_orders=3,
    scale_leadfield=False,
    diffusion_parameter=0.1,
    adjacency_type="spatial",
    adjacency_distance=0.003,
    **kwargs,
)
Source code in invert/solvers/bayesian/subspace_sbl.py
def __init__(
    self,
    name="SubspaceSBL",
    n_orders=3,
    scale_leadfield=False,
    diffusion_parameter=0.1,
    adjacency_type="spatial",
    adjacency_distance=3e-3,
    **kwargs,
):
    self.name = name
    self.n_orders = n_orders
    self.scale_leadfield = scale_leadfield
    self.diffusion_parameter = diffusion_parameter
    self.adjacency_type = adjacency_type
    self.adjacency_distance = adjacency_distance
    self.is_prepared = False
    super().__init__(**kwargs)

make_inverse_operator

make_inverse_operator(
    forward,
    mne_obj,
    *args,
    alpha="auto",
    n="enhanced",
    max_iter_ssm=5,
    max_iter_nlc=500,
    lambda_reg1=0.001,
    lambda_reg2=0.0001,
    lambda_reg3=0.0,
    pruning_thresh=0.001,
    convergence_criterion=1e-08,
    **kwargs,
)
Source code in invert/solvers/bayesian/subspace_sbl.py
def make_inverse_operator(
    self,
    forward,
    mne_obj,
    *args,
    alpha="auto",
    n="enhanced",
    max_iter_ssm=5,
    max_iter_nlc=500,
    lambda_reg1=0.001,
    lambda_reg2=0.0001,
    lambda_reg3=0.0,
    pruning_thresh=1e-3,
    convergence_criterion=1e-8,
    **kwargs,
):
    super().make_inverse_operator(forward, *args, alpha=alpha, **kwargs)
    data = self.unpack_data_obj(mne_obj)

    if not self.is_prepared:
        self._prepare_flex()

    inverse_operator = self._ssm_nlc(
        data,
        n=n,
        max_iter_ssm=max_iter_ssm,
        max_iter_nlc=max_iter_nlc,
        lambda_reg1=lambda_reg1,
        lambda_reg2=lambda_reg2,
        lambda_reg3=lambda_reg3,
        pruning_thresh=pruning_thresh,
        conv_crit=convergence_criterion,
    )
    self.inverse_operators = [InverseOperator(inverse_operator, self.name)]
    return self