diff --git a/setup.cfg b/setup.cfg index 254cdd2c..039a6e7c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = pytest-testmon -version = 1.3.6 +version = 1.3.7 license = AGPL author_email = tibor.arpas@infinit.sk author = Tibor Arpas, Tomas Matlovic, Daniel Hahler, Martin Racak @@ -46,4 +46,5 @@ pytest11 = testmon = testmon.pytest_testmon tox = testmon = testmon.tox_testmon - +console_scripts = + testmon-merge-db = testmon:merge_db diff --git a/testmon/__init__.py b/testmon/__init__.py index e69de29b..d44dc1f8 100644 --- a/testmon/__init__.py +++ b/testmon/__init__.py @@ -0,0 +1,18 @@ + +def merge_db(): + import argparse + from testmon.db import DB, merge_dbs + + parser = argparse.ArgumentParser() + parser.add_argument('dbs', metavar='N', type=str, nargs='+') + parser.add_argument('--output', metavar='N', type=str, nargs='?', default="merged") + parser.add_argument('--environment', metavar='N', type=str, nargs='?', default="default") + + args = parser.parse_args() + databases = args.dbs + output_db = args.output + env = args.environment + + db_1 = DB(datafile=databases[0], environment=env) + db_2 = DB(datafile=databases[1], environment=env) + merge_dbs(merged_datafile=output_db, db_1=db_1, db_2=db_2) diff --git a/testmon/db.py b/testmon/db.py index 8020cba8..cc0c83b2 100644 --- a/testmon/db.py +++ b/testmon/db.py @@ -3,6 +3,8 @@ import sqlite3 from collections import namedtuple +from sqlite3 import Binary +from typing import List, Optional from testmon.process_code import ( blob_to_checksums, @@ -22,7 +24,7 @@ class TestmonDbException(Exception): pass -def connect(datafile): +def connect(datafile: os.PathLike): connection = sqlite3.connect(datafile) connection.execute("PRAGMA synchronous = OFF") @@ -31,9 +33,31 @@ def connect(datafile): connection.row_factory = sqlite3.Row return connection +def merge_dbs(merged_datafile, db_1: "DB", db_2: "DB") -> "DB": + if db_1.env != db_2.env: + raise + + memory_db = DB(":memory:", environment=db_1.env) + file_db = DB(merged_datafile, environment=db_1.env) + with memory_db: + with db_1: + for data in db_1.all_data(): + memory_db.insert_node_fingerprints(data["name"], fingerprints=[data], failed=data["failed"], + duration=data["duration"]) + + with db_2: + for data in db_2.all_data(): + memory_db.insert_node_fingerprints(data["name"], fingerprints=[data], failed=data["failed"], + duration=data["duration"]) + + with file_db: + memory_db.con.backup(file_db.con) + + return file_db + class DB: - def __init__(self, datafile, environment="default"): + def __init__(self, datafile: os.PathLike, environment: str = "default"): new_db = not os.path.exists(datafile) connection = connect(datafile) @@ -45,7 +69,7 @@ def __init__(self, datafile, environment="default"): if new_db or old_format: self.init_tables() - def _check_data_version(self, datafile): + def _check_data_version(self, datafile: os.PathLike) -> bool: stored_data_version = self._fetch_data_version() if int(stored_data_version) == DATA_VERSION: @@ -56,14 +80,14 @@ def _check_data_version(self, datafile): self.con = connect(datafile) return True - def __enter__(self): + def __enter__(self) -> "DB": self.con = self.con.__enter__() return self def __exit__(self, *args, **kwargs): self.con.__exit__(*args, **kwargs) - def update_mtimes(self, new_mtimes): + def update_mtimes(self, new_mtimes: float): with self.con as con: con.executemany( "UPDATE fingerprint SET mtime=?, checksum=? WHERE id = ?", new_mtimes @@ -80,7 +104,7 @@ def remove_unused_fingerprints(self): """ ) - def fetch_or_create_fingerprint(self, filename, mtime, checksum, method_checksums): + def fetch_or_create_fingerprint(self, filename: str, mtime: float, checksum: str, method_checksums: Binary) -> int: cursor = self.con.cursor() try: cursor.execute( @@ -113,7 +137,7 @@ def fetch_or_create_fingerprint(self, filename, mtime, checksum, method_checksum return fingerprint_id def insert_node_fingerprints( - self, nodeid, fingerprints, failed=False, duration=None + self, nodeid: str, fingerprints: Fingerprints, failed: bool = False, duration: Optional[float] = None ): with self.con as con: cursor = con.cursor() @@ -151,7 +175,7 @@ def _fetch_data_version(self): return con.execute("PRAGMA user_version").fetchone()[0] - def _write_attribute(self, attribute, data, environment=None): + def _write_attribute(self, attribute: str, data: dict, environment: Optional[str] = None): dataid = (environment or self.env) + ":" + attribute with self.con as con: con.execute( @@ -159,7 +183,7 @@ def _write_attribute(self, attribute, data, environment=None): [dataid, json.dumps(data)], ) - def _fetch_attribute(self, attribute, default=None, environment=None): + def _fetch_attribute(self, attribute: str, default=None, environment=None): cursor = self.con.execute( "SELECT data FROM metadata WHERE dataid=?", [(environment or self.env) + ":" + attribute], @@ -214,7 +238,7 @@ def init_tables(self): connection.execute(f"PRAGMA user_version = {DATA_VERSION}") - def get_changed_file_data(self, changed_fingerprints): + def get_changed_file_data(self, changed_fingerprints: Fingerprints): in_clause_questionsmarks = ", ".join("?" * len(changed_fingerprints)) result = [] for row in self.con.execute( @@ -300,3 +324,25 @@ def filenames_fingerprints(self): ) return [dict(row) for row in cursor] + + def all_data(self) -> List[dict]: + cursor = self.con.execute( + """ + SELECT + n.name, + n.duration, + n.failed, + f.filename, + f.method_checksums, + f.mtime, + f.checksum + FROM + node n, node_fingerprint nfp, fingerprint f + WHERE + n.id = nfp.node_id AND + nfp.fingerprint_id = f.id AND + environment = ? + """, + (self.env,), + ) + return [dict(row) for row in cursor]