/*
 * Copyright (C) 2006-2021  Music Technology Group - Universitat Pompeu Fabra
 *
 * This file is part of Essentia
 *
 * Essentia is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the Free
 * Software Foundation (FSF), either version 3 of the License, or (at your
 * option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the Affero GNU General Public License
 * version 3 along with this program.  If not, see http://www.gnu.org/licenses/
 */

#include "onsetdetection.h"
#include <complex>
#include "essentiamath.h"

using namespace essentia;
using namespace standard;
using namespace std;

const char* OnsetDetection::name = "OnsetDetection";
const char* OnsetDetection::category = "Rhythm";
const char* OnsetDetection::description = DOC("This algorithm computes various onset detection functions. The output of this algorithm should be post-processed in order to determine whether the frame contains an onset or not. Namely, it could be fed to the Onsets algorithm. It is recommended that the input \"spectrum\" is generated by the Spectrum algorithm.\n"
"Four methods are available:\n"
"  - 'HFC', the High Frequency Content detection function which accurately detects percussive events (see HFC algorithm for details).\n"
"  - 'complex', the Complex-Domain spectral difference function [1] taking into account changes in magnitude and phase. It emphasizes note onsets either as a result of significant change in energy in the magnitude spectrum, and/or a deviation from the expected phase values in the phase spectrum, caused by a change in pitch.\n"
"  - 'complex_phase', the simplified Complex-Domain spectral difference function [2] taking into account phase changes, weighted by magnitude. TODO:It reacts better on tonal sounds such as bowed string, but tends to over-detect percussive events.\n"
"  - 'flux', the Spectral Flux detection function which characterizes changes in magnitude spectrum. See Flux algorithm for details.\n"
"  - 'melflux', the spectral difference function, similar to spectral flux, but using half-rectified energy changes in Mel-frequency bands of the spectrum [3].\n"
"  - 'rms', the difference function, measuring the half-rectified change of the RMS of the magnitude spectrum (i.e., measuring overall energy flux) [4].\n"
"\n"
"If using the 'HFC' detection function, make sure to adhere to HFC's input requirements when providing an input spectrum. Input vectors of different size or empty input spectra will raise exceptions.\n"
"If using the 'complex' detection function, suggested parameters for computation of \"spectrum\" and \"phase\" are 44100Hz sample rate, frame size of 1024 and hopSize of 512 samples, which results in a resolution of 11.6ms, and a Hann window.\n"
"\n"
"References:\n"
"  [1] Bello, Juan P., Chris Duxbury, Mike Davies, and Mark Sandler, On the\n"
"  use of phase and energy for musical onset detection in the complex domain,\n"
"  Signal Processing Letters, IEEE 11, no. 6 (2004): 553-556.\n\n"
"  [2] P. Brossier, J. P. Bello, and M. D. Plumbley, \"Fast labelling of notes\n" 
"  in music signals,\" in International Symposium on Music Information\n"
"  Retrieval (ISMIR’04), 2004, pp. 331–336.\n\n"
"  [3] D. P. W. Ellis, \"Beat Tracking by Dynamic Programming,\" Journal of\n"
"  New Music Research, vol. 36, no. 1, pp. 51–60, 2007.\n\n"
"  [4] J. Laroche, \"Efficient Tempo and Beat Tracking in Audio Recordings,\"\n"
"  JAES, vol. 51, no. 4, pp. 226–233, 2003.\n");

void OnsetDetection::configure() {
  Real sampleRate = parameter("sampleRate").toReal();
  _method = parameter("method").toLower();

  _hfc->configure("type", "Brossier", "sampleRate", sampleRate);
  _melBands-> configure("sampleRate", sampleRate,
                        "numberBands", 40,
                        "lowFrequencyBound", 0.0,
                        "highFrequencyBound", 4000.0);

  // Use L1 for both 'melflux' and 'flux' methods. Evaluation by Jose Zapata
  // revealed better performance of L1 for 'flux'
  _flux->configure("norm", "L1");
  if (_method == "melflux") {
    _flux->configure("halfRectify", true);
  }

  _firstFrame = true;
}

void OnsetDetection::compute() {
  const vector<Real>& spectrum = _spectrum.get();
  const vector<Real>& phase = _phase.get();

  if (spectrum.empty()) {
    throw EssentiaException("OnsetDetection: OnsetDetection cannot be computed on an empty spectrum");
  }

  Real& onsetDetection = _onsetDetection.get();
  onsetDetection = 0.0;

  // HFC-based detection function for percussive onsets
  if (_method == "hfc") {
    _hfc->input("spectrum").set(spectrum);
    _hfc->output("hfc").set(onsetDetection);
    _hfc->compute();
    return;
  }

  // TODO Evaluation is required to compare 'complex' to 'complex_phase'.
  // Remove 'complex_phase' if it can be substituted by 'complex'


  // Complex-domain detection function for non-percussive onsets (Brossier, [2])
  // this version ignores magnitude difference
  if (_method == "complex_phase") {
    if (spectrum.size() != phase.size()) {
      throw EssentiaException("OnsetDetection: Spectrum and phase cannot be of different size");
    }
    if (phase.size() != _phase_2.size() || phase.size() != _phase_1.size()) {
      _phase_1.resize(phase.size());
      _phase_2.resize(phase.size());
      fill(_phase_1.begin(), _phase_1.end(), Real(0.0));
      fill(_phase_2.begin(), _phase_2.end(), Real(0.0));
    }

    for (int i=0; i<int(phase.size()); ++i) {
      //Real targetPhase = princarg(2*_phase_1[i] - _phase_2[i]);
      // optimization: we do not need the princarg here, because we take the sine of it
      //               just after that
      Real targetPhase = 2*_phase_1[i] + _phase_2[i];

      //Real distance = norm(polar(spectrum[i], targetPhase) - polar(spectrum[i], phase[i]));
      // optimization: |ae**(i*p0) - ae**(i*p)| = ... = 2a*|sin((p-p0)/2)|
      // we do not need to take the abs value either, as we square it next line
      Real distance = 2.0 * spectrum[i] * sin((phase[i]-targetPhase)*0.5);

      onsetDetection += distance * distance;
    }

    _phase_2 = _phase_1;
    _phase_1 = phase;
    return;
  }

  // Complex-domain detection function for non-percussive onsets (Bello, [1])
  if (_method == "complex") {
    if (spectrum.size() != phase.size()) {
      throw EssentiaException("OnsetDetection: Spectrum and phase cannot be of different size");
    }
    if (phase.size() != _phase_2.size() || phase.size() != _phase_1.size()) {
      _phase_1.resize(phase.size());
      _phase_2.resize(phase.size());
      fill(_phase_1.begin(), _phase_1.end(), Real(0.0));
      fill(_phase_2.begin(), _phase_2.end(), Real(0.0));
      _spectrum_1.resize(phase.size());
      fill(_spectrum_1.begin(), _spectrum_1.end(), Real(0.0));
    }

    for (int i=0; i<int(phase.size()); ++i) {
      //Real targetPhase = princarg(2*_phase_1[i] - _phase_2[i]);
      Real targetPhase = 2*_phase_1[i] - _phase_2[i];
      targetPhase = fmod(targetPhase + M_PI, -2 * M_PI) + M_PI;

      //Real distance = norm(polar(_spectrum_1, targetPhase) - polar(spectrum[i], phase[i]));
      // optimization: rotate vectors to map target vector to the real axis (details in [2])
      Real distance = abs(_spectrum_1[i] - polar(spectrum[i], phase[i]-targetPhase));
      onsetDetection += distance;
    }

    _phase_2 = _phase_1;
    _phase_1 = phase;
    _spectrum_1 = spectrum;
    return;
  }

  // Detection function based on spectral flux
  if (_method == "flux") {
    _flux->input("spectrum").set(spectrum);
    _flux->output("flux").set(onsetDetection);
    _flux->compute();
    return;
  }


  // Detection function similar to spectral flux, but computed on Mel-frequency spectrum [3].
  if (_method == "melflux") {
    /*
      Original algorithm:
        - downsample audio to 8kHz mono
        - cut frames with 32ms window, 4ms hop size
        - compute log-magnitude Mel-frequency spectrum (40 bands) in each frame
        - take first-order difference along time in each Mel band, and sum across frequency
        - high-pass filter (cutoff at 0.4 Hz) to remove DC offset of the computed function
        - smooth by convolving with a Gaussian envelope about 20 ms wide

      Modifications:
        - compute Mel bands only for frequencies below 4kHz instead of downsampling
        - skip high-pass filtering and smoothing, they should be done in post-processing
        - in the case of 44100 sample rate, exact frame size should be 1411 samples to match 32ms,
          but we leave this values to be decided by the user
      Note:
        - We manually remove a click in ODF on the first frame because Flux algorithm is
          initialized with zero vector while we feed it with log-magnitudes instead of magnitudes.

    */
    vector <Real> melbands;
    _melBands->input("spectrum").set(spectrum);
    _melBands->output("bands").set(melbands);
    _melBands->compute();

    // take the dB amplitude of the spectrum
    for (int i=0; i<int(melbands.size()); ++i) {
      melbands[i] = amp2db(melbands[i]);
    }
    /*
      Note: D. Ellis implementation looks only at the top 80 dB across all frames. Magnitudes below
      the maximum - 80dB are replaced with this dynamic threshold. This requires a post-processing step that
      we want to avoid. Instead, amp2db outputs a fixed silence threshold if values of magnitude are too low.
    */
    _flux->input("spectrum").set(melbands);
    _flux->output("flux").set(onsetDetection);
    _flux->compute();

    if (_firstFrame) {  // a hack to remove click in the first sample
      onsetDetection = 0;
      _firstFrame = false;
    }
  }

  // Detection function based on the half-rectified change of the RMS of the spectrum
  // NB: evaluation: half-rectifying improved results
  if (_method == "rms") {
    Real rms = 0;
    for (int i=0; i<(int) spectrum.size(); ++i) {
      rms += spectrum[i] * spectrum[i];
    }
    rms = sqrt(rms) / spectrum.size();
    if (_firstFrame) {  // a hack to remove click in the first sample
      onsetDetection = 0;
      _firstFrame = false;
    }
    else {
      onsetDetection = rms - _rmsOld;
      if (onsetDetection < 0) { // half-rectify
        onsetDetection = 0;
      }
    }
    _rmsOld = rms;
  }
}

void OnsetDetection::reset() {
  _phase_1.clear();
  _phase_2.clear();
  _spectrum_1.clear();
  _hfc->reset();
  _flux->reset();
  _melBands->reset();
  _firstFrame = true;
}
