Skip to content

Commit 975e790

Browse files
committed
Fix pyre errors: cli.internal
1 parent 0c0d751 commit 975e790

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

src/psykoda/cli/internal.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from dataclasses import dataclass
1111
from datetime import datetime, timedelta
12-
from typing import Dict, List, Optional, Tuple
12+
from typing import Dict, List, Optional, Tuple, cast
1313

1414
import numpy as np
1515
import pandas as pd
@@ -55,7 +55,7 @@ def configure_logging(debug: bool):
5555
stderr_handler.addFilter(stderr_filter)
5656
stderr_handler.setLevel(logging.INFO)
5757
stderr_handler.setFormatter(logging.Formatter("%(message)s"))
58-
handlers = [stderr_handler]
58+
handlers: list[logging.Handler] = [stderr_handler]
5959

6060
logfile_handler = logging.FileHandler(PATH_LOG)
6161
logfile_handler.setLevel(logging.DEBUG)
@@ -405,7 +405,7 @@ def main_detection_skip_or_detect(
405405
logger.info("outputting detection reports")
406406
anomaly_score = detector.compute_anomaly_score(x_test, scale=True)
407407
num_anomaly = min(
408-
sum(anomaly_score > anomaly_detection_config.threshold.min_score),
408+
np.count_nonzero(anomaly_score > anomaly_detection_config.threshold.min_score),
409409
anomaly_detection_config.threshold.num_anomaly,
410410
)
411411

@@ -523,6 +523,7 @@ def report_all(path_list_stats: List[str], path_save: str):
523523
[], columns=["datetime_rounded", "src_ip", "subnet", "service"]
524524
)
525525
idx = 0
526+
results_shaps = pd.DataFrame()
526527
for path in path_list_stats:
527528
# Load stats
528529
stats = utils.load_json(path)
@@ -545,7 +546,7 @@ def report_all(path_list_stats: List[str], path_save: str):
545546
results_pd.loc[idx] = [dt, src_ip, subnet, service]
546547

547548
if idx == 0:
548-
results_shaps = pd.DataFrame([], columns=report.columns)
549+
results_shaps.columns = report.columns
549550
results_shaps.loc[idx] = report.loc[(dt, src_ip)]
550551

551552
idx += 1
@@ -564,13 +565,14 @@ def report_all(path_list_stats: List[str], path_save: str):
564565
ret = pd.concat([ret, results_pd_group.get_group(key)])
565566

566567
ret.round(4).to_csv(path_save, index=False)
568+
num_anomaly_ipaddr = len(keys)
567569
else:
568570
# Anomaly not found
569571
pd.DataFrame([["no anomaly found"]]).to_csv(path_save, index=False)
572+
num_anomaly_ipaddr = 0
570573

571574
logger.info("[RESULT]", extra=to_stderr)
572575
logger.info("Detection summary file: %s", path_save, extra=to_stderr)
573-
num_anomaly_ipaddr = len(keys) if anomaly_found else 0
574576
logger.info(
575577
"Number of unique anomaly IP addresses: %s", num_anomaly_ipaddr, extra=to_stderr
576578
)
@@ -719,7 +721,9 @@ def detect_per_unit(
719721
label_value=1,
720722
)
721723
log_labeled = labeled.factory(config.io.previous.log)[0].load_previous_log(
722-
entries=known_normal.index,
724+
entries=cast(known_normal.index),
725+
# we can safely assume that known_normal.Index is MultiIndex
726+
# since it is empty otherwise.
723727
)
724728
log_labeled = apply_exclude_lists(log_labeled, config.preprocess.exclude_lists)
725729
log_labeled = preprocess.extract_log(
@@ -784,12 +788,12 @@ def _load_log_catch(load, r):
784788

785789

786790
def load_previous(
787-
config: LoadPreviousConfigItem, date_to: datetime, label_value: float
791+
config: Optional[LoadPreviousConfigItem], date_to: datetime, label_value: float
788792
) -> pd.Series:
789793
from psykoda.preprocess import round_datetime
790794
from psykoda.utils import DateRange
791795

792-
if config.list is None:
796+
if config is None or config.list is None:
793797
return pd.Series()
794798

795799
def date_filter(row):

0 commit comments

Comments
 (0)