diff --git a/cmvdr/data_gen/f0_manager.py b/cmvdr/data_gen/f0_manager.py index 43b9f67..9ef251a 100644 --- a/cmvdr/data_gen/f0_manager.py +++ b/cmvdr/data_gen/f0_manager.py @@ -648,9 +648,43 @@ def compute_harmonic_and_modulation_sets_global_coherence(cls, sig, harmonic_fre max_freq_cyclic_hz=cfg_cyc['freq_range_cyclic'][1]) max_bin = -1 - if harmonic_freqs_est.size > 0: - max_bin = int(np.ceil((3 * SFT.delta_f + np.max(np.abs(harmonic_freqs_est))) / SFT.delta_f)) - rho = CoherenceManager.compute_coherence(sig, SFT, mod_coherence, max_bin, min_relative_power=1.e+3) + # if harmonic_freqs_est.size > 0: + # max_bin = int(np.ceil((3 * SFT.delta_f + np.max(np.abs(harmonic_freqs_est))) / SFT.delta_f)) + + # Choose coherence computation method based on config + use_freq_domain = cfg_cyc.get('use_freq_domain_coherence', False) + + if use_freq_domain: + # Frequency-domain coherence path + use_stft = cfg_cyc.get('freq_coherence_stft_enabled', False) + interpolation = cfg_cyc.get('freq_coherence_interpolation', 'none') + apply_phase_correction = cfg_cyc.get('freq_coherence_apply_phase_correction', True) + + # Create high-res STFT if specified + nfft_coherence = cfg_cyc.get('freq_coherence_nfft', None) + if nfft_coherence is not None and use_stft: + # Create a high-resolution STFT object + from scipy.signal import ShortTimeFFT, get_window + # Use 'hann' window like the standard STFT (or extract from config if needed) + win_coherence = get_window('hann', nfft_coherence) + # Use same hop as original, or scale proportionally + hop_coherence = SFT.hop + SFT_coherence = ShortTimeFFT(win_coherence, hop_coherence, fs=SFT.fs, + mfft=nfft_coherence, scale_to=SFT.scaling) + else: + SFT_coherence = SFT + + rho = CoherenceManager.compute_coherence_freq_shifted( + sig, SFT_coherence, mod_coherence.alpha_vec_hz_, max_bin, + min_relative_power=1.e+3, + use_stft=use_stft, + interpolation=interpolation, + apply_phase_correction=apply_phase_correction + ) + else: + # Time-domain modulation path (original) + rho = CoherenceManager.compute_coherence(sig, SFT, mod_coherence, max_bin, min_relative_power=1.e+3) + if 0: cc0 = np.where(mod_coherence.alpha_vec_hz_ == 0)[0][0] rho_no0 = np.delete(rho, cc0, axis=0) @@ -658,10 +692,28 @@ def compute_harmonic_and_modulation_sets_global_coherence(cls, sig, harmonic_fre CoherenceManager.plot_coherence_matrix(rho_no0, alpha_no0, SFT) # retain highly coherent modulated components only - harm_info = CoherenceManager.calculate_harmonic_info_from_coherence(mod_coherence.alpha_vec_hz_, rho, - thr=cfg_cyc['harmonic_threshold'], - P_max_cfg=cfg_cyc['P_max'], - nfft_real=SFT.mfft // 2 + 1) + # Determine if frequency resolution mapping is needed + if use_freq and (nfft_coherence is not None and use_stft): + # Using high-res STFT for coherence + delta_f_coherence = SFT_coherence.delta_f + delta_f_beamforming = SFT.delta_f + elif use_freq and not use_stft: + # Using full-file DFT for coherence + delta_f_coherence = SFT.fs / len(sig) + delta_f_beamforming = SFT.delta_f + else: + # Time-domain method - coherence computed at beamforming resolution + delta_f_coherence = None + delta_f_beamforming = None + + harm_info = CoherenceManager.calculate_harmonic_info_from_coherence( + mod_coherence.alpha_vec_hz_, rho, + thr=cfg_cyc['harmonic_threshold'], + P_max_cfg=cfg_cyc['P_max'], + nfft_real=SFT.mfft // 2 + 1, + delta_f_coherence=delta_f_coherence, + delta_f_beamforming=delta_f_beamforming + ) mod_amount = F0ChangeAmount.small return harm_info, mod_amount diff --git a/cmvdr/estimation/coherence_manager.py b/cmvdr/estimation/coherence_manager.py index da673e2..fab0e9e 100644 --- a/cmvdr/estimation/coherence_manager.py +++ b/cmvdr/estimation/coherence_manager.py @@ -37,7 +37,11 @@ def compute_coherence(signal, SFT: ShortTimeFFT, modulator_obj: Modulator, max_b cc0 = np.where(alpha == 0)[0][0] if max_bin == -1: - max_bin = int(np.ceil((3 * SFT.delta_f + np.max(np.abs(alpha))) / SFT.delta_f)) + # Calculate the maximum frequency to analyze + # Include all alpha values plus a safety margin + max_freq = np.max(np.abs(alpha)) + 3 * SFT.delta_f + # Convert to bins + max_bin = int(np.ceil(max_freq / SFT.delta_f)) # Modulate the data in the time domain (first microphone only) assert g.mic0_idx == 0 @@ -61,6 +65,216 @@ def compute_coherence(signal, SFT: ShortTimeFFT, modulator_obj: Modulator, max_b return rho + @staticmethod + def compute_coherence_freq_shifted(signal, SFT: ShortTimeFFT, alpha_vec_hz, max_bin=-1, + min_relative_power=1.e+3, use_stft=False, + interpolation='none', apply_phase_correction=True): + """ + Compute coherence using frequency-domain shifts instead of time-domain modulation. + + Parameters + ---------- + signal : dict + Signal dictionary with 'time' key containing time-domain signal + SFT : ShortTimeFFT + STFT object for transform parameters + alpha_vec_hz : np.ndarray + Vector of cyclic frequencies (shifts) in Hz + max_bin : int + Maximum frequency bin to process (-1 for auto) + min_relative_power : float + Minimum relative power threshold + use_stft : bool + If True, use high-res STFT; if False, use full-file DFT + interpolation : str + Interpolation mode: 'none', 'linear', or 'lagrange8' + apply_phase_correction : bool + Apply frame-start phase correction + + Returns + ------- + rho : np.ndarray + Coherence matrix (P_sum x kk_max) + """ + + # Sort alpha values to match behavior of Modulator (which sorts internally) + # This ensures consistent ordering with the time-domain method + from cmvdr.estimation.modulator import Modulator + alpha_vec_hz_sorted, alpha_inv = Modulator.unique_with_relative_tolerance_fast( + alpha_vec_hz, tol=1e-4, return_inverse=True + ) + + cc0 = np.where(alpha_vec_hz_sorted == 0)[0][0] + + # Get reference signal (first microphone only) + assert g.mic0_idx == 0 + sig_time = signal['time'][g.mic0_idx, :] + + # Calculate max_bin based on the actual transform being used + if max_bin == -1: + if use_stft: + # Use the delta_f from the STFT object that will be used + delta_f = SFT.delta_f + else: + # For full-file DFT, delta_f = fs / N + delta_f = SFT.fs / len(sig_time) + + # Calculate the maximum frequency to analyze (consistent across methods) + # Use the STFT's delta_f for the safety margin to maintain consistency with time-domain method + max_freq = np.max(np.abs(alpha_vec_hz_sorted)) + 3 * SFT.delta_f + # Convert to bins using the actual delta_f for this method + max_bin = int(np.ceil(max_freq / delta_f)) + + if use_stft: + # High-resolution STFT for coherence estimation + spec_ref = SFT.stft(sig_time)[:max_bin, :] # (K_max, L_frames) + else: + # Full-file DFT + spec_ref_full = np.fft.fft(sig_time) + spec_ref = spec_ref_full[:max_bin, np.newaxis] # (K_max, 1) - single "frame" + + P_sum = len(alpha_vec_hz_sorted) + kk_max = spec_ref.shape[0] + frames = spec_ref.shape[1] + + # Allocate output for shifted spectra + mod = np.zeros((P_sum, kk_max, frames), dtype=np.complex128) + + # Compute shifted versions in frequency domain + for pp, alpha_pp in enumerate(alpha_vec_hz_sorted): + if np.abs(alpha_pp) < 1e-9: + # No shift - just copy reference + mod[pp, :, :] = spec_ref + else: + # Shift by alpha_pp Hz + mod[pp, :, :] = CoherenceManager._shift_spectrum( + spec_ref, alpha_pp, SFT.delta_f, SFT.fs, + interpolation, apply_phase_correction, SFT if use_stft else None + ) + + mod_c = np.conj(mod) + + # Calculate PSDs + psds = np.mean(np.abs(mod) ** 2, axis=-1) # (P_sum, kk_max) + psds = np.maximum(psds, np.max(psds[cc0]) / min_relative_power) + + # Compute coherence + rho = CoherenceManager.compute_coherence_internal_fast(mod, mod_c, psds, alpha_vec_hz_sorted, cc0, SFT.delta_f, SFT.fs) + rho[cc0] = 1 + + return rho + + @staticmethod + def _shift_spectrum(spec, alpha_hz, delta_f, fs, interpolation='none', + apply_phase_correction=True, SFT=None): + """ + Shift spectrum by alpha_hz using interpolation. + + Parameters + ---------- + spec : np.ndarray + Input spectrum (kk_max, frames) + alpha_hz : float + Frequency shift in Hz + delta_f : float + Frequency resolution + fs : float + Sample rate + interpolation : str + 'none', 'linear', or 'lagrange8' + apply_phase_correction : bool + Apply phase correction for frame starts + SFT : ShortTimeFFT or None + STFT object for phase correction + + Returns + ------- + shifted_spec : np.ndarray + Shifted spectrum (kk_max, frames) + """ + kk_max, frames = spec.shape + shifted_spec = np.zeros_like(spec) + + # Alpha to bin shift (fractional) + bin_shift = alpha_hz / delta_f + + for kk in range(kk_max): + # Source bin (fractional) + src_bin_float = kk - bin_shift + + if src_bin_float < 0 or src_bin_float >= kk_max: + # Out of bounds + continue + + # Interpolate complex value at fractional bin + if interpolation == 'none': + # Nearest neighbor + src_bin = int(np.round(src_bin_float)) + if 0 <= src_bin < kk_max: + shifted_spec[kk, :] = spec[src_bin, :] + + elif interpolation == 'linear': + # Linear interpolation + src_bin_low = int(np.floor(src_bin_float)) + src_bin_high = src_bin_low + 1 + frac = src_bin_float - src_bin_low + + if 0 <= src_bin_low < kk_max: + shifted_spec[kk, :] = (1 - frac) * spec[src_bin_low, :] + if 0 <= src_bin_high < kk_max: + shifted_spec[kk, :] += frac * spec[src_bin_high, :] + + elif interpolation == 'lagrange8': + # 8-point Lagrange interpolation + shifted_spec[kk, :] = CoherenceManager._lagrange8_interpolate( + spec, src_bin_float + ) + + # Apply frame-start phase correction + if apply_phase_correction and SFT is not None and frames > 1: + for frame_idx in range(frames): + frame_start_time = frame_idx * SFT.hop / fs + phase_corr = np.exp(-1j * 2 * np.pi * alpha_hz * frame_start_time) + shifted_spec[:, frame_idx] *= phase_corr + + return shifted_spec + + @staticmethod + def _lagrange8_interpolate(spec, src_bin_float): + """ + 8-point Lagrange interpolation for complex spectrum. + + Parameters + ---------- + spec : np.ndarray + Spectrum array (kk_max, frames) + src_bin_float : float + Fractional source bin index + + Returns + ------- + interpolated : np.ndarray + Interpolated values (frames,) + """ + kk_max, frames = spec.shape + center = int(np.round(src_bin_float)) + frac = src_bin_float - center + + # 8-point window around center + result = np.zeros(frames, dtype=np.complex128) + + for offset in range(-3, 5): # -3, -2, -1, 0, 1, 2, 3, 4 + src_idx = center + offset + if 0 <= src_idx < kk_max: + # Lagrange basis polynomial + weight = 1.0 + for other_offset in range(-3, 5): + if other_offset != offset: + weight *= (frac - other_offset) / (offset - other_offset) + result += weight * spec[src_idx, :] + + return result + @staticmethod def compute_coherence_internal(mod, mod_c, psds, alpha, cc0, delta_f, fs): # Calculate spectral coherence (squared) @@ -108,26 +322,84 @@ def compute_coherence_internal_fast(mod, mod_c, psds, alpha, cc0, delta_f, fs): return rho @staticmethod - def calculate_harmonic_info_from_coherence(alpha_vec_hz, rho, thr, P_max_cfg, nfft_real) -> 'HarmonicInfo': - """ Calculate harmonic information (harmonic bins, modulation sets, harmonic sets) from coherence matrix. """ + def calculate_harmonic_info_from_coherence(alpha_vec_hz, rho, thr, P_max_cfg, nfft_real, + delta_f_coherence=None, delta_f_beamforming=None) -> 'HarmonicInfo': + """ Calculate harmonic information (harmonic bins, modulation sets, harmonic sets) from coherence matrix. + + Args: + alpha_vec_hz: Modulation frequencies in Hz + rho: Coherence matrix (len(alpha), K_coherence_bins) + thr: Coherence threshold + P_max_cfg: Maximum number of modulation frequencies per harmonic + nfft_real: Number of real FFT bins for beamforming STFT (nfft_beamforming // 2 + 1) + delta_f_coherence: Frequency resolution of coherence computation (Hz per bin). + If None, assumes coherence bins match beamforming bins. + delta_f_beamforming: Frequency resolution of beamforming STFT (Hz per bin). + If None, assumes same as delta_f_coherence. + """ harmonic_bins_ = [] modulation_sets_ = [] + + # Determine if we need to map frequency bins from coherence resolution to beamforming resolution + needs_freq_mapping = (delta_f_coherence is not None and delta_f_beamforming is not None + and delta_f_coherence != delta_f_beamforming) + + # Find the index of alpha=0 (should always be present) + cc0 = np.where(alpha_vec_hz == 0)[0] + if cc0.size == 0: + raise ValueError("alpha_vec_hz must contain 0 (non-modulated frequency)") + cc0 = cc0[0] for kk in range(rho.shape[1]): rho_kk = rho[:, kk] indices_high_corr = np.where(rho_kk > thr)[0] if indices_high_corr.size > 1: - values_high_corr = rho_kk[indices_high_corr] # Select modulations that have higher coherence - top_order = np.argsort(values_high_corr)[::-1] # Sort by coherence - final_selected = indices_high_corr[top_order][:P_max_cfg] # Keep at most P_max_cfg - final_selected_by_freq = np.r_[final_selected[0], np.sort(final_selected[1:])] - - harmonic_bins_.append(kk) - modulation_sets_.append(alpha_vec_hz[final_selected_by_freq]) - - if 0 not in alpha_vec_hz[final_selected]: + # Ensure cc0 is always included if it's above threshold + if cc0 in indices_high_corr: + # Remove cc0 from the list temporarily so we can add it back at the start + other_indices = indices_high_corr[indices_high_corr != cc0] + + if other_indices.size > 0: + # Sort other indices by coherence value + values_other = rho_kk[other_indices] + top_order = np.argsort(values_other)[::-1] + sorted_other = other_indices[top_order] + + # Take top (P_max_cfg - 1) from others, since cc0 takes one slot + top_other = sorted_other[:P_max_cfg - 1] + + # Put cc0 first, then add the best of the others + final_selected_by_freq = np.r_[cc0, np.sort(top_other)] + else: + # Only cc0 is above threshold + final_selected_by_freq = np.array([cc0]) + else: + # cc0 not above threshold (shouldn't happen since rho[cc0] = 1) + # Fall back to original logic with warning warnings.warn("0 should always be selected: non-modulated freq is perfectly coherent.") + values_high_corr = rho_kk[indices_high_corr] + top_order = np.argsort(values_high_corr)[::-1] + final_selected = indices_high_corr[top_order][:P_max_cfg] + # Force include cc0 + if cc0 not in final_selected: + final_selected[-1] = cc0 + other_indices = final_selected[final_selected != cc0] + final_selected_by_freq = np.r_[cc0, np.sort(other_indices)] + + # Map high-res coherence bin to low-res beamforming bin if needed + if needs_freq_mapping: + # Convert bin index to frequency, then to beamforming bin + freq_hz = kk * delta_f_coherence + kk_beamforming = int(np.round(freq_hz / delta_f_beamforming)) + # Ensure within bounds + if kk_beamforming >= nfft_real: + continue # Skip bins beyond beamforming resolution + harmonic_bins_.append(kk_beamforming) + else: + harmonic_bins_.append(kk) + + modulation_sets_.append(alpha_vec_hz[final_selected_by_freq]) harmonic_bins_ = np.asfortranarray(harmonic_bins_) if harmonic_bins_.size == 0: diff --git a/configs/experiments/default.yaml b/configs/experiments/default.yaml index 76e68b9..409682e 100644 --- a/configs/experiments/default.yaml +++ b/configs/experiments/default.yaml @@ -100,7 +100,12 @@ cyclic: P_max: 8 use_global_coherence: true harmonic_threshold: 0.6 - coherence_source_signal_name: noisy # using noisy signal helps for high SNR + coherence_source_signal_name: noisy + use_freq_domain_coherence: false + freq_coherence_stft_enabled: false + freq_coherence_nfft: null # null uses same nfft as main STFT, or specify a number for high-res STFT + freq_coherence_interpolation: none # none, linear, lagrange8 + freq_coherence_apply_phase_correction: true # using noisy signal helps for high SNR # Good for music. no_change too small for speech (would lead to remodulation of the signals too often) alpha_thresholds: diff --git a/configs/experiments/test_freq_coherence.yaml b/configs/experiments/test_freq_coherence.yaml new file mode 100644 index 0000000..46b729e --- /dev/null +++ b/configs/experiments/test_freq_coherence.yaml @@ -0,0 +1,28 @@ +# Test configuration for frequency-domain coherence A/B comparison +# Based on real_dregon.yaml + +base: experiments/default.yaml +varying_parameters_names: [M, rir_specs|rt60, cyclic|P_max,] +harmonics_est: + max_num_harmonics_peaks: 10 +noise: + sig_type: sample + sample_name: DREGON_individual_motors_recordings_trimmed +cov_estimation: + cov_est_forgetting_factor: 0.05 # Default: 0.05. Higher beta (forgetting factor): we forget more, therefore more weight to the NEW covariance matrices. + + +# Override to enable frequency-domain coherence +cyclic: + use_freq_domain_coherence: true + freq_coherence_stft_enabled: true + freq_coherence_interpolation: lagrange8 # Test with linear interpolation + freq_coherence_nfft: 32768 + freq_coherence_apply_phase_correction: true + harmonic_threshold: 0.8 + +# Log comparison metrics for A/B verification +metrics: + time: [ 'sisdr', 'stoi', 'pesq' ] + log_coherence_diff: true + log_harmonic_set_changes: true diff --git a/tests/test_coherence_manager.py b/tests/test_coherence_manager.py index 791989b..1d12a93 100644 --- a/tests/test_coherence_manager.py +++ b/tests/test_coherence_manager.py @@ -74,5 +74,161 @@ def test_random_noise_case(self): self.check_equivalence(mod, mod_c, psds, alpha, cc0, delta_f, fs) +class TestCalculateHarmonicInfo(unittest.TestCase): + """Tests for calculate_harmonic_info_from_coherence function.""" + + def test_zero_always_at_first_position(self): + """Test that zero modulation is always at the FIRST position in modulation sets.""" + # Create a simple coherence matrix where alpha=0 has high coherence + alpha_vec_hz = np.array([0, -100, -200, -300, -400]) + P_sum = len(alpha_vec_hz) + kk_max = 10 + + # Create coherence matrix with varying values + # Make sure alpha=0 (index 0) has high coherence + rho = np.zeros((P_sum, kk_max)) + rho[0, :] = 0.95 # alpha=0 has high coherence + rho[1, 5] = 0.85 # Some other modulations have high coherence at specific bins + rho[2, 5] = 0.75 + rho[3, 5] = 0.65 + + thr = 0.6 + P_max_cfg = 3 + nfft_real = kk_max + + # This should not raise a warning + import warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + harm_info = CoherenceManager.calculate_harmonic_info_from_coherence( + alpha_vec_hz, rho, thr, P_max_cfg, nfft_real + ) + + # Check that no warning was raised about missing zero + warning_messages = [str(warning.message) for warning in w] + zero_warnings = [msg for msg in warning_messages if "0 should always be selected" in msg] + self.assertEqual(len(zero_warnings), 0, + f"Should not warn about missing zero when it's properly selected. Warnings: {zero_warnings}") + + # Verify that zero is in the modulation sets AND at first position + for mod_set in harm_info.alpha_mods_sets: + self.assertIn(0, mod_set, "Zero should be in all modulation sets") + self.assertEqual(mod_set[0], 0, "Zero must be at the FIRST position (required by Modulator)") + + def test_zero_forced_at_first_position_when_low_coherence(self): + """Test that zero is forced to first position even when it has low coherence.""" + # Create a scenario where alpha=0 doesn't have high enough coherence + alpha_vec_hz = np.array([0, -100, -200, -300, -400]) + P_sum = len(alpha_vec_hz) + kk_max = 10 + + # Create coherence matrix where alpha=0 has LOW coherence + rho = np.zeros((P_sum, kk_max)) + rho[0, :] = 0.3 # alpha=0 has LOW coherence (below threshold) + rho[1, 5] = 0.95 # Other modulations have HIGH coherence + rho[2, 5] = 0.85 + rho[3, 5] = 0.75 + rho[4, 5] = 0.65 + + thr = 0.6 # Threshold above alpha=0's coherence + P_max_cfg = 3 + nfft_real = kk_max + + # This SHOULD raise a warning since alpha=0 won't be naturally selected + import warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + harm_info = CoherenceManager.calculate_harmonic_info_from_coherence( + alpha_vec_hz, rho, thr, P_max_cfg, nfft_real + ) + + # Check that warning WAS raised about missing zero + warning_messages = [str(warning.message) for warning in w] + zero_warnings = [msg for msg in warning_messages if "0 should always be selected" in msg] + self.assertGreater(len(zero_warnings), 0, + "Should warn when zero is not selected due to low coherence") + + # Verify that zero is STILL in the modulation sets at first position (forced inclusion) + for mod_set in harm_info.alpha_mods_sets: + self.assertIn(0, mod_set, "Zero should be in all modulation sets (forced)") + self.assertEqual(mod_set[0], 0, "Zero must be at the FIRST position even when forced") + + def test_reordering_puts_zero_first(self): + """Test that reordering always puts zero at first position regardless of coherence order.""" + # Test case where highest coherence is NOT at index 0 + alpha_vec_hz = np.array([0, -100, -200, -300]) + P_sum = len(alpha_vec_hz) + kk_max = 5 + + rho = np.zeros((P_sum, kk_max)) + # Make index 2 (-200) have highest coherence, but 0 also above threshold + rho[2, :] = 0.99 # highest + rho[0, :] = 0.85 # zero has second-highest + rho[1, :] = 0.75 + rho[3, :] = 0.65 + + thr = 0.6 + P_max_cfg = 3 + nfft_real = kk_max + + harm_info = CoherenceManager.calculate_harmonic_info_from_coherence( + alpha_vec_hz, rho, thr, P_max_cfg, nfft_real + ) + + # Zero should be at first position despite not having highest coherence + for mod_set in harm_info.alpha_mods_sets: + self.assertEqual(mod_set[0], 0, + "Zero must be at first position even when another frequency has higher coherence") + # The code sorts indices, so the order is determined by the original alpha_vec_hz array + # For alpha_vec_hz = [0, -100, -200, -300], selecting indices [2, 0, 1] sorted becomes [0, 1, 2] + # which gives values [0, -100, -200] after reordering to put 0 first + + def test_zero_selected_with_multiple_equal_coherences(self): + """Test that zero is selected even when multiple values have identical coherence (unstable sort).""" + # This is the specific bug that was reported - when multiple alpha values + # have exactly the same coherence (e.g., 1.0), numpy's argsort has unstable + # behavior and cc0 might not be in the top P_max_cfg selections. + + alpha_vec_hz = np.array([0, -100, -200, -300, -400, -500, -600, -700, -800, -900]) + P_sum = len(alpha_vec_hz) + kk_max = 10 + + # Create coherence matrix where MANY values have exactly 1.0 coherence + # This simulates the real-world case where the signal has strong harmonic structure + rho = np.zeros((P_sum, kk_max)) + + # Bin 5: First 6 alpha values all have coherence 1.0 + rho[:6, 5] = 1.0 + + # Bin 7: First 8 alpha values all have coherence 1.0 + rho[:8, 7] = 1.0 + + thr = 0.8 + P_max_cfg = 5 # Less than the number of values with 1.0 coherence + nfft_real = kk_max + + # This should NOT raise a warning - the fix ensures cc0 is always prioritized + import warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + harm_info = CoherenceManager.calculate_harmonic_info_from_coherence( + alpha_vec_hz, rho, thr, P_max_cfg, nfft_real + ) + + # Check that no warning was raised + warning_messages = [str(warning.message) for warning in w] + zero_warnings = [msg for msg in warning_messages if "0 should always be selected" in msg] + self.assertEqual(len(zero_warnings), 0, + f"Should not warn when multiple values have same coherence. Warnings: {zero_warnings}") + + # Verify that zero is in ALL modulation sets at first position + for mod_set in harm_info.alpha_mods_sets: + self.assertIn(0, mod_set, "Zero should be in all modulation sets") + self.assertEqual(mod_set[0], 0, "Zero must be at the FIRST position") + # Verify P_max_cfg constraint is respected + self.assertLessEqual(len(mod_set), P_max_cfg, + f"Modulation set should have at most {P_max_cfg} elements") + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_coherence_method_comparison.py b/tests/test_coherence_method_comparison.py new file mode 100644 index 0000000..9ec012f --- /dev/null +++ b/tests/test_coherence_method_comparison.py @@ -0,0 +1,214 @@ +""" +Test for comparing all three coherence computation methods. +""" + +import numpy as np +import unittest +from scipy.signal import ShortTimeFFT, get_window + +from cmvdr.util import globs as gs +gs.rng, _ = gs.compute_rng(seed_is_random=False, rnd_seed_=42) + +from cmvdr.estimation.coherence_manager import CoherenceManager +from cmvdr.estimation.modulator import Modulator + + +class TestCoherenceMethodComparison(unittest.TestCase): + """Test that all three coherence methods produce comparable results.""" + + def setUp(self): + """Set up test fixtures.""" + self.fs = 16000 + self.nfft = 512 + self.hop = 128 + self.win = get_window('hann', self.nfft) + self.SFT = ShortTimeFFT(self.win, self.hop, fs=self.fs, mfft=self.nfft, + scale_to='magnitude', fft_mode='twosided') + + # Create a test signal with multiple harmonics + duration = 1.0 + t = np.arange(int(duration * self.fs)) / self.fs + f0 = 200 # fundamental frequency + self.signal_time = (np.sin(2 * np.pi * f0 * t) + + 0.5 * np.sin(2 * np.pi * 2 * f0 * t) + + 0.3 * np.sin(2 * np.pi * 3 * f0 * t) + + 0.2 * np.sin(2 * np.pi * 4 * f0 * t)) + self.signal = {'time': self.signal_time[np.newaxis, :]} + + # Alpha vector (must start with 0) + self.alpha_vec_hz = np.array([0, -f0, -2*f0, -3*f0]) + + def test_all_methods_same_shape(self): + """Test that STFT-based methods produce output with the same shape.""" + + # Method 1: Time-domain modulation + STFT + max_len = self.signal['time'].shape[-1] + modulator = Modulator(max_len, self.fs, [self.alpha_vec_hz], + fast_version=True, use_filters=False, + max_freq_cyclic_hz=5000) + + rho_time = CoherenceManager.compute_coherence( + self.signal, self.SFT, modulator, + max_bin=-1, min_relative_power=1.e+3 + ) + + # Method 2: Frequency-domain with STFT + rho_freq_stft = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=-1, min_relative_power=1.e+3, + use_stft=True, interpolation='none', apply_phase_correction=False + ) + + # Method 3: Frequency-domain with full-file DFT + rho_freq_dft = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=-1, min_relative_power=1.e+3, + use_stft=False, interpolation='none', apply_phase_correction=False + ) + + # STFT methods should have the same shape + self.assertEqual(rho_time.shape, rho_freq_stft.shape, + "Time-domain and Freq-domain STFT shapes should match") + + # DFT will have different shape due to finer resolution, but same number of alphas + self.assertEqual(rho_time.shape[0], rho_freq_dft.shape[0], + "All methods should have same number of alpha values") + + # Calculate frequency ranges covered + delta_f_stft = self.SFT.delta_f + delta_f_dft = self.fs / len(self.signal_time) + + freq_range_time = rho_time.shape[1] * delta_f_stft + freq_range_stft = rho_freq_stft.shape[1] * delta_f_stft + freq_range_dft = rho_freq_dft.shape[1] * delta_f_dft + + print(f"\nMethod outputs:") + print(f" Time-domain: {rho_time.shape} covering {freq_range_time:.2f} Hz") + print(f" Freq STFT: {rho_freq_stft.shape} covering {freq_range_stft:.2f} Hz") + print(f" Freq DFT: {rho_freq_dft.shape} covering {freq_range_dft:.2f} Hz") + + # All methods should cover approximately the same frequency range + self.assertAlmostEqual(freq_range_time, freq_range_stft, delta=1.0) + self.assertAlmostEqual(freq_range_time, freq_range_dft, delta=50.0, + msg="DFT should cover approximately same frequency range as STFT") + + def test_all_methods_comparable_values(self): + """Test that all three methods produce comparable coherence values.""" + + # Method 1: Time-domain modulation + STFT + max_len = self.signal['time'].shape[-1] + modulator = Modulator(max_len, self.fs, [self.alpha_vec_hz], + fast_version=True, use_filters=False, + max_freq_cyclic_hz=5000) + + rho_time = CoherenceManager.compute_coherence( + self.signal, self.SFT, modulator, + max_bin=-1, min_relative_power=1.e+3 + ) + + # Method 2: Frequency-domain with STFT + rho_freq_stft = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=-1, min_relative_power=1.e+3, + use_stft=True, interpolation='linear', apply_phase_correction=True + ) + + # Method 3: Frequency-domain with full-file DFT + rho_freq_dft = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=-1, min_relative_power=1.e+3, + use_stft=False, interpolation='linear', apply_phase_correction=False + ) + + print("\n" + "="*70) + print("COHERENCE METHOD COMPARISON") + print("="*70) + + # Find index of alpha=0 (should be last after sorting) + cc0_idx = len(self.alpha_vec_hz) - 1 + + print(f"\nAlpha=0 index: {cc0_idx}") + print(f"Alpha values (sorted): {modulator.alpha_vec_hz_}") + + # Check alpha=0 coherence (should be 1.0 for all methods) + print(f"\nAlpha=0 coherence (should be ~1.0):") + print(f" Time-domain: {rho_time[cc0_idx, 10]:.6f}") + print(f" Freq-domain STFT: {rho_freq_stft[cc0_idx, 10]:.6f}") + # For DFT, use equivalent bin (scale by resolution difference) + dft_bin_idx = int(10 * rho_freq_dft.shape[1] / rho_time.shape[1]) + print(f" Freq-domain DFT: {rho_freq_dft[cc0_idx, dft_bin_idx]:.6f}") + + self.assertAlmostEqual(rho_time[cc0_idx, 10], 1.0, places=5) + self.assertAlmostEqual(rho_freq_stft[cc0_idx, 10], 1.0, places=5) + self.assertAlmostEqual(rho_freq_dft[cc0_idx, dft_bin_idx], 1.0, places=5) + + # Compare overall statistics + print(f"\nOverall statistics:") + print(f" Time-domain: mean={rho_time.mean():.4f}, std={rho_time.std():.4f}") + print(f" Freq-domain STFT: mean={rho_freq_stft.mean():.4f}, std={rho_freq_stft.std():.4f}") + print(f" Freq-domain DFT: mean={rho_freq_dft.mean():.4f}, std={rho_freq_dft.std():.4f}") + + # Correlation between STFT methods (should be high since they have same shape) + corr_time_stft = np.corrcoef(rho_time.flatten(), rho_freq_stft.flatten())[0, 1] + + print(f"\nCorrelations:") + print(f" Time vs Freq-STFT: {corr_time_stft:.4f}") + + # Correlation should be reasonably high + self.assertGreater(corr_time_stft, 0.5, + "Time-domain and Freq-domain STFT should be reasonably correlated") + + # Check that all values are in valid range [0, 1] + self.assertTrue(np.all(rho_time >= -0.01)) + self.assertTrue(np.all(rho_time <= 1.01)) + self.assertTrue(np.all(rho_freq_stft >= -0.01)) + self.assertTrue(np.all(rho_freq_stft <= 1.01)) + self.assertTrue(np.all(rho_freq_dft >= -0.01)) + self.assertTrue(np.all(rho_freq_dft <= 1.01)) + + def test_high_res_stft_produces_larger_output(self): + """Test that high-resolution STFT produces more frequency bins.""" + + # Standard resolution STFT + rho_standard = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=-1, min_relative_power=1.e+3, + use_stft=True, interpolation='none', apply_phase_correction=False + ) + + # High-resolution STFT (4x resolution) + nfft_high = 2048 + win_high = get_window('hann', nfft_high) + SFT_high = ShortTimeFFT(win_high, self.hop, fs=self.fs, mfft=nfft_high, + scale_to='magnitude', fft_mode='twosided') + + rho_high = CoherenceManager.compute_coherence_freq_shifted( + self.signal, SFT_high, self.alpha_vec_hz, + max_bin=-1, min_relative_power=1.e+3, + use_stft=True, interpolation='none', apply_phase_correction=False + ) + + print(f"\n" + "="*70) + print("HIGH-RES STFT TEST") + print("="*70) + print(f"Standard STFT: nfft={self.nfft}, delta_f={self.SFT.delta_f:.2f} Hz") + print(f" Output shape: {rho_standard.shape}") + print(f"High-res STFT: nfft={nfft_high}, delta_f={SFT_high.delta_f:.2f} Hz") + print(f" Output shape: {rho_high.shape}") + + # High-res should have more bins + self.assertGreater(rho_high.shape[1], rho_standard.shape[1], + "High-resolution STFT should produce more frequency bins") + + # The ratio should be approximately nfft_high / self.nfft + ratio = rho_high.shape[1] / rho_standard.shape[1] + expected_ratio = nfft_high / self.nfft + print(f"\nBin count ratio: {ratio:.2f} (expected ~{expected_ratio:.2f})") + + # Should be within 20% of expected ratio + self.assertGreater(ratio, expected_ratio * 0.8) + self.assertLess(ratio, expected_ratio * 1.2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_freq_bin_mapping.py b/tests/test_freq_bin_mapping.py new file mode 100644 index 0000000..d9389be --- /dev/null +++ b/tests/test_freq_bin_mapping.py @@ -0,0 +1,171 @@ +""" +Tests for frequency bin mapping between high-res coherence and low-res beamforming. +""" + +import unittest +import numpy as np +from cmvdr.estimation.coherence_manager import CoherenceManager + + +class TestFrequencyBinMapping(unittest.TestCase): + """Test that high-res coherence bins are correctly mapped to low-res beamforming bins.""" + + def test_high_res_to_low_res_mapping(self): + """Test mapping from high-resolution coherence to low-resolution beamforming.""" + # Setup: High-res coherence (2048 FFT) vs low-res beamforming (512 FFT) + fs = 16000 + nfft_coherence = 2048 + nfft_beamforming = 512 + + # Frequency resolutions + delta_f_coherence = fs / nfft_coherence # 7.8125 Hz + delta_f_beamforming = fs / nfft_beamforming # 31.25 Hz + nfft_real_coherence = nfft_coherence // 2 + 1 # 1025 + nfft_real_beamforming = nfft_beamforming // 2 + 1 # 257 + + # Create simple coherence matrix with high coherence at a few high-res bins + alpha_vec_hz = np.array([0, -100, -200]) # 3 modulation frequencies + rho = np.zeros((len(alpha_vec_hz), nfft_real_coherence)) + + # Set high coherence at specific high-res bins + high_res_bins = [0, 64, 128, 256, 512] # Mix of low and high frequency bins + for bin_idx in high_res_bins: + rho[:, bin_idx] = 1.0 # All alphas have high coherence at these bins + + thr = 0.5 + P_max_cfg = 3 + + # Call with frequency mapping + harm_info = CoherenceManager.calculate_harmonic_info_from_coherence( + alpha_vec_hz, rho, thr, P_max_cfg, + nfft_real=nfft_real_beamforming, + delta_f_coherence=delta_f_coherence, + delta_f_beamforming=delta_f_beamforming + ) + + # Verify that harmonic bins are within beamforming bounds + self.assertTrue(np.all(harm_info.harmonic_bins < nfft_real_beamforming), + f"Harmonic bins {harm_info.harmonic_bins} exceed beamforming resolution {nfft_real_beamforming}") + + # Verify mapping: high-res bin -> frequency -> low-res bin + expected_low_res_bins = [] + for high_res_bin in high_res_bins: + freq_hz = high_res_bin * delta_f_coherence + low_res_bin = int(np.round(freq_hz / delta_f_beamforming)) + if low_res_bin < nfft_real_beamforming: + expected_low_res_bins.append(low_res_bin) + + expected_low_res_bins = np.array(expected_low_res_bins) + + # Check that harmonic bins match expected mapping + np.testing.assert_array_equal( + np.sort(harm_info.harmonic_bins), + np.sort(expected_low_res_bins), + err_msg="Harmonic bins don't match expected frequency mapping" + ) + + def test_no_mapping_when_resolutions_match(self): + """Test that no mapping occurs when coherence and beamforming use same resolution.""" + # Both use same resolution + fs = 16000 + nfft = 512 + delta_f = fs / nfft + nfft_real = nfft // 2 + 1 # 257 + + alpha_vec_hz = np.array([0, -100, -200]) + rho = np.zeros((len(alpha_vec_hz), nfft_real)) + + # Set high coherence at specific bins + coherent_bins = [0, 10, 20, 50, 100] + for bin_idx in coherent_bins: + rho[:, bin_idx] = 1.0 + + thr = 0.5 + P_max_cfg = 3 + + # Call with same delta_f (or None, which means same) + harm_info = CoherenceManager.calculate_harmonic_info_from_coherence( + alpha_vec_hz, rho, thr, P_max_cfg, nfft_real, + delta_f_coherence=delta_f, + delta_f_beamforming=delta_f + ) + + # Verify bins are unchanged + self.assertEqual(len(harm_info.harmonic_bins), len(coherent_bins)) + np.testing.assert_array_equal( + np.sort(harm_info.harmonic_bins), + np.array(coherent_bins), + err_msg="Bins should be unchanged when resolutions match" + ) + + def test_default_no_mapping(self): + """Test backward compatibility: no mapping when delta_f parameters not provided.""" + # Default behavior (no delta_f parameters) + nfft_real = 100 + alpha_vec_hz = np.array([0, -100, -200]) + rho = np.zeros((len(alpha_vec_hz), nfft_real)) + + # Set high coherence at specific bins + coherent_bins = [5, 15, 25, 50, 75] + for bin_idx in coherent_bins: + rho[:, bin_idx] = 1.0 + + thr = 0.5 + P_max_cfg = 3 + + # Call without delta_f parameters (backward compatible) + harm_info = CoherenceManager.calculate_harmonic_info_from_coherence( + alpha_vec_hz, rho, thr, P_max_cfg, nfft_real + ) + + # Verify bins are as provided (no mapping) + self.assertEqual(len(harm_info.harmonic_bins), len(coherent_bins)) + np.testing.assert_array_equal( + np.sort(harm_info.harmonic_bins), + np.array(coherent_bins), + err_msg="Bins should be unchanged when delta_f not provided" + ) + + def test_skip_bins_beyond_beamforming_resolution(self): + """Test that high-frequency coherence bins beyond beamforming range are skipped.""" + # High-res coherence with bins that exceed beamforming range + fs = 16000 + nfft_coherence = 2048 + nfft_beamforming = 512 + + delta_f_coherence = fs / nfft_coherence + delta_f_beamforming = fs / nfft_beamforming + nfft_real_coherence = nfft_coherence // 2 + 1 # 1025 + nfft_real_beamforming = nfft_beamforming // 2 + 1 # 257 + + alpha_vec_hz = np.array([0, -100]) + rho = np.zeros((len(alpha_vec_hz), nfft_real_coherence)) + + # Set high coherence at bins that would map beyond beamforming range + # High-res bin 1000 -> freq = 1000 * 7.8125 = 7812.5 Hz + # Low-res bin would be round(7812.5 / 31.25) = 250 + # But max low-res bin is 256, so bins close to limit should work + high_res_bins = [0, 100, 500, 900, 1000] # Last two might be beyond range + for bin_idx in high_res_bins: + if bin_idx < nfft_real_coherence: + rho[:, bin_idx] = 1.0 + + thr = 0.5 + P_max_cfg = 2 + + harm_info = CoherenceManager.calculate_harmonic_info_from_coherence( + alpha_vec_hz, rho, thr, P_max_cfg, + nfft_real=nfft_real_beamforming, + delta_f_coherence=delta_f_coherence, + delta_f_beamforming=delta_f_beamforming + ) + + # All harmonic bins should be within beamforming bounds + self.assertTrue(np.all(harm_info.harmonic_bins < nfft_real_beamforming), + f"Some harmonic bins exceed beamforming resolution") + self.assertTrue(np.all(harm_info.harmonic_bins >= 0), + f"Some harmonic bins are negative") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_freq_domain_coherence.py b/tests/test_freq_domain_coherence.py new file mode 100644 index 0000000..c5f9bd8 --- /dev/null +++ b/tests/test_freq_domain_coherence.py @@ -0,0 +1,233 @@ +import unittest +import numpy as np +from scipy.signal import ShortTimeFFT, get_window + +from cmvdr.util import globs as gs +gs.rng, _ = gs.compute_rng(seed_is_random=False, rnd_seed_=42) + +from cmvdr.estimation.coherence_manager import CoherenceManager +from cmvdr.estimation.modulator import Modulator + + +class TestFrequencyDomainCoherence(unittest.TestCase): + """Tests for frequency-domain coherence computation.""" + + def setUp(self): + """Set up test fixtures.""" + self.fs = 16000 + self.nfft = 512 + self.hop = 128 + self.win = get_window('hann', self.nfft) + self.SFT = ShortTimeFFT(self.win, self.hop, fs=self.fs, mfft=self.nfft, scale_to='magnitude') + + # Create a simple test signal with harmonics + duration = 0.5 + t = np.arange(int(duration * self.fs)) / self.fs + f0 = 200 # fundamental frequency + self.signal_time = (np.sin(2 * np.pi * f0 * t) + + 0.5 * np.sin(2 * np.pi * 2 * f0 * t) + + 0.3 * np.sin(2 * np.pi * 3 * f0 * t)) + self.signal = {'time': self.signal_time[np.newaxis, :]} + + # Alpha vector (cyclic frequencies) + self.alpha_vec_hz = np.array([0, -f0, -2*f0, -3*f0]) + + def test_compute_coherence_freq_shifted_basic(self): + """Test basic frequency-domain coherence computation.""" + rho = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=50, min_relative_power=1.e+3, + use_stft=True, interpolation='none', apply_phase_correction=False + ) + + # Check output shape + P_sum = len(self.alpha_vec_hz) + self.assertEqual(rho.shape[0], P_sum) + self.assertTrue(rho.shape[1] > 0) + + # Note: alpha values are sorted internally, so find where alpha=0 is + # For alpha_vec_hz = [0, -f0, -2*f0, -3*f0], sorted becomes [-3*f0, -2*f0, -f0, 0] + # So alpha=0 is at index 3 + cc0_idx = 3 + + # Check that rho[cc0_idx] (no shift) has high coherence with itself + self.assertAlmostEqual(rho[cc0_idx, 10], 1.0, places=5) + + # Check that values are in [0, 1] + self.assertTrue(np.all(rho >= -0.01)) + self.assertTrue(np.all(rho <= 1.01)) + + def test_interpolation_modes(self): + """Test different interpolation modes.""" + for interp in ['none', 'linear', 'lagrange8']: + with self.subTest(interpolation=interp): + rho = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=50, use_stft=True, + interpolation=interp, apply_phase_correction=False + ) + + # All should produce valid coherence matrices + self.assertEqual(rho.shape[0], len(self.alpha_vec_hz)) + self.assertTrue(np.all(np.isfinite(rho))) + + def test_full_file_dft_vs_stft(self): + """Test full-file DFT vs STFT modes.""" + rho_dft = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=50, use_stft=False, interpolation='none' + ) + + rho_stft = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=50, use_stft=True, interpolation='none' + ) + + # Both should produce valid outputs + self.assertEqual(rho_dft.shape[0], len(self.alpha_vec_hz)) + self.assertEqual(rho_stft.shape[0], len(self.alpha_vec_hz)) + self.assertTrue(np.all(np.isfinite(rho_dft))) + self.assertTrue(np.all(np.isfinite(rho_stft))) + + def test_phase_correction(self): + """Test phase correction toggle.""" + rho_with_corr = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=50, use_stft=True, + interpolation='linear', apply_phase_correction=True + ) + + rho_without_corr = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=50, use_stft=True, + interpolation='linear', apply_phase_correction=False + ) + + # Both should be valid, but may differ + self.assertTrue(np.all(np.isfinite(rho_with_corr))) + self.assertTrue(np.all(np.isfinite(rho_without_corr))) + + def test_shift_spectrum_basic(self): + """Test basic spectrum shifting.""" + # Create a simple spectrum + frames = 10 + kk_max = 50 + spec = np.random.randn(kk_max, frames) + 1j * np.random.randn(kk_max, frames) + + alpha_hz = 100 + delta_f = self.fs / self.nfft + + shifted = CoherenceManager._shift_spectrum( + spec, alpha_hz, delta_f, self.fs, + interpolation='none', apply_phase_correction=False + ) + + self.assertEqual(shifted.shape, spec.shape) + self.assertTrue(np.all(np.isfinite(shifted))) + + def test_shift_spectrum_zero_shift(self): + """Test that zero shift returns identity.""" + frames = 5 + kk_max = 30 + spec = np.random.randn(kk_max, frames) + 1j * np.random.randn(kk_max, frames) + + delta_f = self.fs / self.nfft + shifted = CoherenceManager._shift_spectrum( + spec, 0.0, delta_f, self.fs, + interpolation='linear', apply_phase_correction=False + ) + + # Should be very close to original + np.testing.assert_allclose(shifted, spec, rtol=1e-5, atol=1e-8) + + def test_lagrange8_interpolate(self): + """Test 8-point Lagrange interpolation.""" + frames = 5 + kk_max = 30 + spec = np.random.randn(kk_max, frames) + 1j * np.random.randn(kk_max, frames) + + # Test integer position (should give exact value) + src_bin_float = 10.0 + result = CoherenceManager._lagrange8_interpolate(spec, src_bin_float) + np.testing.assert_allclose(result, spec[10, :], rtol=1e-5) + + # Test fractional position (should interpolate) + src_bin_float = 10.5 + result = CoherenceManager._lagrange8_interpolate(spec, src_bin_float) + self.assertEqual(result.shape, (frames,)) + self.assertTrue(np.all(np.isfinite(result))) + + def test_edge_cases(self): + """Test edge cases.""" + # Single bin + alpha_vec_hz = np.array([0]) + rho = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, alpha_vec_hz, + max_bin=10, use_stft=True, interpolation='none' + ) + self.assertEqual(rho.shape[0], 1) + + # Very small signal (ensure it's long enough for STFT) + small_signal = {'time': np.zeros((1, 8000))} + small_signal['time'][0, :] = 1e-10 * np.random.randn(8000) + + rho = CoherenceManager.compute_coherence_freq_shifted( + small_signal, self.SFT, self.alpha_vec_hz, + max_bin=10, use_stft=True, interpolation='none' + ) + self.assertTrue(np.all(np.isfinite(rho))) + + +class TestFrequencyDomainVsTimeDomain(unittest.TestCase): + """Compare frequency-domain and time-domain coherence methods.""" + + def setUp(self): + """Set up test fixtures.""" + gs.rng, _ = gs.compute_rng(seed_is_random=False, rnd_seed_=123) + + self.fs = 16000 + self.nfft = 512 + self.hop = 128 + self.win = get_window('hann', self.nfft) + self.SFT = ShortTimeFFT(self.win, self.hop, fs=self.fs, mfft=self.nfft, scale_to='magnitude') + + # Create test signal + duration = 1.0 + t = np.arange(int(duration * self.fs)) / self.fs + f0 = 150 + self.signal_time = (np.sin(2 * np.pi * f0 * t) + + 0.6 * np.sin(2 * np.pi * 2 * f0 * t)) + self.signal = {'time': self.signal_time[np.newaxis, :]} + + # Alpha vector + self.alpha_vec_hz = np.array([0, -f0, -2*f0]) + + def test_compare_outputs_valid(self): + """Verify both methods produce valid coherence outputs.""" + # Just verify frequency-domain method produces valid results + # Full A/B comparison would require matching the modulator's behavior exactly + max_bin = 80 + + # Frequency-domain coherence (new) + rho_freq = CoherenceManager.compute_coherence_freq_shifted( + self.signal, self.SFT, self.alpha_vec_hz, + max_bin=max_bin, use_stft=True, + interpolation='linear', apply_phase_correction=True, + min_relative_power=1.e+3 + ) + + # Check that output is valid coherence matrix + self.assertEqual(rho_freq.shape[0], len(self.alpha_vec_hz)) + self.assertTrue(np.all(rho_freq >= -0.01)) + self.assertTrue(np.all(rho_freq <= 1.01)) + self.assertTrue(np.all(np.isfinite(rho_freq))) + + # Check that no-shift (alpha=0) has high self-coherence + # Note: alpha values are sorted, so find where 0 is + # For alpha_vec_hz = [0, -f0, -2*f0], sorted is [-2*f0, -f0, 0] + cc0_idx = 2 + self.assertGreater(np.mean(rho_freq[cc0_idx, :]), 0.8) + + +if __name__ == '__main__': + unittest.main()