Skip to content

Exhaustive Subspace Maximum-Likelihood

Solver ID: ExhaustiveML

Usage

from invert import Solver

# fwd = ...    (mne.Forward object)
# evoked = ... (mne.Evoked object)

solver = Solver("ExhaustiveML")
solver.make_inverse_operator(fwd)
stc = solver.apply_inverse_operator(evoked)
stc.plot()

Overview

Source localization via exhaustive maximum-likelihood subset search (k<=3) with beam search extension (k>3) and BIC model order selection.

References

  1. Wax, M., & Kailath, T. (1985). Detection of signals by information theoretic criteria. IEEE Trans. ASSP, 33(2), 387-392.

API Reference

Bases: BaseSolver

Source localization via exhaustive maximum-likelihood subset search with BIC model order selection.

For k <= k_exhaustive (default 3), evaluates ALL C(n_sources, k) subsets. For k > k_exhaustive, uses beam search extending top-B solutions from k-1.

References

[1] Wax, M., & Kailath, T. (1985). Detection of signals by information theoretic criteria. IEEE Trans. ASSP, 33(2), 387-392.

Source code in invert/solvers/music/exhaustive_subspace_ml.py
class SolverExhaustiveSubspaceML(BaseSolver):
    """Source localization via exhaustive maximum-likelihood subset search
    with BIC model order selection.

    For k <= k_exhaustive (default 3), evaluates ALL C(n_sources, k) subsets.
    For k > k_exhaustive, uses beam search extending top-B solutions from k-1.

    References
    ----------
    [1] Wax, M., & Kailath, T. (1985). Detection of signals by information
        theoretic criteria. IEEE Trans. ASSP, 33(2), 387-392.
    """

    meta = SolverMeta(
        acronym="ExhaustiveML",
        full_name="Exhaustive Subspace Maximum-Likelihood",
        category="Subspace Methods",
        description=(
            "Source localization via exhaustive maximum-likelihood subset "
            "search (k<=3) with beam search extension (k>3) and BIC model "
            "order selection."
        ),
        references=[
            "Wax, M., & Kailath, T. (1985). Detection of signals by information theoretic criteria. IEEE Trans. ASSP, 33(2), 387-392.",
        ],
    )

    def __init__(self, name="ExhaustiveML", **kwargs):
        self.name = name
        super().__init__(**kwargs)

    def make_inverse_operator(
        self,
        forward,
        mne_obj=None,
        *args,
        alpha="auto",
        noise_cov: mne.Covariance | None = None,
        k_max=5,
        k_exhaustive=3,
        beam_width=50,
        penalty_mode="bic_per_timepoint",
        **kwargs,
    ):
        super().make_inverse_operator(forward, mne_obj, *args, alpha=alpha, **kwargs)
        wf = self.prepare_whitened_forward(noise_cov)

        data = self.unpack_data_obj(mne_obj)
        data_w = wf.sensor_transform @ data  # (n_eff, T)
        G_w = wf.G_white  # (n_eff, n_sources)

        n_eff, n_sources = G_w.shape
        T = data_w.shape[1]

        Gram = G_w.T @ G_w
        # Ridge for numerical stability (rank-deficient when n_sources >> n_eff)
        eps_gram = 1e-10 * np.trace(Gram) / max(n_sources, 1)
        Gram[np.diag_indices_from(Gram)] += eps_gram
        R = G_w.T @ data_w
        Q = R @ R.T
        total_var = np.sum(data_w**2)

        def _penalty(k):
            if penalty_mode == "bic_per_timepoint":
                return k * T * np.log(n_eff)
            elif penalty_mode == "aic":
                return 2 * k * T
            elif penalty_mode == "bic_stacked":
                return k * (T + 1) * np.log(n_eff * T)
            elif penalty_mode == "bic_support_only":
                return k * np.log(n_eff * T)
            else:
                raise ValueError(f"Unknown penalty_mode: {penalty_mode}")

        def _compute_bic(scores, k):
            residuals = np.maximum(total_var - scores, 1e-30)
            return n_eff * T * np.log(residuals / (n_eff * T)) + _penalty(k)

        best_bic = np.inf
        best_subset = None
        chunk_size = 500_000

        top_subsets: list[tuple[int, ...]] = []

        # Limit exhaustive search to feasible sizes (< 2M subsets).
        from math import comb as _comb

        _max_exhaustive_subsets = 2_000_000

        for k in range(1, min(k_max, n_sources) + 1):
            if k <= k_exhaustive and _comb(n_sources, k) <= _max_exhaustive_subsets:
                idx_list = list(combinations(range(n_sources), k))
                idx_array = np.array(idx_list, dtype=np.intp)

                all_score_chunks = []
                for start in range(0, len(idx_array), chunk_size):
                    chunk = idx_array[start : start + chunk_size]
                    scores = _batch_score(Gram, Q, chunk)
                    all_score_chunks.append(scores)

                all_scores = np.concatenate(all_score_chunks)
                bics = _compute_bic(all_scores, k)
                local_best = np.argmin(bics)

                if bics[local_best] < best_bic:
                    best_bic = bics[local_best]
                    best_subset = idx_list[local_best]

                top_indices = np.argsort(all_scores)[-beam_width:]
                top_subsets = [idx_list[i] for i in top_indices]

            else:
                new_subsets: dict[tuple[int, ...], float] = {}
                for subset in top_subsets:
                    subset_set = set(subset)
                    remaining = [j for j in range(n_sources) if j not in subset_set]
                    if not remaining:
                        continue
                    extended = [tuple(sorted(subset + (j,))) for j in remaining]
                    ext_array = np.array(extended, dtype=np.intp)

                    for start in range(0, len(ext_array), chunk_size):
                        chunk = ext_array[start : start + chunk_size]
                        scores = _batch_score(Gram, Q, chunk)
                        for idx, sc in enumerate(scores):
                            sub = extended[start + idx]
                            if sub not in new_subsets or sc > new_subsets[sub]:
                                new_subsets[sub] = sc

                if not new_subsets:
                    break

                ext_list = list(new_subsets.keys())
                ext_scores = np.array(list(new_subsets.values()))
                bics = _compute_bic(ext_scores, k)
                local_best = np.argmin(bics)

                if bics[local_best] < best_bic:
                    best_bic = bics[local_best]
                    best_subset = ext_list[local_best]

                top_indices = np.argsort(ext_scores)[-beam_width:]
                top_subsets = [ext_list[i] for i in top_indices]

        selected_sources = np.array(best_subset, dtype=np.intp)

        G_sel = G_w[:, selected_sources]
        inv_w = np.linalg.lstsq(G_sel, np.eye(n_eff), rcond=None)[0]

        inverse_operator_w = np.zeros((n_sources, n_eff))
        inverse_operator_w[selected_sources, :] = inv_w

        self.inverse_operators = [
            InverseOperator(inverse_operator_w @ wf.sensor_transform, self.name),
        ]
        return self

__init__

__init__(name='ExhaustiveML', **kwargs)
Source code in invert/solvers/music/exhaustive_subspace_ml.py
def __init__(self, name="ExhaustiveML", **kwargs):
    self.name = name
    super().__init__(**kwargs)

make_inverse_operator

make_inverse_operator(
    forward,
    mne_obj=None,
    *args,
    alpha="auto",
    noise_cov: Covariance | None = None,
    k_max=5,
    k_exhaustive=3,
    beam_width=50,
    penalty_mode="bic_per_timepoint",
    **kwargs,
)
Source code in invert/solvers/music/exhaustive_subspace_ml.py
def make_inverse_operator(
    self,
    forward,
    mne_obj=None,
    *args,
    alpha="auto",
    noise_cov: mne.Covariance | None = None,
    k_max=5,
    k_exhaustive=3,
    beam_width=50,
    penalty_mode="bic_per_timepoint",
    **kwargs,
):
    super().make_inverse_operator(forward, mne_obj, *args, alpha=alpha, **kwargs)
    wf = self.prepare_whitened_forward(noise_cov)

    data = self.unpack_data_obj(mne_obj)
    data_w = wf.sensor_transform @ data  # (n_eff, T)
    G_w = wf.G_white  # (n_eff, n_sources)

    n_eff, n_sources = G_w.shape
    T = data_w.shape[1]

    Gram = G_w.T @ G_w
    # Ridge for numerical stability (rank-deficient when n_sources >> n_eff)
    eps_gram = 1e-10 * np.trace(Gram) / max(n_sources, 1)
    Gram[np.diag_indices_from(Gram)] += eps_gram
    R = G_w.T @ data_w
    Q = R @ R.T
    total_var = np.sum(data_w**2)

    def _penalty(k):
        if penalty_mode == "bic_per_timepoint":
            return k * T * np.log(n_eff)
        elif penalty_mode == "aic":
            return 2 * k * T
        elif penalty_mode == "bic_stacked":
            return k * (T + 1) * np.log(n_eff * T)
        elif penalty_mode == "bic_support_only":
            return k * np.log(n_eff * T)
        else:
            raise ValueError(f"Unknown penalty_mode: {penalty_mode}")

    def _compute_bic(scores, k):
        residuals = np.maximum(total_var - scores, 1e-30)
        return n_eff * T * np.log(residuals / (n_eff * T)) + _penalty(k)

    best_bic = np.inf
    best_subset = None
    chunk_size = 500_000

    top_subsets: list[tuple[int, ...]] = []

    # Limit exhaustive search to feasible sizes (< 2M subsets).
    from math import comb as _comb

    _max_exhaustive_subsets = 2_000_000

    for k in range(1, min(k_max, n_sources) + 1):
        if k <= k_exhaustive and _comb(n_sources, k) <= _max_exhaustive_subsets:
            idx_list = list(combinations(range(n_sources), k))
            idx_array = np.array(idx_list, dtype=np.intp)

            all_score_chunks = []
            for start in range(0, len(idx_array), chunk_size):
                chunk = idx_array[start : start + chunk_size]
                scores = _batch_score(Gram, Q, chunk)
                all_score_chunks.append(scores)

            all_scores = np.concatenate(all_score_chunks)
            bics = _compute_bic(all_scores, k)
            local_best = np.argmin(bics)

            if bics[local_best] < best_bic:
                best_bic = bics[local_best]
                best_subset = idx_list[local_best]

            top_indices = np.argsort(all_scores)[-beam_width:]
            top_subsets = [idx_list[i] for i in top_indices]

        else:
            new_subsets: dict[tuple[int, ...], float] = {}
            for subset in top_subsets:
                subset_set = set(subset)
                remaining = [j for j in range(n_sources) if j not in subset_set]
                if not remaining:
                    continue
                extended = [tuple(sorted(subset + (j,))) for j in remaining]
                ext_array = np.array(extended, dtype=np.intp)

                for start in range(0, len(ext_array), chunk_size):
                    chunk = ext_array[start : start + chunk_size]
                    scores = _batch_score(Gram, Q, chunk)
                    for idx, sc in enumerate(scores):
                        sub = extended[start + idx]
                        if sub not in new_subsets or sc > new_subsets[sub]:
                            new_subsets[sub] = sc

            if not new_subsets:
                break

            ext_list = list(new_subsets.keys())
            ext_scores = np.array(list(new_subsets.values()))
            bics = _compute_bic(ext_scores, k)
            local_best = np.argmin(bics)

            if bics[local_best] < best_bic:
                best_bic = bics[local_best]
                best_subset = ext_list[local_best]

            top_indices = np.argsort(ext_scores)[-beam_width:]
            top_subsets = [ext_list[i] for i in top_indices]

    selected_sources = np.array(best_subset, dtype=np.intp)

    G_sel = G_w[:, selected_sources]
    inv_w = np.linalg.lstsq(G_sel, np.eye(n_eff), rcond=None)[0]

    inverse_operator_w = np.zeros((n_sources, n_eff))
    inverse_operator_w[selected_sources, :] = inv_w

    self.inverse_operators = [
        InverseOperator(inverse_operator_w @ wf.sensor_transform, self.name),
    ]
    return self