diff --git a/controlcap/common/evaluation/meteor/__init__.py b/controlcap/common/evaluation/meteor/__init__.py new file mode 100644 index 0000000..3f7d85b --- /dev/null +++ b/controlcap/common/evaluation/meteor/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' diff --git a/controlcap/common/evaluation/meteor/data/paraphrase-en.gz b/controlcap/common/evaluation/meteor/data/paraphrase-en.gz new file mode 100644 index 0000000..88033c8 Binary files /dev/null and b/controlcap/common/evaluation/meteor/data/paraphrase-en.gz differ diff --git a/controlcap/common/evaluation/meteor/meteor-1.5.jar b/controlcap/common/evaluation/meteor/meteor-1.5.jar new file mode 100644 index 0000000..a833bc0 Binary files /dev/null and b/controlcap/common/evaluation/meteor/meteor-1.5.jar differ diff --git a/controlcap/common/evaluation/meteor/meteor.py b/controlcap/common/evaluation/meteor/meteor.py new file mode 100644 index 0000000..fda124a --- /dev/null +++ b/controlcap/common/evaluation/meteor/meteor.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python + +# Python wrapper for METEOR implementation, by Xinlei Chen +# Acknowledge Michael Denkowski for the generous discussion and help +from __future__ import division + +import atexit +import logging +import os +import re +import subprocess +import sys +import threading + +import psutil + +# Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. +METEOR_JAR = 'meteor-1.5.jar' + + +def enc(s): + return s.encode('utf-8') + + +def dec(s): + return s.decode('utf-8') + + +class Meteor: + + def __init__(self): + # Used to guarantee thread safety + self.lock = threading.Lock() + + mem = '2G' + mem_available_G = psutil.virtual_memory().available / 1E9 + if mem_available_G < 2: + logging.warning("There is less than 2GB of available memory.\n" + "Will try with limiting Meteor to 1GB of memory but this might cause issues.\n" + "If you have problems using Meteor, " + "then you can try to lower the `mem` variable in meteor.py") + mem = '1G' + + meteor_cmd = ['java', '-jar', '-Xmx{}'.format(mem), METEOR_JAR, + '-', '-', '-stdio', '-l', 'en', '-norm'] + env = os.environ.copy() + env['LC_ALL'] = "C" + self.meteor_p = subprocess.Popen(meteor_cmd, + cwd=os.path.dirname(os.path.abspath(__file__)), + env=env, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + atexit.register(self.close) + + def close(self): + with self.lock: + if self.meteor_p: + self.meteor_p.kill() + self.meteor_p.wait() + self.meteor_p = None + # if the user calls close() manually, remove the + # reference from atexit so the object can be garbage-collected. + if atexit is not None and atexit.unregister is not None: + atexit.unregister(self.close) + + def compute_score(self, gts, res): + assert (gts.keys() == res.keys()) + imgIds = gts.keys() + scores = [] + + eval_line = 'EVAL' + with self.lock: + for i in imgIds: + assert (len(res[i]) == 1) + stat = self._stat(res[i][0], gts[i]) + eval_line += ' ||| {}'.format(stat) + + self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) + self.meteor_p.stdin.flush() + for i in range(0, len(imgIds)): + v = self.meteor_p.stdout.readline() + try: + scores.append(float(dec(v.strip()))) + except: + sys.stderr.write("Error handling value: {}\n".format(v)) + sys.stderr.write("Decoded value: {}\n".format(dec(v.strip()))) + sys.stderr.write("eval_line: {}\n".format(eval_line)) + # You can try uncommenting the next code line to show stderr from the Meteor JAR. + # If the Meteor JAR is not writing to stderr, then the line will just hang. + # sys.stderr.write("Error from Meteor:\n{}".format(self.meteor_p.stderr.read())) + raise + score = float(dec(self.meteor_p.stdout.readline()).strip()) + + return score, scores + + def method(self): + return "METEOR" + + def _stat(self, hypothesis_str, reference_list): + # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words + hypothesis_str = hypothesis_str.replace('|||', '') + score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) + score_line = re.sub(r'\s+', ' ', score_line) + self.meteor_p.stdin.write(enc(score_line)) + self.meteor_p.stdin.write(enc('\n')) + self.meteor_p.stdin.flush() + return dec(self.meteor_p.stdout.readline()).strip() + + def _score(self, hypothesis_str, reference_list): + with self.lock: + # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words + hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') + score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) + self.meteor_p.stdin.write(enc('{}\n'.format(score_line))) + self.meteor_p.stdin.flush() + stats = dec(self.meteor_p.stdout.readline()).strip() + eval_line = 'EVAL ||| {}'.format(stats) + # EVAL ||| stats + self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) + self.meteor_p.stdin.flush() + score = float(dec(self.meteor_p.stdout.readline()).strip()) + # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice + # thanks for Andrej for pointing this out + score = float(dec(self.meteor_p.stdout.readline()).strip()) + return score + + def __del__(self): + self.close() diff --git a/controlcap/common/evaluation/meteor/tests/test_meteor.py b/controlcap/common/evaluation/meteor/tests/test_meteor.py new file mode 100644 index 0000000..a9965f3 --- /dev/null +++ b/controlcap/common/evaluation/meteor/tests/test_meteor.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +import unittest + +from nlgeval.pycocoevalcap.meteor.meteor import Meteor + + +class TestMeteor(unittest.TestCase): + def test_compute_score(self): + m = Meteor() + + s = m.compute_score({0: ["test"]}, {0: ["test"]}) + self.assertEqual(s, (1.0, [1.0])) + + s = m.compute_score({0: ["テスト"]}, {0: ["テスト"]}) + self.assertEqual(s, (1.0, [1.0])) diff --git a/demo/demo.yaml b/demo/demo.yaml index 7bc5218..0f8486b 100644 --- a/demo/demo.yaml +++ b/demo/demo.yaml @@ -22,4 +22,4 @@ run: batch_size_eval: 64 output_dir: "output/eval/vg1.2" eval_dataset_name: "vg_reg" - load_ckpt_path: "/Workspace/ZhaoYuzhong/ControlCap/output/train/vg_refcocog_5e/20240316102/checkpoint_4.pth" \ No newline at end of file + load_ckpt_path: "ckpts/vg1.2_refcocog_5e.pth"