diff --git a/crepe/__init__.py b/crepe/__init__.py index 52eacea..a740984 100755 --- a/crepe/__init__.py +++ b/crepe/__init__.py @@ -1,2 +1,9 @@ from .version import version as __version__ -from .core import get_activation, predict, process_file +from .core import ( + get_activation, + predict, + process_file, + to_viterbi_cents_impl, + to_viterbi_cents_legacy, + to_viterbi_cents_fast, +) diff --git a/crepe/cli.py b/crepe/cli.py index aaa96de..59c26c6 100644 --- a/crepe/cli.py +++ b/crepe/cli.py @@ -9,6 +9,7 @@ def run(filename, output=None, model_capacity='full', viterbi=False, + viterbi_impl='legacy', save_activation=False, save_plot=False, plot_voicing=False, no_centering=False, step_size=10, verbose=True): """ @@ -27,6 +28,8 @@ def run(filename, output=None, model_capacity='full', viterbi=False, :func:`~crepe.core.build_and_load_model` viterbi : bool Apply viterbi smoothing to the estimated pitch curve. False by default. + viterbi_impl : {'legacy', 'fast'} + Implementation used when `viterbi=True`. save_activation : bool Save the output activation matrix to an .npy file. False by default. save_plot: bool @@ -77,6 +80,7 @@ def run(filename, output=None, model_capacity='full', viterbi=False, process_file(file, output=output, model_capacity=model_capacity, viterbi=viterbi, + viterbi_impl=viterbi_impl, center=(not no_centering), save_activation=save_activation, save_plot=save_plot, @@ -134,6 +138,9 @@ def main(): parser.add_argument('--viterbi', '-V', action='store_true', help='perform Viterbi decoding to smooth the pitch ' 'curve') + parser.add_argument('--viterbi-impl', default='legacy', + choices=['legacy', 'fast'], + help='implementation used when --viterbi is enabled') parser.add_argument('--save-activation', '-a', action='store_true', help='save the output activation matrix to a .npy ' 'file') @@ -165,6 +172,7 @@ def main(): output=args.output, model_capacity=args.model_capacity, viterbi=args.viterbi, + viterbi_impl=args.viterbi_impl, save_activation=args.save_activation, save_plot=args.save_plot, plot_voicing=args.plot_voicing, diff --git a/crepe/core.py b/crepe/core.py index 03ea52e..d246569 100644 --- a/crepe/core.py +++ b/crepe/core.py @@ -21,6 +21,8 @@ # the model is trained on 16kHz audio model_srate = 16000 +viterbi_impls = ('legacy', 'fast') + def build_and_load_model(model_capacity): """ @@ -124,6 +126,25 @@ def to_viterbi_cents(salience): Find the Viterbi path using a transition prior that induces pitch continuity. """ + return to_viterbi_cents_impl(salience, impl='legacy') + + +def to_viterbi_cents_impl(salience, impl='legacy'): + """ + Find the Viterbi path using the requested implementation. + """ + if impl == 'legacy': + return to_viterbi_cents_legacy(salience) + if impl == 'fast': + return to_viterbi_cents_fast(salience) + raise ValueError('expected viterbi_impl to be one of {}, got {}'.format( + viterbi_impls, impl)) + + +def to_viterbi_cents_legacy(salience): + """ + Legacy hmmlearn-backed Viterbi smoothing path. + """ from hmmlearn import hmm # uniform prior on the starting pitch @@ -153,6 +174,86 @@ def to_viterbi_cents(salience): range(len(observations))]) +def to_viterbi_cents_fast(salience): + """ + Exact structured Viterbi smoothing path for CREPE's local transition graph. + """ + observations = np.argmax(salience, axis=1).astype(np.int64, copy=False) + path = _viterbi_path_fast(observations) + return np.array([to_local_average_cents(salience[i, :], path[i]) for i in + range(len(observations))]) + + +def _viterbi_fast_structure(): + """ + Precompute the exact local predecessor structure for CREPE's transition. + """ + cached = getattr(_viterbi_fast_structure, 'cached', None) + if cached is not None: + return cached + + states = 360 + starting = np.ones(states, dtype=np.float64) / states + + xx, yy = np.meshgrid(range(states), range(states)) + transition = np.maximum(12 - abs(xx - yy), 0).astype(np.float64) + transition = transition / np.sum(transition, axis=1)[:, None] + + self_emission = 0.1 + emission = (np.eye(states, dtype=np.float64) * self_emission + + np.ones(shape=(states, states), dtype=np.float64) * + ((1 - self_emission) / states)) + + valid = transition > 0 + width = int(np.max(np.sum(valid, axis=0))) + source_idx = np.zeros((states, width), dtype=np.int16) + log_trans = np.full((states, width), -np.inf, dtype=np.float64) + + for target in range(states): + sources = np.flatnonzero(valid[:, target]).astype(np.int16) + source_idx[target, :len(sources)] = sources + log_trans[target, :len(sources)] = np.log( + transition[sources.astype(np.int64), target]) + + cached = { + 'state_idx': np.arange(states, dtype=np.int64), + 'log_starting': np.log(starting), + 'source_idx': source_idx, + 'log_trans': log_trans, + 'log_emission': np.log(emission) + } + _viterbi_fast_structure.cached = cached + return cached + + +def _viterbi_path_fast(observations): + """ + Exact structured Viterbi decode of CREPE's argmax observations. + """ + structure = _viterbi_fast_structure() + source_idx = structure['source_idx'] + log_trans = structure['log_trans'] + state_idx = structure['state_idx'] + log_emission = structure['log_emission'] + + prev = structure['log_starting'] + log_emission[observations[0]] + backpointers = np.empty((len(observations), 360), dtype=np.int16) + backpointers[0] = np.arange(360, dtype=np.int16) + + for frame, observation in enumerate(observations[1:], start=1): + candidates = prev[source_idx] + log_trans + best_offsets = np.argmax(candidates, axis=1) + best_sources = source_idx[state_idx, best_offsets] + backpointers[frame] = best_sources + prev = candidates[state_idx, best_offsets] + log_emission[observation] + + path = np.empty((len(observations),), dtype=np.int16) + path[-1] = int(np.argmax(prev)) + for frame in range(len(observations) - 1, 0, -1): + path[frame - 1] = backpointers[frame, path[frame]] + return path + + def get_activation(audio, sr, model_capacity='full', center=True, step_size=10, verbose=1): """ @@ -213,7 +314,8 @@ def get_activation(audio, sr, model_capacity='full', center=True, step_size=10, def predict(audio, sr, model_capacity='full', - viterbi=False, center=True, step_size=10, verbose=1): + viterbi=False, center=True, step_size=10, verbose=1, + viterbi_impl='legacy'): """ Perform pitch estimation on given audio @@ -229,6 +331,9 @@ def predict(audio, sr, model_capacity='full', :func:`~crepe.core.build_and_load_model` viterbi : bool Apply viterbi smoothing to the estimated pitch curve. False by default. + viterbi_impl : {'legacy', 'fast'} + Implementation used when `viterbi=True`. Defaults to the current + `hmmlearn` path (`legacy`). center : boolean - If `True` (default), the signal `audio` is padded so that frame `D[:, t]` is centered at `audio[t * hop_length]`. @@ -258,7 +363,7 @@ def predict(audio, sr, model_capacity='full', confidence = activation.max(axis=1) if viterbi: - cents = to_viterbi_cents(activation) + cents = to_viterbi_cents_impl(activation, impl=viterbi_impl) else: cents = to_local_average_cents(activation) @@ -272,7 +377,8 @@ def predict(audio, sr, model_capacity='full', def process_file(file, output=None, model_capacity='full', viterbi=False, center=True, save_activation=False, save_plot=False, - plot_voicing=False, step_size=10, verbose=True): + plot_voicing=False, step_size=10, verbose=True, + viterbi_impl='legacy'): """ Use the input model to perform pitch estimation on the input file. @@ -288,6 +394,8 @@ def process_file(file, output=None, model_capacity='full', viterbi=False, :func:`~crepe.core.build_and_load_model` viterbi : bool Apply viterbi smoothing to the estimated pitch curve. False by default. + viterbi_impl : {'legacy', 'fast'} + Implementation used when `viterbi=True`. center : boolean - If `True` (default), the signal `audio` is padded so that frame `D[:, t]` is centered at `audio[t * hop_length]`. @@ -320,6 +428,7 @@ def process_file(file, output=None, model_capacity='full', viterbi=False, audio, sr, model_capacity=model_capacity, viterbi=viterbi, + viterbi_impl=viterbi_impl, center=center, step_size=step_size, verbose=1 * verbose) @@ -363,4 +472,3 @@ def process_file(file, output=None, model_capacity='full', viterbi=False, imwrite(plot_file, (255 * image).astype(np.uint8)) if verbose: print("CREPE: Saved the salience plot at {}".format(plot_file)) - diff --git a/scripts/benchmark_viterbi.py b/scripts/benchmark_viterbi.py new file mode 100644 index 0000000..eaff004 --- /dev/null +++ b/scripts/benchmark_viterbi.py @@ -0,0 +1,141 @@ +from __future__ import print_function + +import argparse +import os +import time + +import numpy as np +from scipy.io import wavfile + +import crepe +from crepe import core + + +def time_call(fn, warmup, repeats): + last = None + for _ in range(warmup): + last = fn() + t0 = time.perf_counter() + for _ in range(repeats): + last = fn() + return ((time.perf_counter() - t0) * 1000.0 / repeats), last + + +def synthetic_salience(frames, seed): + rng = np.random.RandomState(seed) + return rng.uniform(low=0.0, high=1.0, size=(frames, 360)).astype(np.float64) + + +def has_weights(model_capacity): + return os.path.isfile( + os.path.join(os.path.dirname(core.__file__), + 'model-{}.h5'.format(model_capacity))) + + +def benchmark_salience(salience, warmup, repeats): + results = [] + for impl in ['legacy', 'fast']: + if impl == 'legacy': + try: + __import__('hmmlearn') + except ImportError: + results.append((impl, 'skipped', None)) + continue + ms, _ = time_call( + lambda: core.to_viterbi_cents_impl(salience, impl=impl), + warmup, + repeats) + results.append((impl, 'ok', ms)) + return results + + +def benchmark_predict(audio, sr, model_capacity, warmup, repeats, verbose): + results = [] + for impl in ['legacy', 'fast']: + if impl == 'legacy': + try: + __import__('hmmlearn') + except ImportError: + results.append((impl, 'skipped', None)) + continue + ms, _ = time_call( + lambda: crepe.predict( + audio, + sr, + model_capacity=model_capacity, + viterbi=True, + viterbi_impl=impl, + verbose=verbose), + warmup, + repeats) + results.append((impl, 'ok', ms)) + return results + + +def print_results(title, results): + print('## {}'.format(title)) + print('| Impl | Status | Mean time |') + print('|------|--------|-----------|') + for impl, status, ms in results: + value = '' if ms is None else '**{:.3f} ms**'.format(ms) + print('| `{}` | {} | {} |'.format(impl, status, value)) + print() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--frames', type=int, nargs='+', default=[512, 2048]) + parser.add_argument('--warmup', type=int, default=2) + parser.add_argument('--repeats', type=int, default=10) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--include-sweep', action='store_true') + parser.add_argument('--model-capacity', default='tiny', + choices=['tiny', 'small', 'medium', 'large', 'full']) + parser.add_argument('--verbose', type=int, default=0) + args = parser.parse_args() + + print('# CREPE Viterbi Benchmark') + print() + print('- `frames`: `{}`'.format(args.frames)) + print('- `warmup`: `{}`'.format(args.warmup)) + print('- `repeats`: `{}`'.format(args.repeats)) + print('- `seed`: `{}`'.format(args.seed)) + print('- `include_sweep`: `{}`'.format(args.include_sweep)) + print('- `model_capacity`: `{}`'.format(args.model_capacity)) + print() + + for frames in args.frames: + salience = synthetic_salience(frames, seed=args.seed + frames) + print_results('Synthetic decoder core: {} frames'.format(frames), + benchmark_salience(salience, args.warmup, args.repeats)) + + if not args.include_sweep: + return + + if not has_weights(args.model_capacity): + print('> ⚠️ Sweep benchmark skipped: model weight file for `{}` is not ' + 'present.'.format(args.model_capacity)) + return + + try: + __import__('hmmlearn') + except ImportError: + print('> ⚠️ Sweep benchmark skipped: `hmmlearn` is not installed.') + return + + sweep_path = os.path.join(os.path.dirname(__file__), '..', 'tests', 'sweep.wav') + sr, audio = wavfile.read(sweep_path) + activation = crepe.get_activation( + audio, + sr, + model_capacity=args.model_capacity, + verbose=args.verbose) + print_results('Sweep activation decoder core: {} frames'.format(len(activation)), + benchmark_salience(activation, args.warmup, args.repeats)) + print_results('Sweep full predict(): {} frames'.format(len(activation)), + benchmark_predict(audio, sr, args.model_capacity, + args.warmup, args.repeats, args.verbose)) + + +if __name__ == '__main__': + main() diff --git a/tests/test_viterbi_impl.py b/tests/test_viterbi_impl.py new file mode 100644 index 0000000..3e6ced8 --- /dev/null +++ b/tests/test_viterbi_impl.py @@ -0,0 +1,94 @@ +import numpy as np +import pytest + +import crepe +from crepe import cli +from crepe import core + + +def synthetic_salience(frames, seed=0): + rng = np.random.RandomState(seed) + return rng.uniform(low=0.0, high=1.0, size=(frames, 360)).astype(np.float64) + + +def dense_reference_path(observations): + structure = core._viterbi_fast_structure() + states = 360 + xx, yy = np.meshgrid(range(states), range(states)) + transition = np.maximum(12 - abs(xx - yy), 0).astype(np.float64) + transition = transition / np.sum(transition, axis=1)[:, None] + log_transition = np.log(transition) + log_starting = structure['log_starting'] + log_emission = structure['log_emission'] + + prev = log_starting + log_emission[observations[0]] + backpointers = np.empty((len(observations), states), dtype=np.int16) + backpointers[0] = np.arange(states, dtype=np.int16) + + for frame, observation in enumerate(observations[1:], start=1): + candidates = prev[:, None] + log_transition + best_sources = np.argmax(candidates, axis=0) + backpointers[frame] = best_sources.astype(np.int16) + prev = candidates[best_sources, np.arange(states)] + log_emission[observation] + + path = np.empty((len(observations),), dtype=np.int16) + path[-1] = int(np.argmax(prev)) + for frame in range(len(observations) - 1, 0, -1): + path[frame - 1] = backpointers[frame, path[frame]] + return path + + +class TestViterbiFast: + def test_fast_path_matches_dense_reference(self): + salience = synthetic_salience(64, seed=123) + observations = np.argmax(salience, axis=1) + fast = core._viterbi_path_fast(observations) + dense = dense_reference_path(observations) + np.testing.assert_array_equal(fast, dense) + + def test_fast_cents_matches_legacy(self): + pytest.importorskip("hmmlearn") + salience = synthetic_salience(48, seed=321) + legacy = core.to_viterbi_cents_impl(salience, impl='legacy') + fast = core.to_viterbi_cents_impl(salience, impl='fast') + np.testing.assert_allclose(fast, legacy, rtol=0.0, atol=0.0) + + def test_invalid_impl_raises(self): + salience = synthetic_salience(8, seed=7) + with pytest.raises(ValueError): + core.to_viterbi_cents_impl(salience, impl='nope') + + def test_predict_routes_fast_impl(self, monkeypatch): + salience = synthetic_salience(16, seed=11) + monkeypatch.setattr(core, 'get_activation', + lambda *args, **kwargs: salience) + _, frequency, confidence, activation = core.predict( + np.zeros((16000,), dtype=np.float32), + 16000, + viterbi=True, + viterbi_impl='fast', + verbose=0) + cents = core.to_viterbi_cents_fast(salience) + expected_frequency = 10 * 2 ** (cents / 1200) + np.testing.assert_allclose(frequency, expected_frequency) + np.testing.assert_allclose(confidence, salience.max(axis=1)) + np.testing.assert_allclose(activation, salience) + + def test_cli_run_passes_viterbi_impl(self, monkeypatch, tmp_path): + wav_path = tmp_path / 'dummy.wav' + wav_path.write_bytes(b'RIFF') + captured = {} + + def fake_process_file(file, **kwargs): + captured['file'] = file + captured['kwargs'] = kwargs + + monkeypatch.setattr(cli, 'process_file', fake_process_file) + cli.run( + [str(wav_path)], + viterbi=True, + viterbi_impl='fast', + verbose=False) + assert captured['file'] == str(wav_path) + assert captured['kwargs']['viterbi'] is True + assert captured['kwargs']['viterbi_impl'] == 'fast'