function [Omega_app, Mu, Sigma, generated_samples, generated_sample_groups] = get_overlap_metric(training_data, training_groups, num_samples_to_generate, classifier_type, plot_distributions, dim_selection_matrix)
% function [Omega_app, Mu, Sigma, generated_samples, generated_sample_groups] = get_overlap_metric(training_data, training_groups, num_samples_to_generate, classifier_type, plot_distributions, dim_selection_matrix)
%
% © 2007 Geoffrey Stewart Morrison
% version of 10 March 2007
%
% Calculates distribution-overlap metric described in: 
%   Morrison, G. S. (2007 in prep). Comment on "A geometric representation of spectral and temporal vowel 
%   features: Quantification of vowel overlap in three linguistic varieties" [J. Acoust Soc. Am. 119, 2334–2350 
%   (2006)] (L)
%
% Metric designed to measure degree of overlap in vowel distributions in multiple dimensions 
%   such as normalised F1, F2, and duration.
%
% Means and covariances of categories in training data are calculated, and used to generate samples.
% Discriminant analysis model is used to calculate a posteriori probabilities of sample points from 
%   categories ¬A as members of category A.
% Overlap metric based on mean a posteriori probabilities.
%
% OBLIGATORY INPUT ARGUMENTS:
%   training_data               2D matrix of data, e.g., from experimental measurements of vowel F1, F2, duration
%                                   rows represent cases, columns dimensions
%   training_groups             indices corresponding to vowel category
%                                   may be a numerical vector, a string array, or a cell array of strings
% OPTIONAL INPUT ARGUMENTS:
%   num_samples_to_generate     number of samples to generate per category, e.g., 500000
%                                   default 0 = use 'training_data' directly
%   classifier_type             type of discriminant analysis model, see classifier help
%                                   default 'quadratic'
%   plot_distributions          set to 'plot' to plot first 5000 sample points in 3-D, 2-D, or 1-D space
%                                   'plot only' to skip calculation of overlap metric
%                                   default 'no plot' 
%                                   (1-D and 2-D plots will fail if 'training_data' is used directly 
%                                   and the number of cases per group is not equal)
%   dim_selection_matrix        a matrix specifying which dimensions to include in the overlap calculation
%                                   all calculations are based on the same generated sample set
%                                   for example, to calculate the overlap on (dim1+dim2), (dim3), (dim1+dim2+dim3):
%                                       [1 1 0;
%                                        0 0 1;
%                                        1 1 1]
%                                   for example, to calculate the overlap on (dim1=x)+(dim2=y), (dim3=x)+(dim2=y):
%                                       [1 2;
%                                        3 2]       (if necessary, can pad rows with zeros to make matrix square)
%                                   default is a single overlap metric calculation 
%                                   using all dimensions in 'training_data'
%
% OBLIGATORY OUTPUT ARGUMENT:
%   Omega_app                   overlap metric
% OPTIONAL OUTPUT ARGUMENT:
%   Mu                          cell array containing mean vectors for each group
%   Sigma                       cell array containing covarience matrices for each group
%   generated_samples           matrix of generatred sample points, rows represent cases, columns dimensions
%   generated_sample_groups     numeric vector of group ids of same length as 'generated_samples'
%
% REQUIRED TOOLBOX:
%   Statistics toolbox


% default input arguments
if nargin<3, num_samples_to_generate=0; end
if nargin<4, classifier_type='quadratic'; end
if nargin<5, plot_distributions='no plot'; end
num_dims = size(training_data,2);
if nargin<6, dim_selection_matrix = ones(1,num_dims); end

% get groups
[group_index,group_names] = grp2idx(training_groups);
num_groups = length(group_names);

% generate samples
generated_samples=[];
generated_sample_groups=[];
Mu=cell(1,num_groups);
Sigma=Mu;
if num_samples_to_generate>0
    for Igroup=1:num_groups
        Mu{Igroup}(1,1:num_dims) = mean(training_data(group_index==Igroup,:));
        Sigma{Igroup}(1:num_dims,1:num_dims) = cov(training_data(group_index==Igroup,:));
        generated_samples = [generated_samples; mvnrnd(Mu{Igroup},Sigma{Igroup},num_samples_to_generate)];
        generated_sample_groups = [generated_sample_groups; ones(num_samples_to_generate,1)*Igroup];
    end
else
    generated_samples = training_data;
    generated_sample_groups = training_groups;
end

% cyle through dim_selection_matrix
num_rows_dim_selection_matrix = size(dim_selection_matrix,1);
Omega_app=NaN(num_rows_dim_selection_matrix,1);

for Irow=1:num_rows_dim_selection_matrix
    
    % dimensions to use this cycle
    dim_selection_row=dim_selection_matrix(Irow,:);
    if max(dim_selection_row)>1
        dim_selection_row(dim_selection_row==0)=[];
        num_dims=length(dim_selection_row);
    else
        dim_selection_row=logical(dim_selection_row);
        num_dims=sum(dim_selection_row);
    end
    
    if ~strcmp(plot_distributions,'plot only')
        % get a posteriori probabilities from discriminant analysis model
        [class, err, app] = classify(generated_samples(:,dim_selection_row), generated_samples(:,dim_selection_row), generated_sample_groups, classifier_type);

        % calculate mean a posteriori probabilities of sample points from categories ¬A as members of category A
        mean_app_notA_as_A=zeros(1,num_groups);
        for Igroup=1:num_groups
            mean_app_notA_as_A(Igroup) = mean(app(generated_sample_groups~=Igroup, Igroup));
        end

        % calculate overlap metric
        Omega_app(Irow) = sum(mean_app_notA_as_A)/(num_groups-1);
    end %if ~strcmp(plot_distributions,'plot only')

    % plot distributions
    if strcmp(plot_distributions,'plot') || strcmp(plot_distributions,'plot only')
        shape_colour={'+b' '+r' '+g' '+c' '+m' '+y' '+k'}; % good for upto 7 categories
        IIgroup=cell(1,num_groups);
        
        for Igroup=1:num_groups
            IIgroup{Igroup}=find(generated_sample_groups==Igroup);
            if num_samples_to_generate>5000
                IIgroup{Igroup}=IIgroup{Igroup}(1:5000);
                num_samples_per_group=5000;
            else
                num_samples_per_group=num_samples_to_generate;
            end
        end
        figure;
        if num_dims>2 % 3-D plot
            for Igroup=1:num_groups
                plot3(generated_samples(IIgroup{Igroup},dim_selection_row(1)),generated_samples(IIgroup{Igroup},dim_selection_row(2)),generated_samples(IIgroup{Igroup},dim_selection_row(3)),shape_colour{Igroup});
                hold on
            end
            xlabel('x'); ylabel('y'); zlabel('z'); axis equal
        elseif num_dims>1 % 2-D plot with intercalated symbols (takes a while)
            for Isymbol=1:num_samples_per_group
                for Igroup=1:num_groups
                    plot(generated_samples(IIgroup{Igroup}(Isymbol),dim_selection_row(1)),generated_samples(IIgroup{Igroup}(Isymbol),dim_selection_row(2)),shape_colour{Igroup});
                    hold on
                end        
            end
            xlabel('x'); ylabel('y'); axis equal
        else % 1-D plot
            for Igroup=1:num_groups
                h=histfit(generated_samples(IIgroup{Igroup},dim_selection_row));
                hold on
                delete(h(1));
                set(h(2),'Color',shape_colour{Igroup}(2));
            end
        end
        legend(group_names);
        hold off
    end %if strcmp(plot_distributions,'plot') && num_dims>1
    
end %for Irow=1:num_rows_dim_selection_matrix

return