diff --git a/pix2struct/metrics.py b/pix2struct/metrics.py index 59e39be..6d0e9ae 100644 --- a/pix2struct/metrics.py +++ b/pix2struct/metrics.py @@ -56,11 +56,10 @@ def cider( def anls_metric(target: str, prediction: str, theta: float = 0.5): - """Calculates ANLS for DocVQA. + """Calculates ANLS for DocVQA and InfographicVQA. - There does not seem to be an official evaluation script. - Public implementation on which this implementation is based: - https://github.com/herobd/layoutlmv2/blob/main/eval_docvqa.py#L92 + Official evaluation script at https://rrc.cvc.uab.es/?ch=17&com=downloads + (Infographics VQA Evaluation scripts). Original paper (see Eq 1): https://arxiv.org/pdf/1907.00490.pdf @@ -75,7 +74,7 @@ def anls_metric(target: str, prediction: str, theta: float = 0.5): edit_distance = editdistance.eval(target, prediction) normalized_ld = edit_distance / max(len(target), len(prediction)) - return 1 - normalized_ld if normalized_ld < theta else 0 + return 1 - normalized_ld if normalized_ld <= theta else 0 def relaxed_correctness(target: str,