class SolverOmniChampagne(BaseSolver):
"""Adaptive Champagne/Flex-like solver via model selection."""
meta = SolverMeta(
slug="omni-champagne",
full_name="OmniChampagne",
category="Bayesian",
description=(
"Adaptive sparse Bayesian solver that selects between a dipole-only "
"Champagne-style model and a multi-order patch-dictionary model based "
"on penalized evidence."
),
references=[
"Lukas Hecker 2025, unpublished",
],
)
def __init__(
self,
name: str = "OmniChampagne",
# Patch basis settings (match simulator defaults)
n_orders: int = 3,
diffusion_parameter: float = 0.1,
adjacency_type: str = "spatial",
adjacency_distance: float = 3e-3,
# SBL settings
update_rule: str = "MacKay",
max_iter: int = 2000,
pruning_thresh: float = 1e-3,
convergence_criterion: float = 1e-8,
# Model selection (penalize extra active atoms)
complexity_penalty: float = 0.2,
**kwargs,
) -> None:
self.name = name
self.n_orders = int(n_orders)
self.diffusion_parameter = float(diffusion_parameter)
self.adjacency_type = str(adjacency_type)
self.adjacency_distance = float(adjacency_distance)
self.update_rule = str(update_rule)
self.max_iter = int(max_iter)
self.pruning_thresh = float(pruning_thresh)
self.convergence_criterion = float(convergence_criterion)
self.complexity_penalty = float(complexity_penalty)
super().__init__(**kwargs)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def make_inverse_operator( # type: ignore[override]
self,
forward,
mne_obj,
*args,
alpha: str | float = "auto",
noise_cov: mne.Covariance | None = None,
**kwargs,
):
super().make_inverse_operator(forward, mne_obj, *args, alpha=alpha, **kwargs)
wf = self.prepare_whitened_forward(noise_cov)
data = self.unpack_data_obj(mne_obj)
data = wf.sensor_transform @ data
inv_op = self._fit_and_build_inverse_operator(data)
self.inverse_operators = [
InverseOperator(inv_op @ wf.sensor_transform, self.name)
]
return self
# ------------------------------------------------------------------
# Model selection + inverse construction
# ------------------------------------------------------------------
def _fit_and_build_inverse_operator(self, Y: np.ndarray) -> np.ndarray:
n_chans, n_dipoles = self.leadfield.shape
# Noise estimate from data covariance
C_y = self.data_covariance(Y, center=True, ddof=1)
alpha_noise = float(np.trace(C_y) / n_chans)
noise_cov = alpha_noise * np.eye(n_chans)
Y_scaled = Y
# Model A: dipole-only dictionary
L_dip = self.leadfield
fit_dip = self._fit_sbl(
L_dip,
Y_scaled,
noise_cov=noise_cov,
max_iter=self.max_iter,
pruning_thresh=self.pruning_thresh,
conv_crit=self.convergence_criterion,
update_rule=self.update_rule,
)
W_dip = self._inverse_from_fit(L_dip, fit_dip, noise_cov=noise_cov)
# Model B: patch dictionary (multi-order diffusion basis)
sources_full = self._build_simulator_basis(n_dipoles)
# For the patch model we intentionally exclude the order-0 (identity)
# basis because we evaluate a separate dipole-only model above.
sources = (
sources_full[n_dipoles:, :]
if sources_full.shape[0] > n_dipoles
else sources_full
)
L_patch = self.leadfield @ sources.T # (n_chans, n_candidates)
fit_patch = self._fit_sbl(
L_patch,
Y_scaled,
noise_cov=noise_cov,
max_iter=self.max_iter,
pruning_thresh=self.pruning_thresh,
conv_crit=self.convergence_criterion,
update_rule=self.update_rule,
)
W_patch_coeff = self._inverse_from_fit(L_patch, fit_patch, noise_cov=noise_cov)
# Map coefficients back to dipole space (sources.T is (n_dipoles, n_candidates))
W_patch = (
sources.T[:, fit_patch.active_set] @ W_patch_coeff[fit_patch.active_set]
)
# Penalized evidence for selection
score_dip = fit_dip.loss + self.complexity_penalty * float(
len(fit_dip.active_set)
)
score_patch = fit_patch.loss + self.complexity_penalty * float(
len(fit_patch.active_set)
)
if score_patch < score_dip:
return W_patch
return W_dip
# ------------------------------------------------------------------
# Basis construction (match SimulationGenerator logic)
# ------------------------------------------------------------------
def _build_simulator_basis(self, n_dipoles: int) -> np.ndarray:
"""Return stacked multi-order basis S with shape (n_candidates, n_dipoles)."""
if self.n_orders <= 1:
return np.eye(n_dipoles)
I = np.eye(n_dipoles)
if self.adjacency_type == "spatial":
adjacency = build_source_adjacency(
self.forward["src"],
adjacency_type="spatial",
adjacency_distance=self.adjacency_distance,
verbose=0,
)
else:
adjacency = mne.spatial_dist_adjacency(
self.forward["src"], self.adjacency_distance, verbose=None
)
adjacency = csr_matrix(adjacency)
G = csr_matrix(I - self.diffusion_parameter * laplacian(adjacency))
sources = csr_matrix(I)
for _ in range(1, self.n_orders):
last_block = sources.toarray()[-n_dipoles:, -n_dipoles:]
new_sources = csr_matrix(last_block) @ G
col_max = new_sources.max(axis=0).toarray().ravel()
col_max = np.maximum(col_max, 1e-12)
new_sources = new_sources / col_max[np.newaxis]
sources = vstack([sources, new_sources])
return sources.toarray()
# ------------------------------------------------------------------
# Sparse Bayesian learning core
# ------------------------------------------------------------------
def _fit_sbl(
self,
L_orig: np.ndarray,
Y_scaled: np.ndarray,
*,
noise_cov: np.ndarray,
max_iter: int,
pruning_thresh: float,
conv_crit: float,
update_rule: str,
) -> _SBLFit:
result = sbl_iterate(
L=L_orig,
Y=Y_scaled,
noise_cov=noise_cov,
update_rule=update_rule,
max_iter=max_iter,
pruning_thresh=pruning_thresh,
conv_crit=conv_crit,
)
return _SBLFit(
active_set=result.active_set.astype(int, copy=False),
gammas=result.gammas.astype(float, copy=False),
sigma_y_inv=result.Sigma_y_inv,
loss=result.loss,
)
@staticmethod
def _inverse_from_fit(
L: np.ndarray, fit: _SBLFit, *, noise_cov: np.ndarray
) -> np.ndarray:
"""Build full inverse operator W (n_atoms, n_chans) for a given dictionary."""
n_chans, n_atoms = L.shape
gam_full = np.zeros(n_atoms, dtype=float)
gam_full[fit.active_set] = fit.gammas
Gamma = np.diag(gam_full)
Sigma_y = noise_cov + (L * gam_full) @ L.T
Sigma_y = 0.5 * (Sigma_y + Sigma_y.T)
Sigma_y_inv = SolverOmniChampagne._robust_inv(Sigma_y)
return Gamma @ L.T @ Sigma_y_inv # (n_atoms, n_chans)
@staticmethod
def _robust_inv(M: np.ndarray) -> np.ndarray:
try:
return np.linalg.inv(M)
except np.linalg.LinAlgError:
return np.linalg.pinv(M)