%
% Wavelet bispectrum
%
%-------------------------------Copyright----------------------------------
%
% Author: Aleksandra Pidde, aleksandra.pidde@gmail.com
% 
% Related articles:
% J. Newman, A. Pidde, A. Stefanovska, "Defining the Wavelet Bispectrum"
% {Preprint submitted to Applied and Computational Harmonic Analysis, March 5, 2019, preprint:arXiv ?}
%
%------------------------------Documentation-------------------------------
%
% [Bisp, norm, freq, Optional:wavPar] = BispectrumTemp(x, y, z, fs, wavelet, f0, fmin, fmax, nv, p, cutEdges, logBase, T)
% - calculate wavelet bispectrum [Bisp] with the normalisation [norm] of a signals [x, y, z] sampled at [fs] Hz.
% 
% INPUT:
% x, y, z    - signals
% fs         - sampling frequency of the signals
% wavelet/f0/fmin/fmax/nv
%            - see wav.m for description
% p          - for logarithmic frequency approach 1 is recommended, for linear 1/3
% cutEdges/logBase
%            - see wav.m for description
% T          - the time window T, for which the bispectrum is averaged, [tmin tmax] (sec), where tmin <= tmax
%
% OUTPUT:
% Bisp       - wavelet bispectrum
% norm       - normalisation for logarithmic frequency approach, 1/D_phi from eq. 41
% freq:      - frequencies
% Optional:
% wavPar     - see wav.m 
%
    
function [Bisp, norm, freq, varargout] = BispectrumTemp(x, y, z, fs, wavelet, f0, fmin, fmax, nv, p, cutEdges, logBase, T)
tmin = T(1); tmax = T(end);

[wtx, freq, wavPar] = wav(x, fs, wavelet, f0, fmin, fmax, nv, p, cutEdges, logBase);
auto = false;
if mean(x == y) == 1 && mean(x == z) == 1
    wty = wtx;
    auto = true;
else
    [wty, freq, wavPar] = wav(y, fs, wavelet, f0, fmin, fmax, nv, p, cutEdges, logBase);
end
i1 = max(floor(tmin * fs), 1); i2 = max(1, floor(tmax * fs));
% cutting to appropriate time window
wtx = wtx(:, i1 : i2);
wty = wty(:, i1 : i2);

nf = length(freq);
Bisp = nan * zeros(nf, nf); 
for i = 1 : nf
    f1 = freq(i);
    start = 1;
    if auto
        start = i;
    end
    for j = start : nf
        f2 = freq(j);
        f3 = f1 + f2;
        bigger = max([i j]);
        idx3 = find(freq >= f3, 1);
        if (f3 <= freq(end)) && (freq(idx3 - 1) > freq(bigger)) 
            wt3 = wtAtf(z, fs, wavelet, f0, f3, p, cutEdges);
            wt3 = wt3(i1 : i2);
            val = nanmean(wtx(i, :) .* wty(j, :) .* conj(wt3));
            Bisp(i, j) = val;
            
        end
    end
end
norm = Dphi(wavPar.fwt, freq, fs, nv, logBase);
norm(isnan(Bisp)) = nan;
if nargout > 3
    varargout{1} = wavPar;
end
end

function [wt, varargout] = wtAtf(x, fs, wavelet, f0, fr, p, cutEdges)
%
% function helper, wavelet transform precisely at given frequency fr
%
[wt, ~, wavPar] = wav(x, fs, wavelet, f0, fr, fr, 1, p, cutEdges, exp(1));
if nargout > 1
    varargout{1} = wavPar;
end
end

function [val] = D(fwt, ratio, kmin, kmax, logBase)
%
% function helper, finding the value of the integral given the ratio of two frequencies
%
nk = 500;
first = log(kmin) / log(logBase);
last = log(kmax) / log(logBase);
KSI1 = logBase.^linspace(first, last, nk); 
dksi = diff(KSI1);
dksi = [dksi(1) (dksi(1 : end - 1) + dksi(2 : end)) / 2 dksi(end)];
tsum = 0;
for i = 1 : nk
    for j = 1 : nk
        ksi1 = KSI1(i); ksi2 = KSI1(j);
        ksi3 = ksi1 + ksi2;
        coun = fwt(1 / (1 + ratio) * ksi1) .* fwt(ratio / (ratio + 1) * ksi2) .* fwt(ksi1 * ksi2 / ksi3);
        den = ksi1 * ksi2;
        tmp = (coun / den) * dksi(i) * dksi(j);
        tsum = nansum([tsum tmp]);
    end
end
val = 1 / tsum;
end



function [normD] = Dphi(fwt, freq, fs, nv, logBase)
% finding the normalisation matrix for bispectrum 
%
% finding the frequency region of interest (for intergration over R), fmin
% and fmax
nr = 1e5; % number of point for integration
thres = 1e-5; % threshold for wavelet
nfreq = length(freq);

first = log(1e-10) / log(logBase);
last = log(fs / 2) / log(logBase);
R1 = logBase.^linspace(first, last, nr); 

idx1 = find(fwt(R1) >= thres, 1, 'first'); % frequencies pointed 
if isempty(idx1) 
    idx1 = 1;
end
idx2 = find(fwt(R1) >= thres, 1, 'last');
if isempty(idx2)
    idx2 = nr;
end
fmin = R1(max(idx1 - 1, 1)); fmax = R1(min(idx2 + 2, nr));

% finding the actual normalization matrix (as a function of frequency
% ratios (f1 / f2), assuming that freq is a logarithmic vector

normD = nan * zeros(nfreq, nfreq);
for i = 0 : nfreq - 1 
    ratio = logBase.^(i / nv);
    rmin = fmin * (1 + 1 / ratio); rmax = fmax * (1 + ratio);
    val = D(fwt, ratio, rmin, rmax, nr);
    for j = 1 : nfreq - i
         if freq(j) + freq(j + i) <= fs / 2
            normD(j, j + i) = val;
            normD(j + i, j) = val;
         end
    end
end
end   
 
