diff --git a/CHANGELOG.md b/CHANGELOG.md index 594ded2..23b3a00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,29 @@ + # CHANGELOG -This changelog file outlines a chronologically ordered list of the changes made on this project. + +This changelog file outlines a chronologically ordered list of the changes made on this project. It is organized by version and release date followed by a list of Enhancements, New Features, Bug Fixes, and/or Breaking Changes. -

+## Version 3.1.1 + +**Released:** December 14, 2025 +**Tag:** v3.1.1 + +### Bug Fixes -## Version 3.1.0 (Latest) -**Released:** April 22, 2025
+- Fixed IndexError in wer, wers, werp, werps, summary, and summaryp functions when processing single string inputs. The issue occurred due to incorrect handling of np.vectorize return values, which produce different array structures for single versus batch inputs (0-dimensional arrays for single strings, 1-dimensional object arrays for lists). The fix unwraps 0-dimensional arrays and uses element-type checking to distinguish between single example vectors and batch processing scenarios. + +- Improved handling of ragged data structures in batch processing by implementing direct field summation instead of transpose operations. This prevents errors when processing mixed scalar and list data in the word error rate breakdown results. + +### Enhancements + +- Added comprehensive benchmarking infrastructure to compare performance against werx and jiwer packages using the LibriSpeech evaluation dataset. + +- Added optional dependencies groups in pyproject.toml for testing and benchmarking workflows, enabling easier development environment setup. + +## Version 3.1.0 + +**Released:** April 22, 2025 **Tag:** v3.1.0 ### Enhancements @@ -24,11 +42,9 @@ It is organized by version and release date followed by a list of Enhancements, - Loop indices and size variables now use Py_ssize_t, matching Python's internal conventions. - Grouped and explicitly typed intermediate variables like inserted_words, deleted_words, and substituted_words for improved readability and static checks. This enhances code quality, reduces reliance on dynamic typing in performance-critical paths, and prepares the function for future optimizations. +## Version 3.0.2 -
- -## Version 3.0.2 -**Released:** April 3, 2025
+**Released:** April 3, 2025 **Tag:** v3.0.2 ### Enhancements @@ -54,12 +70,9 @@ It is organized by version and release date followed by a list of Enhancements, - Initial support for static type checking: included py.typed marker file to enable type checkers to recognize the package as typed. Note: full type coverage is not yet guaranteed and will be improved incrementally. +## Version 3.0.1 -
- - -## Version 3.0.1 -**Released:** March 26, 2025
+**Released:** March 26, 2025 **Tag:** v3.0.1 ### Enhancements @@ -71,13 +84,9 @@ It is organized by version and release date followed by a list of Enhancements, - Publishing process is now automated using GitHub Actions and PyPI Trusted Publishing +## Version 3.0.0 - -
- - -## Version 3.0.0 -**Released:** March 20, 2025
+**Released:** March 20, 2025 **Tag:** v3.0.0 ### Breaking Changes @@ -94,7 +103,6 @@ It is organized by version and release date followed by a list of Enhancements, - NumPy 2.x is required for Python 3.12+, which may introduce API changes - ### Enhancements - Official support for Python 3.13: @@ -127,12 +135,9 @@ It is organized by version and release date followed by a list of Enhancements, - Updated Cython Function Type Annotation - Changed cpdef np.ndarray calculations(...) to cpdef cnp.ndarray calculations(...) to properly reference the Cython-level NumPy API, ensuring type safety and compatibility with compiled C extensions. - -
- - ## Version 2.1.3-beta -**Released:** Not released
+ +**Released:** Not released **Tag:** v2.1.3-beta ### Enhancements @@ -147,11 +152,9 @@ All the following changes were incorporated into the major v3.0.0 rollout - Bump sphinx-nefertiti from 0.3.2 to 0.3.4 ([#12](https://github.com/analyticsinmotion/werpy/pull/12)) - Bump certifi from 2023.11.17 to 2024.7.4 in /docs ([#13](https://github.com/analyticsinmotion/werpy/pull/13)) -
- - ## Version 2.1.2 -**Released:** April 5, 2024
+ +**Released:** April 5, 2024 **Tag:** v2.1.2 ### Enhancements @@ -170,7 +173,6 @@ All the following changes were incorporated into the major v3.0.0 rollout - Ensured compliance with the Black code formatting by modifying relevant files. - ### Changed - Bump jinja2 from 3.1.2 to 3.1.3 in /docs. ([#1](https://github.com/analyticsinmotion/werpy/pull/1)) @@ -180,11 +182,9 @@ All the following changes were incorporated into the major v3.0.0 rollout - Bump sphinx-nefertiti from 0.2.3 to 0.3.1 ([#3](https://github.com/analyticsinmotion/werpy/pull/3)) - -

- ## Version 2.1.1 -**Released:** November 27, 2023
+ +**Released:** November 27, 2023 **Tag:** v2.1.1 ### Enhancements @@ -194,22 +194,18 @@ All the following changes were incorporated into the major v3.0.0 rollout - Passed the Cython source code directly to the py.extension_module() definition for improved integration. - Specified the C standard configuration as C11, instructing Meson to use C11 as the designated C standard. - -

- ## Version 2.1.0 -**Released:** November 23, 2023
+ +**Released:** November 23, 2023 **Tag:** v2.1.0 ### New Feature - Enhanced cross-platform support by integrating cibuildwheel, enabling compatibility with macOS and popular Linux distributions. With existing Windows compatibility, the package now spans all major configurations. Feel free to reach out if you have a specific OS configuration you'd like to discuss for potential inclusion. +## Version 2.0.0 -

- -## Version 2.0.0 -**Released:** November 23, 2023
+**Released:** November 23, 2023 **Tag:** v2.0.0 ### New Feature @@ -218,18 +214,15 @@ All the following changes were incorporated into the major v3.0.0 rollout - During the transition to utilizing C optimizations, we opted to switch our Python package build system from Hatchling to Mesonpy. Mesonpy facilitates seamless compilation of C code as an integral part of the package build process. As a result of this transition, you can expect modifications to the pyproject.toml file and the introduction of a new meson.build file. This change under the hood enables us to integrate both Python and C code within the package natively for the performance optimizations. - ### Breaking Changes -- In this significant application update, we are introducing phased support for different operating systems. Initially, this version will exclusively support Windows. However, swift additions for UNIX/Linux and macOS compatibility are already in the pipeline and will be incorporated promptly. This temporary change allows us to roll out the major version upgrade incrementally while ensuring reliability for our user base. +- In this significant application update, we are introducing phased support for different operating systems. Initially, this version will exclusively support Windows. However, swift additions for UNIX/Linux and macOS compatibility are already in the pipeline and will be incorporated promptly. This temporary change allows us to roll out the major version upgrade incrementally while ensuring reliability for our user base. - Certain web applications relying exclusively on pure Python environments might encounter challenges running this package successfully. If your applications are affected, please don't hesitate to get in touch to discuss potential compatibility issues. - - -

## Version 1.1.2 -**Released:** November 17, 2023
+ +**Released:** November 17, 2023 **Tag:** v1.1.2 ### New Feature @@ -238,18 +231,15 @@ All the following changes were incorporated into the major v3.0.0 rollout - Added corresponding new tests for the 'summaryp' function within the 'werpy' package, enhancing test coverage and ensuring robust functionality. The additional tests provide comprehensive validation of the 'summaryp' function, contributing to improved reliability and accuracy in the package's performance. - ### Bug Fixes - Fixed an AttributeError in the 'summary.py' module that occurred when attempting to access the 'size' attribute of a 'float' object. This error happened when the module was provided a single reference and hypothesis string as input. The issue has been resolved. - Fixed an issue with the attributes of DataFrame column name "ld". Resolved the discrepancy in the "dtype" attribute from int32 to int64. - -

+## Version 1.1.1 -## Version 1.1.1 -**Released:** November 13, 2023
+**Released:** November 13, 2023 **Tag:** v1.1.1 ### Enhancements @@ -258,18 +248,15 @@ All the following changes were incorporated into the major v3.0.0 rollout - Added the CircleCI badge to the repository Readme.md file - ### Bug Fixes - Fixed an AttributeError in the wers.py module caused by an non-standard operation on a 'float' object. This only occurred when a single reference and hypothesis input string was entered and has now been rectified. - Resolved an AttributeError in the 'werps.py' module, which was triggered by a 'float' object having a size attribute. This issue specifically arose when a single reference and hypothesis input string was provided, and it has been successfully addressed. - - -

## Version 1.1.0 -**Released:** November 8, 2023
+ +**Released:** November 8, 2023 **Tag:** v1.1.0 ### Enhancements @@ -278,15 +265,14 @@ All the following changes were incorporated into the major v3.0.0 rollout - Added the following unit tests to improve code coverage and validation for the functions in the werpy module. The new tests cover additional use cases with longer input sequences and help ensure the wer calculation works properly in different scenarios. - Added new unit tests for the wer module. - - Added new unit tests for the wers module. - - Added new unit tests for the werp module. - - Added new unit tests for the werps module. + - Added new unit tests for the wers module. + - Added new unit tests for the werp module. + - Added new unit tests for the werps module. - Added new unit tests for the summary module. -

- ## Version 1.0.0 -**Released:** November 2, 2023
+ +**Released:** November 2, 2023 **Tag:** v1.0.0 ### Enhancements @@ -295,10 +281,9 @@ All the following changes were incorporated into the major v3.0.0 rollout - Added new unit tests for the normalize module. These tests focus on improving test coverage, enhancing the reliability of the module, and ensuring the accuracy of the normalization process. By incorporating these tests, we aim to identify and address issues early in the development cycle, making the upcoming release more stable and reliable. -

- ## Version 0.0.5 -**Released:** October 26, 2023
+ +**Released:** October 26, 2023 **Tag:** v0.0.5 ### Enhancements @@ -310,50 +295,44 @@ All the following changes were incorporated into the major v3.0.0 rollout - Added a new method to the "normalization" module: - `remove_whitespace(text)`: This new method efficiently removes all excess spaces in the input text. It replaces consecutive sequences of spaces with a single space and removes any leading or trailing spaces, ensuring a cleaner and more consistent text output. -

+## Version 0.0.4 -## Version 0.0.4 -**Released:** May 4, 2023
+**Released:** May 4, 2023 **Tag:** v0.0.4 ### Enhancements - The code to handle exceptions and errors has been refactored to reduce code duplication across modules. In addition, the changes will make adding and testing errors or exceptions easier to maintain in the future. - ### Bug Fix - Fixed a number of inconsistent return statements (R1710) within the package modules. This ensures that all functions will return a consistent expression when called. -

- ## Version 0.0.3 -**Released:** May 2, 2023
+ +**Released:** May 2, 2023 **Tag:** v0.0.3 ### Bug Fix - Fixed a bug contained within the modules that was causing a Cyclic Import issue (R0401). One of the import statements was missing a period at the start of the module name. The fix has been tested and deployed successfully. -

+## Version 0.0.2 -## Version 0.0.2 -**Released:** May 1, 2023
+**Released:** May 1, 2023 **Tag:** v0.0.2 ### General Changes - Added Module Docstrings - ### Bug Fix - Fixed an unidiomatic-typecheck (C0123) from type() to isinstance(). The idiomatic way to perform an explicit typecheck in Python is to use isinstance(x, y) rather than type(x) == Y. -

- ## Version 0.0.1 (Initial Release) -**Released:** April 28, 2023
+ +**Released:** April 28, 2023 **Tag:** v0.0.1 This is the initial release diff --git a/benchmarks/speed_comparison_librispeech_full.py b/benchmarks/speed_comparison_librispeech_full.py new file mode 100644 index 0000000..e6978ee --- /dev/null +++ b/benchmarks/speed_comparison_librispeech_full.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: 2023 Analytics in Motion +# SPDX-License-Identifier: BSD-3-Clause + +""" +Speed comparison benchmark for WER calculation packages. + +This script compares the performance of werx, werpy, and jiwer on the +LibriSpeech evaluation dataset using timeit for accurate timing measurements. +""" + +from datasets import load_dataset +import werpy +import werx +import jiwer +import timeit + +# Load the consolidated CSV from the Hugging Face Hub +dataset = load_dataset( + "analyticsinmotion/librispeech-eval", + data_files="all_splits.csv", + split="train" +) + +# Specify which splits and model/version to evaluate +splits = ["test-clean", "test-other"] +model_name = "whisper-base" +model_version = "v20240930" + +for split in splits: + print(f"\n{'='*70}") + print(f"Benchmarking: {split}") + print(f"{'='*70}\n") + + # Filter references and hypotheses for the chosen split/model/version + filtered = dataset.filter( + lambda x: x["split"] == split and + x["model_name"] == model_name and + x["model_version"] == model_version + ) + + filtered = list(filtered) + references = [werpy.normalize(row["reference"]) for row in filtered] + hypotheses = [werpy.normalize(row["hypothesis"]) for row in filtered] + + print(f"Loaded {len(references):,} utterances\n") + + # --- WER tools --- + tools = { + "WERX": werx.wer, + "WERPY": werpy.wer, + "JIWER": jiwer.wer, + } + + # --- Run + time each tool using timeit --- + results = [] + n_repeats = 10 # Number of repeats for timeit + + for name, func in tools.items(): + def stmt(): + return func(references, hypotheses) + total_time = timeit.timeit(stmt, number=n_repeats) + avg_time = total_time / n_repeats + wer = func(references, hypotheses) + results.append((name, wer, avg_time)) + + # --- Normalize by fastest average time --- + min_time = min(r[2] for r in results) + normalized_results = [ + (name, wer, t, t / min_time) for name, wer, t in results + ] + + # --- Print CLI-friendly table --- + print("\n Word Error Rate Benchmark:\n") + print(f"{'Tool':<15} {'WER':<8} {'WER (%)':<10} {'Time (s)':<12} {'Norm Time':<18}") + print("-" * 70) + for name, wer, t, norm in normalized_results: + if name == "WERX": + norm_str = "1.00× (baseline)" + else: + norm_str = f"{norm:.2f}× slower" + print(f"{name:<15} {wer:.4f} {wer*100:6.2f}% {t:.6f} {norm_str:<18}") diff --git a/docs/source/conf.py b/docs/source/conf.py index f10543a..f95a820 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,7 +19,7 @@ project = "werpy" copyright = f'{datetime.now().year} Analytics in Motion' author = "Ross Armstrong" -release = "3.1.0" +release = "3.1.1" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/meson.build b/meson.build index 2328d34..f9033f3 100644 --- a/meson.build +++ b/meson.build @@ -1,7 +1,7 @@ project( 'werpy', 'c', 'cython', - version : '3.1.0', + version : '3.1.1', license: 'BSD-3', meson_version: '>= 1.1.0', default_options : [ diff --git a/pyproject.toml b/pyproject.toml index cf1b2e5..2d558dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ requires = [ [project] name = 'werpy' -version = '3.1.0' +version = '3.1.1' description = 'A powerful yet lightweight Python package to calculate and analyze the Word Error Rate (WER).' readme = 'README.md' requires-python = '>=3.10' @@ -67,3 +67,12 @@ docs = [ "sphinx==8.2.3", "sphinx-nefertiti==0.9.1", ] +test = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", +] +benchmarks = [ + "datasets>=4.4.1", + "werx>=0.3.1", + "jiwer>=4.0.0", +] \ No newline at end of file diff --git a/werpy/summary.py b/werpy/summary.py index 676b593..eae8d14 100644 --- a/werpy/summary.py +++ b/werpy/summary.py @@ -55,10 +55,41 @@ def summary(reference, hypothesis) -> pd.DataFrame | None: except (ValueError, AttributeError, ZeroDivisionError) as err: print(f"{type(err).__name__}: {str(err)}") return None - if isinstance(word_error_rate_breakdown[0], np.ndarray): - word_error_rate_breakdown = word_error_rate_breakdown.tolist() + + b = word_error_rate_breakdown + + # Unwrap 0-D container + if isinstance(b, np.ndarray) and b.ndim == 0: + b = b.item() + + if isinstance(b, np.ndarray): + if b.ndim == 2: + # True 2-D numeric batch + word_error_rate_breakdown = b.tolist() + + elif b.ndim == 1: + # Could be either: + # (a) single example row vector, or + # (b) object array of per-example vectors + first = b[0] if b.size else None + + if isinstance(first, (np.ndarray, list, tuple)): + # Batch stored as 1-D object array of per-example vectors (ragged fields exist) + word_error_rate_breakdown = [] + for r in b: + rr = r.tolist() if isinstance(r, np.ndarray) else r + word_error_rate_breakdown.append(rr) + else: + # Single example vector - wrap in list for DataFrame + word_error_rate_breakdown = [b.tolist()] + + else: + raise ValueError(f"Unexpected metrics output ndim: {b.ndim}") + else: - word_error_rate_breakdown = [word_error_rate_breakdown.tolist()] + # Non-numpy fallback (assume [wer, ld, m, ...]) + word_error_rate_breakdown = [b.tolist() if hasattr(b, 'tolist') else b] + columns = [ "wer", "ld", diff --git a/werpy/summaryp.py b/werpy/summaryp.py index 5421fde..9c087fd 100644 --- a/werpy/summaryp.py +++ b/werpy/summaryp.py @@ -69,29 +69,66 @@ def summaryp( except (ValueError, AttributeError, ZeroDivisionError) as err: print(f"{type(err).__name__}: {str(err)}") return None - if isinstance(word_error_rate_breakdown[0], np.ndarray): - word_error_rate_breakdown = word_error_rate_breakdown.tolist() - transform_word_error_rate_breakdown = np.transpose(word_error_rate_breakdown) - weighted_insertions = transform_word_error_rate_breakdown[3] * insertions_weight - weighted_deletions = transform_word_error_rate_breakdown[4] * deletions_weight - weighted_substitutions = ( - transform_word_error_rate_breakdown[5] * substitutions_weight - ) - m = transform_word_error_rate_breakdown[2] - weighted_errors = sum( - (weighted_insertions, weighted_deletions, weighted_substitutions) - ) - werps_result = (weighted_errors / m).tolist() + + b = word_error_rate_breakdown + + # Unwrap 0-D container + if isinstance(b, np.ndarray) and b.ndim == 0: + b = b.item() + + if isinstance(b, np.ndarray): + if b.ndim == 2: + # True 2-D numeric batch + word_error_rate_breakdown = b.tolist() + t = b.T + weighted_insertions = t[3] * insertions_weight + weighted_deletions = t[4] * deletions_weight + weighted_substitutions = t[5] * substitutions_weight + m = t[2] + weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions + werps_result = (weighted_errors / m).tolist() + + elif b.ndim == 1: + # Could be either: + # (a) single example row vector, or + # (b) object array of per-example vectors + first = b[0] if b.size else None + + if isinstance(first, (np.ndarray, list, tuple)): + # Batch stored as 1-D object array of per-example vectors (ragged fields exist) + word_error_rate_breakdown = [] + werps_result = [] + for r in b: + rr = r.tolist() if isinstance(r, np.ndarray) else r + word_error_rate_breakdown.append(rr) + w_ins = float(rr[3]) * insertions_weight + w_del = float(rr[4]) * deletions_weight + w_sub = float(rr[5]) * substitutions_weight + m_val = float(rr[2]) + weighted_wer = (w_ins + w_del + w_sub) / m_val if m_val else 0.0 + werps_result.append(weighted_wer) + else: + # Single example vector - wrap in list for DataFrame + word_error_rate_breakdown = [b.tolist()] + weighted_insertions = b[3] * insertions_weight + weighted_deletions = b[4] * deletions_weight + weighted_substitutions = b[5] * substitutions_weight + m = b[2] + weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions + werps_result = float(weighted_errors / m) if m else 0.0 + + else: + raise ValueError(f"Unexpected metrics output ndim: {b.ndim}") + else: - word_error_rate_breakdown = [word_error_rate_breakdown.tolist()] - weighted_insertions = word_error_rate_breakdown[0][3] * insertions_weight - weighted_deletions = word_error_rate_breakdown[0][4] * deletions_weight - weighted_substitutions = word_error_rate_breakdown[0][5] * substitutions_weight - m = word_error_rate_breakdown[0][2] - weighted_errors = sum( - (weighted_insertions, weighted_deletions, weighted_substitutions) - ) - werps_result = weighted_errors / m + # Non-numpy fallback (assume [wer, ld, m, ...]) + word_error_rate_breakdown = [b.tolist() if hasattr(b, 'tolist') else b] + weighted_insertions = b[3] * insertions_weight + weighted_deletions = b[4] * deletions_weight + weighted_substitutions = b[5] * substitutions_weight + m = b[2] + weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions + werps_result = float(weighted_errors / m) if m else 0.0 columns = [ "wer", diff --git a/werpy/wer.py b/werpy/wer.py index ff7a75e..2d52bc8 100644 --- a/werpy/wer.py +++ b/werpy/wer.py @@ -59,13 +59,43 @@ def wer(reference, hypothesis) -> float | np.float64 | None: except (ValueError, AttributeError, ZeroDivisionError) as err: print(f"{type(err).__name__}: {str(err)}") return None - if isinstance(word_error_rate_breakdown[0], np.ndarray): - transform_word_error_rate_breakdown = np.transpose( - word_error_rate_breakdown.tolist() - ) - wer_result = (np.sum(transform_word_error_rate_breakdown[1])) / ( - np.sum(transform_word_error_rate_breakdown[2]) - ) + + b = word_error_rate_breakdown + + # Unwrap 0-D container + if isinstance(b, np.ndarray) and b.ndim == 0: + b = b.item() + + if isinstance(b, np.ndarray): + if b.ndim == 2: + # True 2-D numeric batch + t = b.T + wer_result = float(np.sum(t[1]) / np.sum(t[2])) + + elif b.ndim == 1: + # Could be either: + # (a) single example row vector, or + # (b) object array of per-example vectors + first = b[0] if b.size else None + + if isinstance(first, (np.ndarray, list, tuple)): + # Batch stored as 1-D object array of per-example vectors (ragged fields exist) + total_ld = 0.0 + total_m = 0.0 + for r in b: + rr = r.tolist() if isinstance(r, np.ndarray) else r + total_ld += float(rr[1]) + total_m += float(rr[2]) + wer_result = float(total_ld / total_m) if total_m else 0.0 + else: + # Single example vector + wer_result = float(b[0]) + + else: + raise ValueError(f"Unexpected metrics output ndim: {b.ndim}") + else: - wer_result = word_error_rate_breakdown[0] + # Non-numpy fallback (assume [wer, ld, m, ...]) + wer_result = float(b[0]) + return wer_result diff --git a/werpy/werp.py b/werpy/werp.py index d33ebaf..e5cee5b 100644 --- a/werpy/werp.py +++ b/werpy/werp.py @@ -77,23 +77,61 @@ def werp( except (ValueError, AttributeError, ZeroDivisionError) as err: print(f"{type(err).__name__}: {str(err)}") return None - if isinstance(word_error_rate_breakdown[0], np.ndarray): - transform_word_error_rate_breakdown = np.transpose( - word_error_rate_breakdown.tolist() - ) - weighted_insertions = transform_word_error_rate_breakdown[3] * insertions_weight - weighted_deletions = transform_word_error_rate_breakdown[4] * deletions_weight - weighted_substitutions = ( - transform_word_error_rate_breakdown[5] * substitutions_weight - ) - m = np.sum(transform_word_error_rate_breakdown[2]) + + b = word_error_rate_breakdown + + # Unwrap 0-D container + if isinstance(b, np.ndarray) and b.ndim == 0: + b = b.item() + + if isinstance(b, np.ndarray): + if b.ndim == 2: + # True 2-D numeric batch + t = b.T + weighted_insertions = np.sum(t[3]) * insertions_weight + weighted_deletions = np.sum(t[4]) * deletions_weight + weighted_substitutions = np.sum(t[5]) * substitutions_weight + m = np.sum(t[2]) + + elif b.ndim == 1: + # Could be either: + # (a) single example row vector, or + # (b) object array of per-example vectors + first = b[0] if b.size else None + + if isinstance(first, (np.ndarray, list, tuple)): + # Batch stored as 1-D object array of per-example vectors (ragged fields exist) + total_insertions = 0.0 + total_deletions = 0.0 + total_substitutions = 0.0 + total_m = 0.0 + for r in b: + rr = r.tolist() if isinstance(r, np.ndarray) else r + total_insertions += float(rr[3]) + total_deletions += float(rr[4]) + total_substitutions += float(rr[5]) + total_m += float(rr[2]) + weighted_insertions = total_insertions * insertions_weight + weighted_deletions = total_deletions * deletions_weight + weighted_substitutions = total_substitutions * substitutions_weight + m = total_m + else: + # Single example vector + weighted_insertions = b[3] * insertions_weight + weighted_deletions = b[4] * deletions_weight + weighted_substitutions = b[5] * substitutions_weight + m = b[2] + + else: + raise ValueError(f"Unexpected metrics output ndim: {b.ndim}") + else: - weighted_insertions = word_error_rate_breakdown[3] * insertions_weight - weighted_deletions = word_error_rate_breakdown[4] * deletions_weight - weighted_substitutions = word_error_rate_breakdown[5] * substitutions_weight - m = np.sum(word_error_rate_breakdown[2]) - weighted_errors = np.sum( - [weighted_insertions, weighted_deletions, weighted_substitutions] - ) - werp_result = weighted_errors / m + # Non-numpy fallback (assume [wer, ld, m, ...]) + weighted_insertions = b[3] * insertions_weight + weighted_deletions = b[4] * deletions_weight + weighted_substitutions = b[5] * substitutions_weight + m = b[2] + + weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions + werp_result = float(weighted_errors / m) if m else 0.0 return werp_result diff --git a/werpy/werps.py b/werpy/werps.py index 145704a..4714af9 100644 --- a/werpy/werps.py +++ b/werpy/werps.py @@ -71,28 +71,60 @@ def werps( except (ValueError, AttributeError, ZeroDivisionError) as err: print(f"{type(err).__name__}: {str(err)}") return None - if isinstance(word_error_rate_breakdown[0], np.ndarray): - transform_word_error_rate_breakdown = np.transpose( - word_error_rate_breakdown.tolist() - ) - weighted_insertions = transform_word_error_rate_breakdown[3] * insertions_weight - weighted_deletions = transform_word_error_rate_breakdown[4] * deletions_weight - weighted_substitutions = ( - transform_word_error_rate_breakdown[5] * substitutions_weight - ) - m = transform_word_error_rate_breakdown[2] - else: - weighted_insertions = word_error_rate_breakdown[3] * insertions_weight - weighted_deletions = word_error_rate_breakdown[4] * deletions_weight - weighted_substitutions = word_error_rate_breakdown[5] * substitutions_weight - m = word_error_rate_breakdown[2] - weighted_errors = sum( - (weighted_insertions, weighted_deletions, weighted_substitutions) - ) - werps_result = weighted_errors / m + b = word_error_rate_breakdown + + # Unwrap 0-D container + if isinstance(b, np.ndarray) and b.ndim == 0: + b = b.item() + + if isinstance(b, np.ndarray): + if b.ndim == 2: + # True 2-D numeric batch + t = b.T + weighted_insertions = t[3] * insertions_weight + weighted_deletions = t[4] * deletions_weight + weighted_substitutions = t[5] * substitutions_weight + m = t[2] + weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions + werps_result = (weighted_errors / m).tolist() + + elif b.ndim == 1: + # Could be either: + # (a) single example row vector, or + # (b) object array of per-example vectors + first = b[0] if b.size else None - if isinstance(word_error_rate_breakdown[0], float): - return werps_result + if isinstance(first, (np.ndarray, list, tuple)): + # Batch stored as 1-D object array of per-example vectors (ragged fields exist) + werps_result = [] + for r in b: + rr = r.tolist() if isinstance(r, np.ndarray) else r + w_ins = float(rr[3]) * insertions_weight + w_del = float(rr[4]) * deletions_weight + w_sub = float(rr[5]) * substitutions_weight + m_val = float(rr[2]) + weighted_wer = (w_ins + w_del + w_sub) / m_val if m_val else 0.0 + werps_result.append(weighted_wer) + else: + # Single example vector + weighted_insertions = b[3] * insertions_weight + weighted_deletions = b[4] * deletions_weight + weighted_substitutions = b[5] * substitutions_weight + m = b[2] + weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions + werps_result = float(weighted_errors / m) if m else 0.0 + + else: + raise ValueError(f"Unexpected metrics output ndim: {b.ndim}") + + else: + # Non-numpy fallback (assume [wer, ld, m, ...]) + weighted_insertions = b[3] * insertions_weight + weighted_deletions = b[4] * deletions_weight + weighted_substitutions = b[5] * substitutions_weight + m = b[2] + weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions + werps_result = float(weighted_errors / m) if m else 0.0 - return werps_result.tolist() + return werps_result diff --git a/werpy/wers.py b/werpy/wers.py index a7dc97a..c52a986 100644 --- a/werpy/wers.py +++ b/werpy/wers.py @@ -52,11 +52,40 @@ def wers(reference, hypothesis): except (ValueError, AttributeError, ZeroDivisionError) as err: print(f"{type(err).__name__}: {str(err)}") return None - if isinstance(word_error_rate_breakdown[0], np.ndarray): - transform_word_error_rate_breakdown = np.transpose( - word_error_rate_breakdown.tolist() - ) - wers_result = transform_word_error_rate_breakdown[0].tolist() + + b = word_error_rate_breakdown + + # Unwrap 0-D container + if isinstance(b, np.ndarray) and b.ndim == 0: + b = b.item() + + if isinstance(b, np.ndarray): + if b.ndim == 2: + # True 2-D numeric batch + t = b.T + wers_result = t[0].tolist() + + elif b.ndim == 1: + # Could be either: + # (a) single example row vector, or + # (b) object array of per-example vectors + first = b[0] if b.size else None + + if isinstance(first, (np.ndarray, list, tuple)): + # Batch stored as 1-D object array of per-example vectors (ragged fields exist) + wers_result = [] + for r in b: + rr = r.tolist() if isinstance(r, np.ndarray) else r + wers_result.append(float(rr[0])) + else: + # Single example vector + wers_result = float(b[0]) + + else: + raise ValueError(f"Unexpected metrics output ndim: {b.ndim}") + else: - wers_result = word_error_rate_breakdown[0] + # Non-numpy fallback (assume [wer, ld, m, ...]) + wers_result = float(b[0]) + return wers_result