diff --git a/subsetter/sampler.py b/subsetter/sampler.py index db45393..c0f3532 100644 --- a/subsetter/sampler.py +++ b/subsetter/sampler.py @@ -692,6 +692,7 @@ def __init__(self, source: DatabaseConfig, config: SamplerConfig) -> None: self.compact_columns: Dict[Tuple[str, str], Set[str]] = {} self.temp_tables = TempTableCreator() self.passthrough_tables: Set[str] = set() + self.cached_table_sizes: Dict[Tuple[str, str], int] = {} def sample( self, @@ -866,6 +867,7 @@ def _materialize_tables( primary_key=table.primary_key, ) ) + self.cached_table_sizes[(schema, table_name)] = rowcount LOGGER.info( "Materialized %d rows for %s.%s in temporary table", rowcount, @@ -1005,9 +1007,9 @@ def _copy_results( rows = 0 - def _count_rows(result): + def _count_rows(result, total: Optional[int]): nonlocal rows - for row in tqdm(result, desc="row progress", unit="rows"): + for row in tqdm(result, total=total, desc="row progress", unit="rows"): # result_processor rows += 1 yield row @@ -1017,7 +1019,7 @@ def _count_rows(result): schema, table_name, columns, - _count_rows(result), + _count_rows(result, self.cached_table_sizes.get((schema, table_name))), filter_view=filter_view, multiplier=( 1