diff --git a/evaluate_recognition.m b/evaluate_recognition.m index ed586fb..05b8c3e 100644 --- a/evaluate_recognition.m +++ b/evaluate_recognition.m @@ -18,36 +18,52 @@ %lexicon.phocs = [lexicon.phocs;encodeWordsLength(lexicon.words,10)]; -% Embed the test attributes representation (attRepreTe_emb) +% Embed the test representations matx = emb.rndmatx(1:emb.M,:); tmp = matx*data.attReprTe; attReprTe_emb = 1/sqrt(emb.M) * [ cos(tmp); sin(tmp)]; -attReprTe_emb = bsxfun(@minus, attReprTe_emb, emb.matts); +attReprTe_emb=bsxfun(@minus, attReprTe_emb, emb.matts); attReprTe_emb = emb.Wx(:,1:emb.K)' * attReprTe_emb; attReprTe_emb = (bsxfun(@rdivide, attReprTe_emb, sqrt(sum(attReprTe_emb.*attReprTe_emb)))); -% Embed the dictionary -phocs = single(lexicon.phocs); +% Embed the lexicon dictionary +lexicon_phocs = single(lexicon.phocs); maty = emb.rndmaty(1:emb.M,:); -tmp = maty*phocs; -phocs_cca = 1/sqrt(emb.M) * [ cos(tmp); sin(tmp)]; -phocs_cca=bsxfun(@minus, phocs_cca, emb.mphocs); -phocs_cca = emb.Wy(:,1:emb.K)' * phocs_cca; -phocs_cca = (bsxfun(@rdivide, phocs_cca, sqrt(sum(phocs_cca.*phocs_cca)))); -phocs_cca(isnan(phocs_cca)) = 0; +tmp = maty*lexicon_phocs; +lexicon_phocs_emb = 1/sqrt(emb.M) * [ cos(tmp); sin(tmp)]; +lexicon_phocs_emb=bsxfun(@minus, lexicon_phocs_emb, emb.mphocs); +lexicon_phocs_emb = emb.Wy(:,1:emb.K)' * lexicon_phocs_emb; +lexicon_phocs_emb = (bsxfun(@rdivide, lexicon_phocs_emb, sqrt(sum(lexicon_phocs_emb.*lexicon_phocs_emb)))); +lexicon_phocs_emb(isnan(lexicon_phocs_emb)) = 0; words = lexicon.words; -N = size(attReprTe_emb,2); + +% Get all the valid queries. For most datasets, that is all of them. +qidx = find(~strcmp({data.wordsTe.gttext},'-')); +N = length(qidx); + +% Get the scores: p1, cer, and wer p1small = zeros(N,1); +cersmall = zeros(N,1); +wersmall = zeros(N,1); p1medium = zeros(N,1); -p1large = zeros(N,1); +cermedium = zeros(N,1); +wermedium = zeros(N,1); +p1full = zeros(N,1); +cerfull = zeros(N,1); +werfull = zeros(N,1); + for i=1:N - feat = attReprTe_emb(:,i); - gt = data.wordsTe(i).gttext; - if ~strcmpi(opts.dataset, 'LP') + % Get actual idx, feature vector, and gt transcription + pos = qidx(i); + feat = attReprTe_emb(:,pos); + gt = data.wordsTe(pos).gttext; + + % Small lexicon available + if isfield(data.wordsTe,'sLexi') smallLexicon = unique(data.wordsTe(i).sLexi); [~,~,ind] = inters(smallLexicon,words,'stable'); - scores = feat'*phocs_cca(:,ind); + scores = feat'*lexicon_phocs_emb(:,ind); randInd = randperm(length(scores)); scores = scores(randInd); [scores,I] = sort(scores,'descend'); @@ -57,12 +73,15 @@ p1small(i) = 1; else p1small(i) = 0; + cersmall(i) = levenshtein_c(gt, smallLexicon{I(1)}); end end - if strcmpi(opts.dataset,'IIIT5K') - mediumLexicon = unique(data.wordsTe(i).mLexi); + + % Medium lexicon available + if isfield(data.wordsTe,'mLexi') + mediumLexicon = unique(data.wordsTe(pos).mLexi); [~,~,ind] = inters(mediumLexicon,words,'stable'); - scores = feat'*phocs_cca(:,ind); + scores = feat'*lexicon_phocs_emb(:,ind); randInd = randperm(length(scores)); scores = scores(randInd); [scores,I] = sort(scores,'descend'); @@ -71,35 +90,80 @@ p1medium(i) = 1; else p1medium(i) = 0; + cermedium(i) = levenshtein_c(gt, mediumLexicon{I(1)}); end end - scores = feat'*phocs_cca; + % Full lexicon always available + scores = feat'*lexicon_phocs_emb; randInd = randperm(length(scores)); scores = scores(randInd); [scores,I] = sort(scores,'descend'); I = randInd(I); if strcmpi(gt,words{I(1)}) - p1large(i) = 1; + p1full(i) = 1; else - p1large(i) = 0; + p1full(i) = 0; + cerfull(i) = levenshtein_c(gt, words{I(1)}); + end +end + +% Compute wer if there is line info available +if isfield(data.wordsTe,'lineId') + linesTe = {data.wordsTe.lineId}'; + linesTe = linesTe(qidx); + if isfield(data.wordsTe,'sLexi') + recognition.wersmall = compute_wer(linesTe,p1small); end - + if isfield(data.wordsTe,'mLexi') + recognition.wermedium = compute_wer(linesTe,p1medium); + end + recognition.werfull = compute_wer(linesTe,p1full); end + + +% Display stuff +fprintf('\n'); +disp('**************************************'); +disp('************ Recognition ***********'); +disp('**************************************'); disp('------------------------------------'); -recognition.small = 100*mean(p1small); -fprintf('lexicon small -- p@1: %.2f\n', recognition.small); -if strcmpi(opts.dataset,'IIIT5K') - recognition.medium = 100*mean(p1medium); - fprintf('lexicon medium -- p@1: %.2f\n', recognition.medium); +if isfield(data.wordsTe,'sLexi') + recognition.p1small = 100*mean(p1small); + recognition.cersmall = 100*mean(cersmall); + if isfield(recognition,'wersmall') + fprintf('lexicon small -- p@1: %.2f. cer: %.2f. wer: %.2f\n', recognition.p1small, recognition.cersmall, recognition.wersmall); + else + fprintf('lexicon small -- p@1: %.2f. cer: %.2f. wer: N/A\n', recognition.p1small, recognition.cersmall); + end end -recognition.large = 100*mean(p1large); -fprintf('lexicon large -- p@1: %.2f\n', recognition.large); -disp('------------------------------------'); +if isfield(data.wordsTe,'mLexi') + recognition.p1medium = 100*mean(p1medium); + recognition.cermedium = 100*mean(cermedium); + if isfield(recognition,'wermedium') + fprintf('lexicon medium -- p@1: %.2f. cer: %.2f. wer: %.2f\n', recognition.p1medium, recognition.cermedium, recognition.wermedium); + else + fprintf('lexicon medium -- p@1: %.2f. cer: %.2f. wer: N/A\n', recognition.p1medium, recognition.cermedium); + end end +recognition.p1full = 100*mean(p1full); +recognition.cerfull = 100*mean(cerfull); +if isfield(recognition,'werfull') + fprintf('lexicon full -- p@1: %.2f. cer: %.2f. wer: %.2f\n', recognition.p1full, recognition.cerfull, recognition.werfull); +else + fprintf('lexicon full -- p@1: %.2f. cer: %.2f. wer: N/A\n', recognition.p1full, recognition.cerfull); +end +disp('------------------------------------'); +end + + + + + + % Ugly hack to deal with the lack of stable intersection in old versions of % matlab function [empty1, empty2, ind] = stableintersection(a, b, varargin) diff --git a/evaluate_retrieval.m b/evaluate_retrieval.m index ed52d67..8c94492 100644 --- a/evaluate_retrieval.m +++ b/evaluate_retrieval.m @@ -6,6 +6,7 @@ mAP.reg = []; mAP.cca = []; mAP.kcca = []; +mAP.hybrid = []; if opts.TestFV mAP.fv = evaluateFV(opts, data); @@ -31,15 +32,8 @@ mAP.kcca = evaluateKCCA(opts, data, embedding.kcca); end -% % Hybrid spotting not implemented yet -% if opts.evalHybrid -% alpha = 0:0.1:1; -% hybrid_test_map = zeros(length(alpha),1); -% for i=1:length(alpha) -% attRepr_hybrid = attReprTe_cca*alpha(i) + phocsTe_cca*(1-alpha(i)); -% [p1,mAPEucl,q] = eval_dp_asymm(opts,attRepr_hybrid, attReprTe_cca,DATA.queriesClassesTe,DATA.wordsTe); -% hybrid_test_map(i) = mean(mAPEucl); -% end -% end +if opts.TestHybrid + mAP.hybrid = evaluateHybrid(opts, data, embedding.kcca); +end end \ No newline at end of file diff --git a/extract_lexicon.m b/extract_lexicon.m index d01a891..242144b 100644 --- a/extract_lexicon.m +++ b/extract_lexicon.m @@ -1,20 +1,44 @@ function lexicon = extract_lexicon(opts,data) -words = data.words(data.idxTest); +% Small fix for versions of matlab older than 2012b ('8') that do not support stable intersection +if verLessThan('matlab', '8') + inters=@stableintersection; +else + inters=@intersect; +end % Extracts the unique set of words in the lexicon if strcmpi(opts.dataset,'IIIT5K') - words=[words(:).sLexi words(:).mLexi]; + wordsTe = data.wordsTe; + words=[wordsTe(:).sLexi wordsTe(:).mLexi]; words = unique(words); elseif strcmpi(opts.dataset,'ICDAR11') - words=[words(:).sLexi words(:).mLexi]; + wordsTe = data.wordsTe; + words=[wordsTe(:).sLexi wordsTe(:).mLexi]; words = unique(words); elseif strcmpi(opts.dataset,'SVT') - words=[words(:).sLexi]; + wordsTe = data.wordsTe; + words=[wordsTe(:).sLexi]; words = unique(words); elseif strcmpi(opts.dataset, 'LP') - words = unique({words.gttext})'; + wordsTe = data.words; + words = unique({wordsTe.gttext})'; +elseif strcmpi(opts.dataset, 'IAM') + wordsTe = data.wordsTe; + words = unique({wordsTe.gttext})'; + words(ismember(words, '-')) = []; +elseif strcmpi(opts.dataset, 'GW') + wordsTe = data.words; + words = unique({wordsTe.gttext})'; +else + error('Dataset not supported'); end +% Extracts the class of every word in the lexicon +% Class is equal to 0 if the word does not appear in the dataset +class_words = zeros(length(words),1); +[~,ia,ib] = inters(words,{wordsTe.gttext},'stable'); +class_words(ia) = [wordsTe(ib).class]; + % Extracts the PHOC embedding for every word in the lexicon voc = opts.unigrams; if opts.considerDigits @@ -32,7 +56,17 @@ lexicon.words = words; lexicon.phocs = phocs; +lexicon.class_words = class_words; save(opts.fileLexicon,'lexicon'); +end + +% Ugly hack to deal with the lack of stable intersection in old versions of +% matlab +function [empty1, ia, ib] = stableintersection(a, b, varargin) +empty1=0; +[~,ia,ib] = intersect(a,b); +[ia, tmp2] = sort(ia); +ib = ib(tmp2); end \ No newline at end of file diff --git a/load_dataset.m b/load_dataset.m index d5722f6..73dcbb2 100644 --- a/load_dataset.m +++ b/load_dataset.m @@ -17,6 +17,8 @@ data = load_ESP(opts); elseif strcmpi(opts.dataset,'LP') data = load_LP(opts); + else + error('Dataset not supported'); end save(opts.fileData,'data'); else diff --git a/prepare_opts.m b/prepare_opts.m index 80beadf..558e20c 100644 --- a/prepare_opts.m +++ b/prepare_opts.m @@ -1,5 +1,13 @@ function opts = prepare_opts() +% Adjustable paths +% Select the disk location of your datasets +opts.path_datasets = 'datasets'; +% Path where the generated files will be saved +opts.pathData = '~/watts/data'; +% Select the dataset +opts.dataset = 'GW'; + % Adding all the necessary libraries and paths addpath('util/'); if ~exist('util/bin','dir') @@ -14,6 +22,9 @@ if ~exist('phoc_mex','file') mex -o util/bin/phoc_mex -O -largeArrayDims util/phoc_mex.cpp end +if ~exist('levenshtein_c','file') + mex -o util/bin/levenshtein_c -O -largeArrayDims util/levenshtein_c.c +end if ~exist('util/vlfeat-0.9.18/toolbox/mex','dir') if isunix cd 'util/vlfeat-0.9.18/'; @@ -39,10 +50,6 @@ % Set random seed to default rng('default'); -% Select the dataset -opts.dataset = 'SVT'; - -opts.path_datasets = 'datasets'; opts.pathDataset = sprintf('%s/%s/',opts.path_datasets,opts.dataset); opts.pathImages = sprintf('%s/%s/images/',opts.path_datasets,opts.dataset); opts.pathDocuments = sprintf('%s/%s/documents/',opts.path_datasets,opts.dataset); @@ -111,16 +118,17 @@ opts.KCCA.verbose = 1; opts.evalRecog = 1; +opts.TestHybrid = 1; % Specific dataset options if strcmp(opts.dataset,'GW') opts.fold = 1; - opts.evalRecog = 0; + opts.minH = 80; + opts.maxH = 80; elseif strcmp(opts.dataset,'IAM') opts.PCADIM = 30; opts.RemoveStopWords = 1; opts.swFile = 'data/swIAM.txt'; - opts.evalRecog = 0; opts.minH = 80; opts.maxH = 80; elseif strcmp(opts.dataset,'IIIT5K') @@ -143,7 +151,7 @@ opts.FVdim = (opts.PCADIM+2)*opts.numSpatialX*opts.numSpatialY*opts.G*2; -if opts.evalRecog +if opts.evalRecog || opts.TestHybrid opts.TestKCCA = 1; end @@ -172,7 +180,6 @@ opts.tagFeatures = sprintf('%s%s%s%s',tagFeats,tagPCA,tagGMM,tagFold); % Paths and files -opts.pathData = '~/watts/data'; opts.pathFiles = sprintf('%s/files',opts.pathData); if ~exist(opts.pathData,'dir') mkdir(opts.pathData); diff --git a/util/compute_wer.m b/util/compute_wer.m new file mode 100644 index 0000000..a97a563 --- /dev/null +++ b/util/compute_wer.m @@ -0,0 +1,23 @@ +function [ wer ] = compute_wer(linesTe ,p1) +%UNTITLED Summary of this function goes here +% Detailed explanation goes here + + +% Create dictionary +dict =java.util.Hashtable; + +for i=1:length(linesTe) + dict.put(linesTe{i}, [dict.get(linesTe{i}); i]); +end + +entries = dict.entrySet.toArray; + +wer = zeros(1, length(entries)); + +for i=1:length(entries) + idxL = entries(i).getValue; + wer(i) = mean(p1(idxL)); +end +wer = 100*(1-mean(wer)); +end + diff --git a/util/evaluateHybrid.m b/util/evaluateHybrid.m new file mode 100644 index 0000000..ae8531a --- /dev/null +++ b/util/evaluateHybrid.m @@ -0,0 +1,55 @@ +function hybrid_map = evaluateHybrid(opts,DATA,embedding) + +fprintf('\n'); +disp('**************************************'); +disp('************ Hybrid KCSR ***********'); +disp('**************************************'); + +matx = embedding.rndmatx(1:embedding.M,:); +maty = embedding.rndmaty(1:embedding.M,:); + +tmp = matx*DATA.attReprTe; +attReprTe_emb = 1/sqrt(embedding.M) * [ cos(tmp); sin(tmp)]; +tmp = maty*DATA.phocsTe; +phocsTe_emb = 1/sqrt(embedding.M) * [ cos(tmp); sin(tmp)]; + +% Mean center +attReprTe_emb=bsxfun(@minus, attReprTe_emb, embedding.matts); +phocsTe_emb=bsxfun(@minus, phocsTe_emb, embedding.mphocs); + +% Embed test +attReprTe_cca = embedding.Wx(:,1:embedding.K)' * attReprTe_emb; +phocsTe_cca = embedding.Wy(:,1:embedding.K)' * phocsTe_emb; + +% L2 normalize (critical) +attReprTe_cca = (bsxfun(@rdivide, attReprTe_cca, sqrt(sum(attReprTe_cca.*attReprTe_cca)))); +phocsTe_cca = (bsxfun(@rdivide, phocsTe_cca, sqrt(sum(phocsTe_cca.*phocsTe_cca)))); + +% Evaluate +alpha = 0:0.1:1; +hybrid_map = zeros(length(alpha),1); +hybrid_p1 = zeros(length(alpha),1); +for i=1:length(alpha) + attRepr_hybrid = attReprTe_cca*alpha(i) + phocsTe_cca*(1-alpha(i)); + [p1,mAPEucl,q] = eval_dp_asymm(opts,attRepr_hybrid,attReprTe_cca,DATA.wordClsTe,DATA.labelsTe); + hybrid_map(i) = mean(mAPEucl)*100; + hybrid_p1(i) = mean(p1)*100; +end + +[best_map,idx] = max(hybrid_map); +best_p1 = hybrid_p1(idx); +best_alpha = alpha(idx); + +% Display info +disp('------------------------------------'); +fprintf('alpha: %.2f reg: %.8f. k: %d\n', best_alpha, embedding.reg, embedding.K); +fprintf('hybrid -- test: (map: %.2f. p@1: %.2f)\n', best_map, best_p1); +disp('------------------------------------'); + +plot(alpha,hybrid_map,'.-','MarkerSize',16) +title(opts.dataset) +xlabel('alpha') +ylabel('Mean Average Precision (%)') +grid on + +end diff --git a/util/learnCCA.m b/util/learnCCA.m index a63a9bc..e9110da 100644 --- a/util/learnCCA.m +++ b/util/learnCCA.m @@ -4,7 +4,7 @@ %% Part 1: Crosvalidate to find the best parameters in the config range fprintf('\n'); disp('**************************************'); -disp('************* CV CCA *************'); +disp('************* CV CSR *************'); disp('**************************************'); % A) L2 normalize and mean center. Not critical, but helps a bit. diff --git a/util/learnKCCA.m b/util/learnKCCA.m index f862d6b..b9d888e 100644 --- a/util/learnKCCA.m +++ b/util/learnKCCA.m @@ -4,7 +4,7 @@ %% Part 1: Crosvalidate to find the best parameters in the config range fprintf('\n'); disp('**************************************'); -disp('************* CV KCCA ************'); +disp('************* CV KCSR ************'); disp('**************************************'); diff --git a/util/levenshtein_c.c b/util/levenshtein_c.c new file mode 100755 index 0000000..b8534b5 --- /dev/null +++ b/util/levenshtein_c.c @@ -0,0 +1,71 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +#define MIN3(a, b, c) ((a) < (b) ? ((a) < (c) ? (a) : (c)) : ((b) < (c) ? (b) : (c))) +#define MAX(a,b) ((a) >= (b))?(a):(b) + +int levenshtein(char *s1, char *s2,int s1len,int s2len) { + unsigned int x, y, lastdiag, olddiag; + + unsigned int column[s1len+1]; + for (y = 1; y <= s1len; y++) + column[y] = y; + for (x = 1; x <= s2len; x++) { + column[0] = x; + for (y = 1, lastdiag = x-1; y <= s1len; y++) { + olddiag = column[y]; + column[y] = MIN3(column[y] + 1, column[y-1] + 1, lastdiag + (s1[y-1] == s2[x-1] ? 0 : 1)); + lastdiag = olddiag; + } + } + return(column[s1len]); +} + +void mexFunction (int nlhs, mxArray *plhs[], + int nrhs, const mxArray*prhs[]) { + + /* Input parameters */ + /* + * [0]: w1 + * [1]: w2 + */ + + /* Output parameters */ + /* [0]: dist + */ + + int N1, N2; + char *w1; + char *w2; + + + + /* Read Data */ + w1 = mxArrayToString(prhs[0]); + w2 = mxArrayToString(prhs[1]); + N1 = strlen(w1); + N2 = strlen(w2); + + + /* Prepare output */ + mwSize dims[1]; + dims[0]= 1; + plhs[0] = mxCreateNumericArray(1, dims, mxSINGLE_CLASS, mxREAL); + float *d = (float*)mxGetData(plhs[0]); + + *d = levenshtein(w1,w2,N1,N2)/(float)(MAX(N1,N2)); + + return; + +} + +