Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions python/src/magika/magika.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,43 @@ def identify_paths(

return self._get_results_from_paths(paths_)

def scan_directory(
self, directory: Union[str, os.PathLike], recursive_scan: bool = False
) -> List[MagikaResult]:
"""Identify the content type of all files in a directory given its path."""
path_obj = Path(directory)

# Guard clause: check if directory exists
if not path_obj.exists() or not path_obj.is_dir():
raise FileNotFoundError(
f"The directory '{directory}' does not exist or is not a directory."
)

collected_paths: List[Union[str, os.PathLike]] = []

# Use rglob('*') for recursive scan, glob('*') for single directory
glob_pattern = (
sorted(path_obj.rglob("*"))
if recursive_scan
else sorted(path_obj.glob("*"))
)

for item in glob_pattern:
# We only want files, not sub-directories themselves
if item.is_file():
collected_paths.append(item)

paths_ = []
for path in collected_paths:
if isinstance(path, str) or isinstance(path, os.PathLike):
paths_.append(Path(path))
else:
raise TypeError(
f"Input '{path}' is invalid: input path should be of type `Union[str, os.PathLike]`"
)

return self._get_results_from_paths(paths_)

def identify_bytes(self, content: bytes) -> MagikaResult:
"""Identify the content type of raw bytes."""
if not isinstance(content, bytes):
Expand Down
17 changes: 17 additions & 0 deletions python/tests/test_magika_python_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ def test_magika_module_with_explicit_model_dir() -> None:
_ = m.identify_stream(f)


def test_magika_module_with_basic_tests_by_directory() -> None:
tests_paths = utils.get_directory_tests_files_dir()

m = Magika()

# Only scan direct children of tests_data/directory.
# Expected output is "directory" content type.
results = m.scan_directory(tests_paths)
direct_children = sorted([p for p in tests_paths.glob("*")])
check_results_vs_expected_results(direct_children, results)

# Scan all files recursively. Expected output is content type of each file.
results = m.scan_directory(tests_paths, recursive_scan=True)
all_files = sorted([p for p in tests_paths.rglob("*") if p.is_file()])
check_results_vs_expected_results(all_files, results)


def test_magika_module_with_basic_tests_by_paths() -> None:
tests_paths = utils.get_basic_test_files_paths()

Expand Down
6 changes: 6 additions & 0 deletions python/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def get_basic_tests_files_dir() -> Path:
return tests_files_dir


def get_directory_tests_files_dir() -> Path:
tests_files_dir = get_tests_data_dir() / "directory"
assert tests_files_dir.is_dir()
return tests_files_dir


def get_mitra_tests_files_dir() -> Path:
tests_files_dir = get_tests_data_dir() / "mitra"
assert tests_files_dir.is_dir()
Expand Down
3 changes: 3 additions & 0 deletions rust/cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ Options:
--no-dereference
Identifies symbolic links as is instead of identifying their content by following them

--summary
Prints a summary of file types at the end of the output

--colors
Prints with colors regardless of terminal support

Expand Down
54 changes: 52 additions & 2 deletions rust/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;

use anyhow::{bail, ensure, Result};
use clap::{Args, Parser};
use colored::ColoredString;
use colored::{ColoredString, Colorize};
use magika_lib::{
self as magika, ContentType, Features, FeaturesOrRuled, FileType, InferredType,
OverwriteReason, Session, TypeInfo,
Expand All @@ -48,6 +48,10 @@ struct Flags {
#[arg(long)]
no_dereference: bool,

/// Prints a summary of file types at the end of the output.
#[arg(long)]
summary: bool,

#[clap(flatten)]
colors: Colors,

Expand Down Expand Up @@ -211,12 +215,26 @@ async fn main() -> Result<()> {
if flags.format.json {
print!("[");
}
// Initializes counter for file types.
let mut type_counts: HashMap<(String, String), usize> = HashMap::new();
let mut reorder = Reorder::default();
let mut errors = false;
// Prints results, reordering as needed to match input order.
while let Some(response) = result_receiver.recv().await {
reorder.push(response?);
while let Some(response) = reorder.pop() {
errors |= response.result.is_err();
// Counts file type for the final summary.
if flags.summary {
// If result is ok, extracts the description (e.g., "Python source").
if let Ok(file_type) = &response.result {
let type_label = file_type.info().description.to_string();
let group = file_type.info().group.to_string();
// Increments the count in the HashMap, inserting 0 if the key does not exist.
*type_counts.entry((type_label, group)).or_insert(0) += 1;
}
}
// Prints output.
if flags.format.json {
if reorder.next != 1 {
print!(",");
Expand All @@ -239,6 +257,25 @@ async fn main() -> Result<()> {
if errors {
std::process::exit(1);
}
// Prints summary if requested (only if there were no errors).
if flags.summary && !flags.format.json && !flags.format.jsonl {
println!("--- Summary ---");
// Sorts by count (descending).
let mut sorted_counts: Vec<_> = type_counts.into_iter().collect();
sorted_counts.sort_by(|a, b| {
let count_cmp = b.1.cmp(&a.1);
if count_cmp == std::cmp::Ordering::Equal {
// Sorts alphabetically (case-insensitive).
// a.0 .0 accesses the type_label.
a.0 .0.to_lowercase().cmp(&b.0 .0.to_lowercase())
} else {
count_cmp
}
});
for ((type_label, group), count) in sorted_counts {
println!("{}: {}", color_type_label(&type_label, &group), count);
}
}
Ok(())
}

Expand Down Expand Up @@ -554,7 +591,6 @@ impl Response {
}

fn color(&self, result: ColoredString) -> ColoredString {
use colored::Colorize as _;
// We only use true colors (except for errors). If the terminal doesn't support true colors,
// the colored crate will automatically choose the closest one.
match &self.result {
Expand Down Expand Up @@ -587,3 +623,17 @@ fn join<T: AsRef<str>>(xs: impl IntoIterator<Item = T>) -> String {
result.push(']');
result
}

fn color_type_label(type_label: &str, group: &str) -> colored::ColoredString {
match group {
"application" => type_label.truecolor(0xf4, 0x3f, 0x5e),
"archive" => type_label.truecolor(0xf5, 0x9e, 0x0b),
"audio" => type_label.truecolor(0x84, 0xcc, 0x16),
"code" => type_label.truecolor(0x8b, 0x5c, 0xf6),
"document" => type_label.truecolor(0x3b, 0x82, 0xf6),
"executable" => type_label.truecolor(0xec, 0x48, 0x99),
"image" => type_label.truecolor(0x06, 0xb6, 0xd4),
"video" => type_label.truecolor(0x10, 0xb9, 0x81),
_ => type_label.bold().truecolor(0xcc, 0xcc, 0xcc),
}
}
1 change: 1 addition & 0 deletions tests_data/directory/txt/complex-sentence.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is yet another simple test, it includes one simple sentence, but it is not as trivial as other simpler tests.
Loading