Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

## [0.9.0] - 2025-11-26

### Added

- `check_correlation_warnings`: New check to identify pairs of numerical columns with high correlation above a specified threshold.

## [0.8.0] - 2025-11-26

### Added
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Switch to Pull Requests for new features and improvements.

## v0.9.0

- [ ] Correlation Analysis Check
- [x] Correlation Analysis Check
- [ ] Documentation and Examples ready for release

## v1.0.0
Expand Down
5 changes: 4 additions & 1 deletion src/lintdata/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def report(
future_date_reference: Optional[str] = None,
special_chars_threshold: float = 0.1,
threshold_years: float = 50,
correlation_threshold: float = 0.95,
report_format: str = "text",
output: Optional[str] = None,
return_dict: bool = False,
Expand All @@ -114,7 +115,7 @@ def report(
'unique', 'outliers', 'missing_patterns', 'case', 'cardinality', 'skewness',
'duplicate_columns', 'type_consistency', 'negative', 'rare_categories',
'date_format', 'string_length', 'zero_inflation', 'future_dates',
'special_chars', 'date_anomalies'. Use 'all' to run all checks. Defaults to None.
'special_chars', 'date_anomalies', 'correlation'. Use 'all' to run all checks. Defaults to None.
outlier_threshold (float, optional): Outlier detection threshold using the IQR method. Defaults to 1.5.
skewness_threshold (float, optional): Threshold for skewness detection. Defaults to 1.0.
rare_category_threshold (float, optional): Minimum proportion for rare categories. Defaults to 0.01.
Expand All @@ -128,6 +129,7 @@ def report(
future_date_reference (Optional[str], optional): Reference date for future date check (YYYY-MM-DD). Defaults to None (today).
special_chars_threshold (float, optional): Minimum proportion of values with special characters. Defaults to 0.1.
threshold_years (float, optional): Maximum acceptable date range in years. Columns with date ranges exceeding will be flagged. Defaults to 50.
correlation_threshold (float, optional): Threshold for flagging highly correlated columns. Defaults to 0.95.
report_format (str, optional): Output format. Options: 'text', 'html', 'json', 'csv'. Defaults to 'text'.
output (Optional[str], optional): File path to save the report. If None, returns as string. Defaults to None.
return_dict (bool, optional): If True, returns structured dictionary instead of formatted string. Defaults to False.
Expand Down Expand Up @@ -212,6 +214,7 @@ def report(
"date_anomalies": lambda: checks.check_date_range_anomalies(
self._df, columns=future_date_columns, threshold_years=threshold_years
),
"correlation": lambda: checks.check_correlation_warnings(self._df, threshold=correlation_threshold),
}

if checks_to_run is None:
Expand Down
38 changes: 38 additions & 0 deletions src/lintdata/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,3 +1111,41 @@ def check_date_range_anomalies(
f"({min_date.date()} to {max_date.date()})"
)
return warnings


def check_correlation_warnings(df: pd.DataFrame, threshold: float = 0.95) -> List[str]:
warnings: List[str] = []

if df.empty:
return warnings

if not (0 < threshold <= 1):
raise ValueError("Correlation threshold must be between 0 and 1.")

numeric_df = df.select_dtypes(include=[np.number])

if len(numeric_df.columns) < 2:
return warnings

corr_matrix = numeric_df.corr()

checked_pairs = set()

for i, col1 in enumerate(corr_matrix.columns):
for j, col2 in enumerate(corr_matrix.columns):
if i >= j:
continue

pair_key = tuple(sorted([col1, col2]))
if pair_key in checked_pairs:
continue

checked_pairs.add(pair_key)

corr_value = abs(corr_matrix.loc[col1, col2]) # type: ignore

if not pd.isna(corr_value) and corr_value >= threshold: # type: ignore
percent = corr_value * 100
warnings.append(f"[High Correlation] Columns '{col1}' and '{col2}' are {percent:.1f}% correlated.")

return warnings
166 changes: 166 additions & 0 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,3 +1379,169 @@ def test_performance_all_null_columns():

warnings = checks.check_outliers(df)
assert warnings == []


# ==== Correlation Warnings Tests ====


def test_check_correlation_warnings_no_correlation():
"""No correlation in independent columns."""
np.random.seed(42)
df = pd.DataFrame({"a": np.random.randn(100), "b": np.random.randn(100), "c": np.random.randn(100)})
warnings = checks.check_correlation_warnings(df)
assert warnings == []


def test_check_correlation_warnings_detects_high_correlation():
"""Detects highly correlated columns."""
df = pd.DataFrame(
{
"height_cm": [170, 180, 175, 165, 185],
"height_inches": [66.9, 70.9, 68.9, 65.0, 72.8], # ~99% correlated
}
)
warnings = checks.check_correlation_warnings(df, threshold=0.95)
assert len(warnings) == 1
assert "height_cm" in warnings[0]
assert "height_inches" in warnings[0]
assert "correlated" in warnings[0].lower()


def test_check_correlation_warnings_perfect_correlation():
"""Detects perfectly correlated columns."""
df = pd.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": [2, 4, 6, 8, 10], # Perfect correlation (b = 2*a) or (b = c*2 - 8)
"c": [5, 6, 7, 8, 9], # Perfect correlation (c = a + 4) or (c = b/2 + 4)
}
)
warnings = checks.check_correlation_warnings(df, threshold=0.95)
assert len(warnings) == 3
assert "'a'" in warnings[0]
assert "'b'" in warnings[0]
assert "100.0%" in warnings[0]


def test_check_correlation_warnings_custom_threshold():
"""Custom threshold works correctly."""
df = pd.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": [1.2, 2.1, 2.5, 3.9, 5.2], # ~97% correlated
}
)

# High threshold - should not trigger with 0.99
warnings_high = checks.check_correlation_warnings(df, threshold=0.99)
assert warnings_high == []

# Low threshold - should trigger with 0.90
warnings_low = checks.check_correlation_warnings(df, threshold=0.90)
assert len(warnings_low) >= 1


def test_check_correlation_warnings_empty_dataframe():
"""Empty DataFrame returns no warnings."""
df = pd.DataFrame()
warnings = checks.check_correlation_warnings(df)
assert warnings == []


def test_check_correlation_warnings_single_numeric_column():
"""Single numeric column returns no warnings."""
df = pd.DataFrame({"a": [1, 2, 3, 4, 5]})
warnings = checks.check_correlation_warnings(df)
assert warnings == []


def test_check_correlation_warnings_no_numeric_columns():
"""Non-numeric columns are ignored."""
df = pd.DataFrame({"name": ["Alice", "Bob", "Charlie"], "category": ["A", "B", "C"]})
warnings = checks.check_correlation_warnings(df)
assert warnings == []


def test_check_correlation_warnings_with_nan():
"""NaN values are handled correctly."""
df = pd.DataFrame({"a": [1, 2, np.nan, 4, 5], "b": [2, 4, np.nan, 8, 10]})
warnings = checks.check_correlation_warnings(df, threshold=0.95)
assert len(warnings) == 1
assert "'a'" in warnings[0]
assert "'b'" in warnings[0]


def test_check_correlation_warnings_negative_correlation():
"""Detects negative correlation (absolute value used)."""
df = pd.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": [-2, -4, -6, -8, -10], # Perfect negative correlation
}
)
warnings = checks.check_correlation_warnings(df, threshold=0.95)
assert len(warnings) == 1
assert "'a'" in warnings[0]
assert "'b'" in warnings[0]
assert "100.0%" in warnings[0]


def test_check_correlation_warnings_multiple_pairs():
"""Detects multiple correlated pairs."""
df = pd.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": [2, 4, 6, 8, 10], # Correlated with a
"c": [10, 20, 30, 40, 50], # Also correlated with a and b
"d": [100, 101, 102, 103, 104], # Not correlated
}
)
warnings = checks.check_correlation_warnings(df, threshold=0.95)
# Should find a-b, a-c, b-c correlations
assert len(warnings) >= 2


def test_check_correlation_warnings_invalid_threshold_low():
"""Invalid threshold (too low) raises ValueError."""
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})

try:
checks.check_correlation_warnings(df, threshold=0)
assert False, "Should have raised ValueError"
except ValueError as e:
assert "between 0 and 1" in str(e).lower()


def test_check_correlation_warnings_invalid_threshold_high():
"""Invalid threshold (too high) raises ValueError."""
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})

try:
checks.check_correlation_warnings(df, threshold=1.5)
assert False, "Should have raised ValueError"
except ValueError as e:
assert "between 0 and 1" in str(e).lower()


def test_check_correlation_warnings_constant_columns():
"""Constant columns (no variance) are handled."""
df = pd.DataFrame({"const1": [5, 5, 5, 5], "const2": [10, 10, 10, 10], "varying": [1, 2, 3, 4]})
warnings = checks.check_correlation_warnings(df)
# Correlation with constant columns is NaN, should not cause issues
assert isinstance(warnings, list)


def test_check_correlation_warnings_mixed_types():
"""Mixed numeric and non-numeric columns work correctly."""
df = pd.DataFrame(
{
"height_cm": [170, 180, 175],
"height_inches": [66.9, 70.9, 68.9],
"name": ["Alice", "Bob", "Charlie"],
"category": ["A", "B", "C"],
}
)
warnings = checks.check_correlation_warnings(df, threshold=0.95)
assert len(warnings) == 1
assert "height_cm" in warnings[0]
assert "height_inches" in warnings[0]