Skip to content

Commit

Permalink
Merge branch 'eval-protocols'
Browse files Browse the repository at this point in the history
Conflicts:
	load_dataset.m
  • Loading branch information
almazan committed Mar 18, 2014
2 parents 7cace11 + c1d32e3 commit 5d713cf
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 56 deletions.
126 changes: 95 additions & 31 deletions evaluate_recognition.m
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,52 @@

%lexicon.phocs = [lexicon.phocs;encodeWordsLength(lexicon.words,10)];

% Embed the test attributes representation (attRepreTe_emb)
% Embed the test representations
matx = emb.rndmatx(1:emb.M,:);
tmp = matx*data.attReprTe;
attReprTe_emb = 1/sqrt(emb.M) * [ cos(tmp); sin(tmp)];
attReprTe_emb = bsxfun(@minus, attReprTe_emb, emb.matts);
attReprTe_emb=bsxfun(@minus, attReprTe_emb, emb.matts);
attReprTe_emb = emb.Wx(:,1:emb.K)' * attReprTe_emb;
attReprTe_emb = (bsxfun(@rdivide, attReprTe_emb, sqrt(sum(attReprTe_emb.*attReprTe_emb))));

% Embed the dictionary
phocs = single(lexicon.phocs);
% Embed the lexicon dictionary
lexicon_phocs = single(lexicon.phocs);
maty = emb.rndmaty(1:emb.M,:);
tmp = maty*phocs;
phocs_cca = 1/sqrt(emb.M) * [ cos(tmp); sin(tmp)];
phocs_cca=bsxfun(@minus, phocs_cca, emb.mphocs);
phocs_cca = emb.Wy(:,1:emb.K)' * phocs_cca;
phocs_cca = (bsxfun(@rdivide, phocs_cca, sqrt(sum(phocs_cca.*phocs_cca))));
phocs_cca(isnan(phocs_cca)) = 0;
tmp = maty*lexicon_phocs;
lexicon_phocs_emb = 1/sqrt(emb.M) * [ cos(tmp); sin(tmp)];
lexicon_phocs_emb=bsxfun(@minus, lexicon_phocs_emb, emb.mphocs);
lexicon_phocs_emb = emb.Wy(:,1:emb.K)' * lexicon_phocs_emb;
lexicon_phocs_emb = (bsxfun(@rdivide, lexicon_phocs_emb, sqrt(sum(lexicon_phocs_emb.*lexicon_phocs_emb))));
lexicon_phocs_emb(isnan(lexicon_phocs_emb)) = 0;
words = lexicon.words;

N = size(attReprTe_emb,2);

% Get all the valid queries. For most datasets, that is all of them.
qidx = find(~strcmp({data.wordsTe.gttext},'-'));
N = length(qidx);

% Get the scores: p1, cer, and wer
p1small = zeros(N,1);
cersmall = zeros(N,1);
wersmall = zeros(N,1);
p1medium = zeros(N,1);
p1large = zeros(N,1);
cermedium = zeros(N,1);
wermedium = zeros(N,1);
p1full = zeros(N,1);
cerfull = zeros(N,1);
werfull = zeros(N,1);

for i=1:N
feat = attReprTe_emb(:,i);
gt = data.wordsTe(i).gttext;
if ~strcmpi(opts.dataset, 'LP')
% Get actual idx, feature vector, and gt transcription
pos = qidx(i);
feat = attReprTe_emb(:,pos);
gt = data.wordsTe(pos).gttext;

% Small lexicon available
if isfield(data.wordsTe,'sLexi')
smallLexicon = unique(data.wordsTe(i).sLexi);
[~,~,ind] = inters(smallLexicon,words,'stable');
scores = feat'*phocs_cca(:,ind);
scores = feat'*lexicon_phocs_emb(:,ind);
randInd = randperm(length(scores));
scores = scores(randInd);
[scores,I] = sort(scores,'descend');
Expand All @@ -57,12 +73,15 @@
p1small(i) = 1;
else
p1small(i) = 0;
cersmall(i) = levenshtein_c(gt, smallLexicon{I(1)});
end
end
if strcmpi(opts.dataset,'IIIT5K')
mediumLexicon = unique(data.wordsTe(i).mLexi);

% Medium lexicon available
if isfield(data.wordsTe,'mLexi')
mediumLexicon = unique(data.wordsTe(pos).mLexi);
[~,~,ind] = inters(mediumLexicon,words,'stable');
scores = feat'*phocs_cca(:,ind);
scores = feat'*lexicon_phocs_emb(:,ind);
randInd = randperm(length(scores));
scores = scores(randInd);
[scores,I] = sort(scores,'descend');
Expand All @@ -71,35 +90,80 @@
p1medium(i) = 1;
else
p1medium(i) = 0;
cermedium(i) = levenshtein_c(gt, mediumLexicon{I(1)});
end
end

scores = feat'*phocs_cca;
% Full lexicon always available
scores = feat'*lexicon_phocs_emb;
randInd = randperm(length(scores));
scores = scores(randInd);
[scores,I] = sort(scores,'descend');
I = randInd(I);

if strcmpi(gt,words{I(1)})
p1large(i) = 1;
p1full(i) = 1;
else
p1large(i) = 0;
p1full(i) = 0;
cerfull(i) = levenshtein_c(gt, words{I(1)});
end
end

% Compute wer if there is line info available
if isfield(data.wordsTe,'lineId')
linesTe = {data.wordsTe.lineId}';
linesTe = linesTe(qidx);
if isfield(data.wordsTe,'sLexi')
recognition.wersmall = compute_wer(linesTe,p1small);
end

if isfield(data.wordsTe,'mLexi')
recognition.wermedium = compute_wer(linesTe,p1medium);
end
recognition.werfull = compute_wer(linesTe,p1full);
end


% Display stuff
fprintf('\n');
disp('**************************************');
disp('************ Recognition ***********');
disp('**************************************');
disp('------------------------------------');
recognition.small = 100*mean(p1small);
fprintf('lexicon small -- p@1: %.2f\n', recognition.small);
if strcmpi(opts.dataset,'IIIT5K')
recognition.medium = 100*mean(p1medium);
fprintf('lexicon medium -- p@1: %.2f\n', recognition.medium);
if isfield(data.wordsTe,'sLexi')
recognition.p1small = 100*mean(p1small);
recognition.cersmall = 100*mean(cersmall);
if isfield(recognition,'wersmall')
fprintf('lexicon small -- p@1: %.2f. cer: %.2f. wer: %.2f\n', recognition.p1small, recognition.cersmall, recognition.wersmall);
else
fprintf('lexicon small -- p@1: %.2f. cer: %.2f. wer: N/A\n', recognition.p1small, recognition.cersmall);
end
end
recognition.large = 100*mean(p1large);
fprintf('lexicon large -- p@1: %.2f\n', recognition.large);
disp('------------------------------------');

if isfield(data.wordsTe,'mLexi')
recognition.p1medium = 100*mean(p1medium);
recognition.cermedium = 100*mean(cermedium);
if isfield(recognition,'wermedium')
fprintf('lexicon medium -- p@1: %.2f. cer: %.2f. wer: %.2f\n', recognition.p1medium, recognition.cermedium, recognition.wermedium);
else
fprintf('lexicon medium -- p@1: %.2f. cer: %.2f. wer: N/A\n', recognition.p1medium, recognition.cermedium);
end
end

recognition.p1full = 100*mean(p1full);
recognition.cerfull = 100*mean(cerfull);
if isfield(recognition,'werfull')
fprintf('lexicon full -- p@1: %.2f. cer: %.2f. wer: %.2f\n', recognition.p1full, recognition.cerfull, recognition.werfull);
else
fprintf('lexicon full -- p@1: %.2f. cer: %.2f. wer: N/A\n', recognition.p1full, recognition.cerfull);
end
disp('------------------------------------');
end






% Ugly hack to deal with the lack of stable intersection in old versions of
% matlab
function [empty1, empty2, ind] = stableintersection(a, b, varargin)
Expand Down
14 changes: 4 additions & 10 deletions evaluate_retrieval.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
mAP.reg = [];
mAP.cca = [];
mAP.kcca = [];
mAP.hybrid = [];

if opts.TestFV
mAP.fv = evaluateFV(opts, data);
Expand All @@ -31,15 +32,8 @@
mAP.kcca = evaluateKCCA(opts, data, embedding.kcca);
end

% % Hybrid spotting not implemented yet
% if opts.evalHybrid
% alpha = 0:0.1:1;
% hybrid_test_map = zeros(length(alpha),1);
% for i=1:length(alpha)
% attRepr_hybrid = attReprTe_cca*alpha(i) + phocsTe_cca*(1-alpha(i));
% [p1,mAPEucl,q] = eval_dp_asymm(opts,attRepr_hybrid, attReprTe_cca,DATA.queriesClassesTe,DATA.wordsTe);
% hybrid_test_map(i) = mean(mAPEucl);
% end
% end
if opts.TestHybrid
mAP.hybrid = evaluateHybrid(opts, data, embedding.kcca);
end

end
44 changes: 39 additions & 5 deletions extract_lexicon.m
Original file line number Diff line number Diff line change
@@ -1,20 +1,44 @@
function lexicon = extract_lexicon(opts,data)
words = data.words(data.idxTest);
% Small fix for versions of matlab older than 2012b ('8') that do not support stable intersection
if verLessThan('matlab', '8')
inters=@stableintersection;
else
inters=@intersect;
end

% Extracts the unique set of words in the lexicon
if strcmpi(opts.dataset,'IIIT5K')
words=[words(:).sLexi words(:).mLexi];
wordsTe = data.wordsTe;
words=[wordsTe(:).sLexi wordsTe(:).mLexi];
words = unique(words);
elseif strcmpi(opts.dataset,'ICDAR11')
words=[words(:).sLexi words(:).mLexi];
wordsTe = data.wordsTe;
words=[wordsTe(:).sLexi wordsTe(:).mLexi];
words = unique(words);
elseif strcmpi(opts.dataset,'SVT')
words=[words(:).sLexi];
wordsTe = data.wordsTe;
words=[wordsTe(:).sLexi];
words = unique(words);
elseif strcmpi(opts.dataset, 'LP')
words = unique({words.gttext})';
wordsTe = data.words;
words = unique({wordsTe.gttext})';
elseif strcmpi(opts.dataset, 'IAM')
wordsTe = data.wordsTe;
words = unique({wordsTe.gttext})';
words(ismember(words, '-')) = [];
elseif strcmpi(opts.dataset, 'GW')
wordsTe = data.words;
words = unique({wordsTe.gttext})';
else
error('Dataset not supported');
end

% Extracts the class of every word in the lexicon
% Class is equal to 0 if the word does not appear in the dataset
class_words = zeros(length(words),1);
[~,ia,ib] = inters(words,{wordsTe.gttext},'stable');
class_words(ia) = [wordsTe(ib).class];

% Extracts the PHOC embedding for every word in the lexicon
voc = opts.unigrams;
if opts.considerDigits
Expand All @@ -32,7 +56,17 @@

lexicon.words = words;
lexicon.phocs = phocs;
lexicon.class_words = class_words;

save(opts.fileLexicon,'lexicon');

end

% Ugly hack to deal with the lack of stable intersection in old versions of
% matlab
function [empty1, ia, ib] = stableintersection(a, b, varargin)
empty1=0;
[~,ia,ib] = intersect(a,b);
[ia, tmp2] = sort(ia);
ib = ib(tmp2);
end
2 changes: 2 additions & 0 deletions load_dataset.m
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
data = load_ESP(opts);
elseif strcmpi(opts.dataset,'LP')
data = load_LP(opts);
else
error('Dataset not supported');
end
save(opts.fileData,'data');
else
Expand Down
23 changes: 15 additions & 8 deletions prepare_opts.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
function opts = prepare_opts()

% Adjustable paths
% Select the disk location of your datasets
opts.path_datasets = 'datasets';
% Path where the generated files will be saved
opts.pathData = '~/watts/data';
% Select the dataset
opts.dataset = 'GW';

% Adding all the necessary libraries and paths
addpath('util/');
if ~exist('util/bin','dir')
Expand All @@ -14,6 +22,9 @@
if ~exist('phoc_mex','file')
mex -o util/bin/phoc_mex -O -largeArrayDims util/phoc_mex.cpp
end
if ~exist('levenshtein_c','file')
mex -o util/bin/levenshtein_c -O -largeArrayDims util/levenshtein_c.c
end
if ~exist('util/vlfeat-0.9.18/toolbox/mex','dir')
if isunix
cd 'util/vlfeat-0.9.18/';
Expand All @@ -39,10 +50,6 @@
% Set random seed to default
rng('default');

% Select the dataset
opts.dataset = 'SVT';

opts.path_datasets = 'datasets';
opts.pathDataset = sprintf('%s/%s/',opts.path_datasets,opts.dataset);
opts.pathImages = sprintf('%s/%s/images/',opts.path_datasets,opts.dataset);
opts.pathDocuments = sprintf('%s/%s/documents/',opts.path_datasets,opts.dataset);
Expand Down Expand Up @@ -111,16 +118,17 @@
opts.KCCA.verbose = 1;

opts.evalRecog = 1;
opts.TestHybrid = 1;

% Specific dataset options
if strcmp(opts.dataset,'GW')
opts.fold = 1;
opts.evalRecog = 0;
opts.minH = 80;
opts.maxH = 80;
elseif strcmp(opts.dataset,'IAM')
opts.PCADIM = 30;
opts.RemoveStopWords = 1;
opts.swFile = 'data/swIAM.txt';
opts.evalRecog = 0;
opts.minH = 80;
opts.maxH = 80;
elseif strcmp(opts.dataset,'IIIT5K')
Expand All @@ -143,7 +151,7 @@

opts.FVdim = (opts.PCADIM+2)*opts.numSpatialX*opts.numSpatialY*opts.G*2;

if opts.evalRecog
if opts.evalRecog || opts.TestHybrid
opts.TestKCCA = 1;
end

Expand Down Expand Up @@ -172,7 +180,6 @@
opts.tagFeatures = sprintf('%s%s%s%s',tagFeats,tagPCA,tagGMM,tagFold);

% Paths and files
opts.pathData = '~/watts/data';
opts.pathFiles = sprintf('%s/files',opts.pathData);
if ~exist(opts.pathData,'dir')
mkdir(opts.pathData);
Expand Down
23 changes: 23 additions & 0 deletions util/compute_wer.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
function [ wer ] = compute_wer(linesTe ,p1)
%UNTITLED Summary of this function goes here
% Detailed explanation goes here


% Create dictionary
dict =java.util.Hashtable;

for i=1:length(linesTe)
dict.put(linesTe{i}, [dict.get(linesTe{i}); i]);
end

entries = dict.entrySet.toArray;

wer = zeros(1, length(entries));

for i=1:length(entries)
idxL = entries(i).getValue;
wer(i) = mean(p1(idxL));
end
wer = 100*(1-mean(wer));
end

Loading

0 comments on commit 5d713cf

Please sign in to comment.