clear; clc

%% Dependencies

disp('Adding dependencies...');

addpath(genpath('ds_gplvm'));
addpath(genpath('gpml-matlab-v3.5-2014-12-08'));
addpath(genpath('minFunc_2012'));

%% Load data

%%%% Define the path of the MERL binary files
path = 'binary_files/';

binFiles = dir([path '*.binary']);

%%%% Read MERL BRDFs
disp('Reading MERL database...');
brdfs = zeros(90, 90, 180, 3, 100);
for i = 1:100
    brdfs(:, :, :, :, i) = readMerlBrdf([path binFiles(i).name]);
end
clear path binFiles i

disp('Pre-processing data...');

%%%% Keep only the indices that contain data
brdfs = reshape(permute(brdfs, [4 5 1 2 3]), [3*100 90*90*180]);
negIdx = find(brdfs(1, :) < 0);
posIdx = setdiff(1:size(brdfs,2), negIdx);
brdfs = brdfs(:, posIdx);

%%%% Apply the natural logarithm
brdfs = log(brdfs + 1);

%%%% Convert to LAB colorspace
brdfs = reshape(permute(reshape(brdfs, [3 100 size(posIdx,2)]), [2 1 3 4 5]), [100 3*size(posIdx,2)]);
brdfsLab = zeros(size(brdfs));
for i = 1:100
    tmp = brdfs(i, :);
    tmp = reshape(tmp, [3 size(posIdx,2)]);
    tmp = permute(tmp, [2 3 1]);
    tmp = RGB2Lab(tmp);
    tmp = permute(tmp, [3 1 2]);
    tmp = reshape(tmp, [1 3*size(posIdx,2)]);
    brdfsLab(i, :) = tmp;
end
clear i tmp

%%%% Find the indices that contain the 1D BRDF (i.e. theta_d = 0)
tmp = zeros(90, 90, 180, 3);
tmp(:, 1, 90, :) = 1;
tmp = reshape(permute(tmp, [4 1 2 3]), [3 90*90*180]);
tmp = tmp(:, posIdx);
tmp = reshape(tmp, [1 3*size(posIdx,2)]);
subIdx1 = find(tmp(1, :) == 1);
clear tmp

%%%% Find the indices that contain the 2D BRDF (i.e. phi_d = 90)
tmp = zeros(90, 90, 180, 3);
tmp(:, :, 90, :) = 1;
tmp = reshape(permute(tmp, [4 1 2 3]), [3 90*90*180]);
tmp = tmp(:, posIdx);
tmp = reshape(tmp, [1 3*size(posIdx,2)]);
subIdx2 = find(tmp(1, :) == 1);
clear tmp

%%%% Load the Spectral Clustering result
load('spec_clust.mat');

%%%% 1D BRDFs
data_in{1} = brdfsLab(:, subIdx1);
labels_full{1} = idx;
%%%% 2D BRDFs
data_in{2} = brdfsLab(:, subIdx2);
labels_full{2} = idx;
%%%% 3D BRDFs
train3D = false;                    % set true to train 3D BRDFs (slower)
if train3D
    data_in{3} = brdfsLab(:, :);
    labels_full{3} = idx;
end

%% Specify indices for training, validation, test (leave validation and test empty for training mode only)

disp('Creating training, validation, test set...');

%%%% Re-shuffle the order of the BRDFs (for n-fold cross-validation)
%%%% Folder rndIdx contains some pre-generated random indices
%%%% Feel free to generate your own random indices here
load(['rndIdx/rndIdx' num2str(randi(5))]);

%%%% Define training, validation and testing samples
%%%% Find below some pre-defined indices
%%%% Feel free to define your own indices
idxMode = randi(5);
if idxMode == 1
    trIdx = 1:60;
    vlIdx = 61:80;
    tsIdx = 81:100;
elseif idxMode == 2
    trIdx = 21:80;
    vlIdx = 81:100;
    tsIdx = 1:20;
elseif idxMode == 3
    trIdx = 41:100;
    vlIdx = 1:20;
    tsIdx = 21:40;
elseif idxMode == 4
    trIdx = [61:100, 1:20];
    vlIdx = 21:40;
    tsIdx = 41:60;
else
    trIdx = [81:100, 1:40];
    vlIdx = 41:60;
    tsIdx = 61:80;
end

train_ind = rndIdx(trIdx);
val_ind = rndIdx(vlIdx);
test_ind = rndIdx(tsIdx);

for i = 1:numel(data_in)
    test_ind2{i} = test_ind;
end

ind = {train_ind, val_ind, test_ind2};
if isempty(ind{2})||isempty(ind{3})
    validation = 0;
else
    validation = 1;
end

%% Select covariance function and initialize parameters for GPs

disp('Setting model parameters...');

covfunc = {@covSum_mod, {@covSEiso_mod, @covConst_mod, @covNoise_mod}};
likfunc = @likGauss;

hyp.cov = log([1; 1; sqrt(.1); sqrt(.1)]);
hyp.lik = log(sqrt(.1));

%% Create model

model.prior = 50;                   % efect of the prior
model.prior_type = 'pca';           % possible options lpp, lda, pca

%%%% back projection settings. Only one can be active !!!!
model.bp = 0;                       % standard back projection defined by Lawrence and used in D-GPLVM
model.sbp = 0;                      % SBP setting of the DS-GPVLM
model.ibp = 1;                      % IBP setting of the DS-GPVLM

%%%% validation setting, and corresponding indices 
model.validation = validation;
model.ind = ind;

%%%%
model.X = [];                       % latent space
model.Laplacian = [];               % Laplacian matrix of the constrain
model.labels_full = labels_full;    % labels of the dataset

%%%% rho and max value for \mu parameters of the adm
model.rho = 1.1;
model.max_mu = 1e3;
model.T = 50;                       % No. of ADM cycles (default is 100)

%%%% parameters for the GP mappings
model.covfunc = covfunc;
model.likfunc = likfunc;
model.hyp = hyp;

%%%% OUTPUT of the model. If validation is not set, predictions only for
%%%% the train set are returned
model.out.ac_val = []; model.out.ac_test = []; model.out.ac_train = []; 
model.out.labels_val = []; model.out.labels_test = []; model.out.labels_train = [];

%%%% Do we want plots?
model.verbose = 0;

model = initialize_dsgplvm(data_in, model);

tic
model = ds_gplvm_adm(data_in, model);
toc

%% Test predictions

disp('Inferring test samples...');

%%%% We always move from the space of 1D BRDFs to the space of 2D/3D BRDFs
for i = 1:numel(data_in)
    [Y{i}, m{i}, s2{i}] = standardize(data_in{i}(model.ind{1}, :), 1);
    Y_test{i} = (data_in{i}(model.ind{3}{i}, :) - repmat(m{i}, [size(test_ind,2) 1])) ./ repmat(s2{i}, [size(test_ind,2) 1]);
    outK_test{i} = covSEisoU(log(model.bp_params(i).gamma), Y{i}, Y_test{i});
    outK_test{i} = [outK_test{i}; ones(1, size(outK_test{i}, 2))];
    m_latent = outK_test{1}' * model.bp_params(1).A;
    [smse_test, smse_train, Y_star_test, Y_star, snll] = reconstruction_error(model.hyp, model.covfunc, model.X, Y, m_latent, Y_test);
end

disp('Visualizing and saving test samples...');

i = 2;                              % 2D BRDFs
    
preds = Y_star_test{i} .* repmat(s2{i}, [size(test_ind,2) 1]) + repmat(m{i}, [size(test_ind,2) 1]);
preds = reshape(preds, [size(test_ind,2) 3 90 90]);
preds = permute(preds, [3 4 2 1]);
gts = Y_test{i} .* repmat(s2{i}, [size(test_ind,2) 1]) + repmat(m{i}, [size(test_ind,2) 1]);
gts = reshape(gts, [size(test_ind,2) 3 90 90]);
gts = permute(gts, [3 4 2 1]);

figure(1); title('2D BRDF predictions');
for ii = 1:size(test_ind,2)
    xx = ceil(ii/4); yy = mod(ii-1,4)+1;
    gtLab = gts(:, :, :, ii);
    gt = Lab2RGB(gtLab);
    gt = exp(gt) - 1;
    gt = repmat(permute(gt, [1 2 4 3]), [1 1 180 1]);
    gt = reshape(permute(gt, [4 1 2 3]), [3 90*90*180]);
    gt(:, negIdx) = -1;
    gt = permute(reshape(gt, [3 90 90 180]), [2 3 4 1]);
    subplot(5,8,(xx-1)*8+(yy-1)*2+1);
    imshow(2.0*permute(gt(:,:,90,:), [1 2 4 3]).^0.5);
    title(['Test sample: ' sprintf('%02d', ii)]);
    ylabel('Ground truth');
    predLab = preds(:, :, :, ii);
    pred = Lab2RGB(predLab);
    pred = exp(pred) - 1;
    pred = repmat(permute(pred, [1 2 4 3]), [1 1 180 1]);
    pred = reshape(permute(pred, [4 1 2 3]), [3 90*90*180]);
    pred(:, negIdx) = -1;
    pred = permute(reshape(pred, [3 90 90 180]), [2 3 4 1]);
    subplot(5,8,(xx-1)*8+(yy-1)*2+2);
    imshow(2.0*permute(pred(:,:,90,:), [1 2 4 3]).^0.5);
    ylabel('Prediction');
    writeMerlBrdf(pred, ['output/prediction.' sprintf('%02d', ii) '.2d.bin']);
end
    
if train3D
    i = 3;                          % 3D BRDFs
    
    preds = Y_star_test{i} .* repmat(s2{i}, [size(test_ind,2) 1]) + repmat(m{i}, [size(test_ind,2) 1]);
    gts = Y_test{i} .* repmat(s2{i}, [size(test_ind,2) 1]) + repmat(m{i}, [size(test_ind,2) 1]);
    
    figure(2); title('3D BRDF predictions');
    for ii = 1:size(test_ind,2)
        xx = ceil(ii/4); yy = mod(ii-1,4)+1;
        gtLab = gts(ii, :);
        gt = reshape(permute(Lab2RGB(permute(reshape(gtLab, [3 size(posIdx,2)]), [2 3 1])), [3 1 2]), [1 3*size(posIdx,2)]);
        tmp = reshape(gt, [3 size(posIdx,2)]);
        gt = -1 * ones(3, 90*90*180);
        gt(:, posIdx) = exp(tmp) - 1;
        gt = permute(reshape(gt, [3 90 90 180]), [2 3 4 1]);
        subplot(5,8,(xx-1)*8+(yy-1)*2+1);
        imshow(2.0*permute(gt(:,:,90,:), [1 2 4 3]).^0.5);
        title(['Test sample: ' sprintf('%02d', ii)]);
        ylabel('Ground truth');
        predLab = preds(ii, :);
        pred = reshape(permute(Lab2RGB(permute(reshape(predLab, [3 size(posIdx,2)]), [2 3 1])), [3 1 2]), [1 3*size(posIdx,2)]);
        tmp = reshape(pred, [3 size(posIdx,2)]);
        pred = -1 * ones(3, 90*90*180);
        pred(:, posIdx) = exp(tmp) - 1;
        pred = permute(reshape(pred, [3 90 90 180]), [2 3 4 1]);
        subplot(5,8,(xx-1)*8+(yy-1)*2+2);
        imshow(2.0*permute(pred(:,:,90,:), [1 2 4 3]).^0.5);
        ylabel('Prediction');
        writeMerlBrdf(pred, ['output/prediction.' sprintf('%02d', ii) '.3d.bin']);
    end     
end