diff --git a/hotkey_tagger.py b/hotkey_tagger.py index c7cadd0..7a5dc5e 100644 --- a/hotkey_tagger.py +++ b/hotkey_tagger.py @@ -26,7 +26,7 @@ import json from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from PyQt5.QtCore import Qt, pyqtSignal from PyQt5.QtGui import QFont, QPixmap, QKeySequence @@ -336,6 +336,30 @@ def _init_ui(self, hotkey_map: Dict[str, str]) -> None: btn_row.addStretch() layout.addLayout(btn_row) + # Sorting and import controls + sort_row = QHBoxLayout() + sort_az_btn = QPushButton("Sort A→Z (Tag)", self) + sort_az_btn.setToolTip("Sort rows alphabetically by tag name") + sort_az_btn.clicked.connect(self._sort_by_tag) + sort_key_btn = QPushButton("Sort by Key (Keyboard Order)", self) + sort_key_btn.setToolTip( + "Sort rows by keyboard position of their first key " + "(numbers row → QWERTY row → ASDF row → ZXCV row)" + ) + sort_key_btn.clicked.connect(self._sort_by_key) + load_btn = QPushButton("Load Hotkeys from File…", self) + load_btn.setToolTip( + "Import hotkey bindings from a hotkeys.json file " + "(e.g. copied from a previous image folder)" + ) + load_btn.clicked.connect(self._load_from_file) + sort_row.addWidget(sort_az_btn) + sort_row.addWidget(sort_key_btn) + sort_row.addSpacing(16) + sort_row.addWidget(load_btn) + sort_row.addStretch() + layout.addLayout(sort_row) + # OK / Cancel buttons = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel, self) buttons.accepted.connect(self.accept) @@ -385,6 +409,105 @@ def _remove_selected_row(self) -> None: if 0 <= row < self.table.rowCount(): self.table.removeRow(row) + # ------------------------------------------------------------------ + # Table helpers: extract, populate, sort, and import + # ------------------------------------------------------------------ + + def _extract_rows(self) -> List[Tuple[str, str]]: + """Return all table rows as a list of (keys_text, tag_text) tuples.""" + rows: List[Tuple[str, str]] = [] + for row in range(self.table.rowCount()): + key_item = self.table.item(row, 0) + tag_item = self.table.item(row, 1) + rows.append(( + key_item.text() if key_item else "", + tag_item.text() if tag_item else "", + )) + return rows + + def _populate_table(self, rows: List[Tuple[str, str]]) -> None: + """Replace all table contents with the given (keys_text, tag_text) rows.""" + self.table.blockSignals(True) + self.table.setRowCount(0) + for keys, tag in rows: + self._add_row(keys, tag) + self.table.blockSignals(False) + self._refresh_keyboard_preview() + + def _sort_by_tag(self) -> None: + """Sort table rows alphabetically by tag name (case-insensitive).""" + rows = self._extract_rows() + rows.sort(key=lambda r: r[1].lower()) + self._populate_table(rows) + + def _sort_by_key(self) -> None: + """Sort table rows by the keyboard position of their first assigned key. + + The order follows: numbers row (1–0), QWERTY row, ASDF row, ZXCV row, + then any remaining characters, then rows with no recognised keys last. + """ + rows = self._extract_rows() + + def _key_order(row: Tuple[str, str]) -> Tuple[int, str]: + keys = parse_keys_field(row[0]) + if keys: + # Use the position of the lowest-order key in the row + order = min(KEYBOARD_ORDER.get(k, len(KEYBOARD_ORDER)) for k in keys) + else: + order = len(KEYBOARD_ORDER) # sort unrecognised keys to the end + return (order, row[1].lower()) + + rows.sort(key=_key_order) + self._populate_table(rows) + + def _load_from_file(self) -> None: + """Import hotkey bindings from a user-chosen hotkeys.json file. + + The current table contents are *replaced* by the imported bindings. + """ + path_str, _ = QFileDialog.getOpenFileName( + self, + "Load Hotkeys from File", + "", + "JSON Files (*.json);;All Files (*)", + ) + if not path_str: + return + + try: + data = json.loads(Path(path_str).read_text(encoding="utf-8")) + except Exception as exc: + QMessageBox.warning(self, "Load Hotkeys", f"Could not read file:\n{exc}") + return + + if not isinstance(data, dict): + QMessageBox.warning( + self, "Load Hotkeys", "Invalid format – expected a JSON object." + ) + return + + # Normalise to lowercase single-char keys (same logic as _load_folder_hotkeys) + hotkey_map: Dict[str, str] = { + str(k).lower(): str(v) + for k, v in data.items() + if len(str(k)) == 1 + } + + if not hotkey_map: + QMessageBox.information( + self, "Load Hotkeys", "No valid single-key bindings found in the file." + ) + return + + # Rebuild the table from the imported map + by_tag = group_keys_by_tag(hotkey_map) + rows: List[Tuple[str, str]] = [(",".join(keys), tag) for tag, keys in by_tag.items()] + self._populate_table(rows) + + # ------------------------------------------------------------------ + # Hotkey map extraction + # ------------------------------------------------------------------ + def get_hotkey_map(self) -> Dict[str, str]: """ Return a flattened key->tag map. @@ -496,20 +619,25 @@ def _init_ui(self) -> None: # ---- Navigation row ---- nav = QHBoxLayout() - prev_btn = QPushButton("◀ Prev") - prev_btn.clicked.connect(self.prev_image) + self.first_btn = QPushButton("⏮ First") + self.first_btn.clicked.connect(self.first_image) + + # Keep a reference so the label can be toggled by _on_next_mode_changed + self.prev_btn = QPushButton("◀ Prev") + self.prev_btn.clicked.connect(self.prev_action) self.next_btn = QPushButton("Next ▶") # keep a reference; we change its label self.next_btn.clicked.connect(self.next_action) - # NEW: checkbox to toggle "next untagged" mode + # Checkbox to toggle "untagged" mode for both Prev and Next self.chk_next_untagged = QCheckBox("Next = untagged") self.chk_next_untagged.stateChanged.connect(self._on_next_mode_changed) self.progress_label = QLabel("0 / 0") self.progress_label.setAlignment(Qt.AlignCenter) - nav.addWidget(prev_btn) + nav.addWidget(self.first_btn) + nav.addWidget(self.prev_btn) nav.addWidget(self.progress_label, stretch=1) nav.addWidget(self.chk_next_untagged) # <-- add checkbox in the row nav.addWidget(self.next_btn) @@ -557,7 +685,7 @@ def add(seq: str, fn): self._shortcuts.append(sc) # Previous / Next - add("Left", self.prev_image) + add("Left", self.prev_action) add("Right", self.next_action) # Also advance with Space and Enter/Return @@ -738,12 +866,46 @@ def _show_current_image(self) -> None: # Navigation # ------------------------------------------------------------------ # + def first_image(self) -> None: + if self.image_files and self.current_index != 0: + self.current_index = 0 + self._show_current_image() + self.settings.last_image_index = self.current_index + def prev_image(self) -> None: if self.image_files and self.current_index > 0: self.current_index -= 1 self._show_current_image() self.settings.last_image_index = self.current_index + def prev_action(self) -> None: + """Delegate to sequential prev or prev-untagged based on checkbox.""" + if self.chk_next_untagged.isChecked(): + self.prev_untagged() + else: + self.prev_image() + + def prev_untagged(self) -> None: + """Jump to the previous image (before current) that has no tags.""" + if not self.image_files: + return + + end = self.current_index - 1 + found = None + for i in range(end, -1, -1): + rel = self._relpath_for_index(i) + if rel is not None and not self.tags_dict.get(rel, []): + found = i + break + + if found is None: + self.status_bar.showMessage("No untagged images before current") + return + + self.current_index = found + self._show_current_image() + self.settings.last_image_index = self.current_index + def next_image(self) -> None: if self.image_files and self.current_index < len(self.image_files) - 1: self.current_index += 1 @@ -751,11 +913,13 @@ def next_image(self) -> None: self.settings.last_image_index = self.current_index def _on_next_mode_changed(self, state) -> None: - """Update the Next button label when the mode changes.""" + """Update the Prev and Next button labels when the mode changes.""" if self.chk_next_untagged.isChecked(): + self.prev_btn.setText("◀ Prev Untagged") self.next_btn.setText("Next Untagged ▶") self.status_bar.showMessage("Next mode: jump to next untagged image") else: + self.prev_btn.setText("◀ Prev") self.next_btn.setText("Next ▶") self.status_bar.showMessage("Next mode: sequential") diff --git a/requirements.txt b/requirements.txt index dcd21e8..cfb4c4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,4 @@ PyQt5>=5.15 +numpy>=1.21 +Pillow>=9.0 +scikit-learn>=1.0 diff --git a/tests/test_csv_manager.py b/tests/test_csv_manager.py index be3fbe5..471f347 100644 --- a/tests/test_csv_manager.py +++ b/tests/test_csv_manager.py @@ -3,6 +3,7 @@ import csv import os import tempfile +from pathlib import Path import pytest @@ -18,21 +19,21 @@ def test_get_all_tags_empty(): def test_get_all_tags_single_file(): - assert get_all_tags({"a.jpg": ["galaxy", "bright"]}) == ["bright", "galaxy"] + assert get_all_tags({Path("a.jpg"): ["galaxy", "bright"]}) == ["bright", "galaxy"] def test_get_all_tags_multiple_files(): tags_dict = { - "a.jpg": ["galaxy", "star"], - "b.jpg": ["nebula", "star"], - "c.jpg": [], + Path("a.jpg"): ["galaxy", "star"], + Path("b.jpg"): ["nebula", "star"], + Path("c.jpg"): [], } result = get_all_tags(tags_dict) assert result == ["galaxy", "nebula", "star"] def test_get_all_tags_returns_sorted(): - tags_dict = {"img.jpg": ["z_tag", "a_tag", "m_tag"]} + tags_dict = {Path("img.jpg"): ["z_tag", "a_tag", "m_tag"]} assert get_all_tags(tags_dict) == ["a_tag", "m_tag", "z_tag"] @@ -42,27 +43,27 @@ def test_get_all_tags_returns_sorted(): def test_save_and_load_roundtrip(): tags_dict = { - "image1.jpg": ["galaxy", "bright"], - "image2.jpg": ["star"], - "image3.jpg": [], + Path("image1.jpg"): ["galaxy", "bright"], + Path("image2.jpg"): ["star"], + Path("image3.jpg"): [], } with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as f: - path = f.name + path = Path(f.name) try: save_tags(path, tags_dict) loaded = load_tags(path) - assert loaded["image1.jpg"] == ["bright", "galaxy"] - assert loaded["image2.jpg"] == ["star"] - assert loaded["image3.jpg"] == [] + assert loaded[Path("image1.jpg")] == ["bright", "galaxy"] + assert loaded[Path("image2.jpg")] == ["star"] + assert loaded[Path("image3.jpg")] == [] finally: os.unlink(path) def test_save_tags_csv_structure(): """The CSV must have a 'filename' header followed by sorted tag columns.""" - tags_dict = {"img.png": ["b_tag", "a_tag"]} + tags_dict = {Path("img.png"): ["b_tag", "a_tag"]} with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as f: - path = f.name + path = Path(f.name) try: save_tags(path, tags_dict) with open(path, newline="") as fh: @@ -78,11 +79,11 @@ def test_save_tags_csv_structure(): def test_save_tags_binary_values(): """Files that don't have a tag should get 0; those that do should get 1.""" tags_dict = { - "a.jpg": ["galaxy"], - "b.jpg": [], + Path("a.jpg"): ["galaxy"], + Path("b.jpg"): [], } with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as f: - path = f.name + path = Path(f.name) try: save_tags(path, tags_dict) loaded_raw: list = [] @@ -99,13 +100,13 @@ def test_save_tags_binary_values(): def test_load_tags_missing_file(): - assert load_tags("/nonexistent/path/tags.csv") == {} + assert load_tags(Path("/nonexistent/path/tags.csv")) == {} def test_load_tags_empty_csv(): with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as f: f.write("filename\n") - path = f.name + path = Path(f.name) try: result = load_tags(path) assert result == {} @@ -119,7 +120,7 @@ def test_load_tags_empty_csv(): def test_repair_csv_noop_when_file_missing(): """repair_csv must not raise when the file does not exist.""" - repair_csv("/nonexistent/path/tags.csv") # should not raise + repair_csv(Path("/nonexistent/path/tags.csv")) # should not raise def test_repair_csv_adds_missing_columns(): @@ -130,7 +131,7 @@ def test_repair_csv_adds_missing_columns(): with tempfile.NamedTemporaryFile( suffix=".csv", delete=False, mode="w", newline="" ) as f: - path = f.name + path = Path(f.name) writer = csv.writer(f) # Header already knows about 'nebula', but old_image.jpg row doesn't writer.writerow(["filename", "galaxy", "nebula"]) @@ -158,26 +159,26 @@ def test_repair_csv_adds_missing_columns(): def test_repair_csv_idempotent(): """Calling repair_csv twice produces the same result as calling it once.""" tags_dict = { - "a.jpg": ["galaxy", "bright"], - "b.jpg": ["star"], + Path("a.jpg"): ["galaxy", "bright"], + Path("b.jpg"): ["star"], } with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as f: - path = f.name + path = Path(f.name) try: save_tags(path, tags_dict) repair_csv(path) repair_csv(path) loaded = load_tags(path) - assert set(loaded["a.jpg"]) == {"galaxy", "bright"} - assert loaded["b.jpg"] == ["star"] + assert set(loaded[Path("a.jpg")]) == {"galaxy", "bright"} + assert loaded[Path("b.jpg")] == ["star"] finally: os.unlink(path) def test_repair_csv_preserves_all_rows(): - tags_dict = {f"img{i}.jpg": ["tag_a"] for i in range(10)} + tags_dict = {Path(f"img{i}.jpg"): ["tag_a"] for i in range(10)} with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as f: - path = f.name + path = Path(f.name) try: save_tags(path, tags_dict) repair_csv(path) diff --git a/tests/test_keyboard_widget.py b/tests/test_keyboard_widget.py index df1e129..5373582 100644 --- a/tests/test_keyboard_widget.py +++ b/tests/test_keyboard_widget.py @@ -153,3 +153,97 @@ def test_keyboard_widget_total_button_count(): kb = KeyboardWidget({}) total_keys = sum(len(row) for row in KEYBOARD_LAYOUT) assert len(kb._buttons) == total_keys + + +# --------------------------------------------------------------------------- +# HotkeyConfigDialog – sort and load-from-file features +# --------------------------------------------------------------------------- + +import json +import tempfile +import os as _os + +from hotkey_tagger import HotkeyConfigDialog, parse_keys_field + + +def test_sort_by_tag_alphabetical(): + """_sort_by_tag reorders rows alphabetically by tag name.""" + dlg = HotkeyConfigDialog({"z": "zebra", "a": "apple", "m": "mango"}) + dlg._sort_by_tag() + tags = [dlg.table.item(r, 1).text() for r in range(dlg.table.rowCount())] + assert tags == sorted(tags, key=str.lower) + + +def test_sort_by_key_keyboard_order(): + """_sort_by_key places '1'-keyed tag before 'q'-keyed tag before 'a'-keyed tag.""" + hotkey_map = {"a": "asdf_tag", "q": "qwerty_tag", "1": "number_tag"} + dlg = HotkeyConfigDialog(hotkey_map) + dlg._sort_by_key() + tags = [dlg.table.item(r, 1).text() for r in range(dlg.table.rowCount())] + # number row ('1') < qwerty row ('q') < asdf row ('a') in KEYBOARD_ORDER + assert tags.index("number_tag") < tags.index("qwerty_tag") + assert tags.index("qwerty_tag") < tags.index("asdf_tag") + + +def test_sort_by_key_zxcv_after_asdf(): + """Keys in the ZXCV row sort after the ASDF row.""" + hotkey_map = {"z": "zxcv_tag", "f": "asdf_tag"} + dlg = HotkeyConfigDialog(hotkey_map) + dlg._sort_by_key() + tags = [dlg.table.item(r, 1).text() for r in range(dlg.table.rowCount())] + assert tags.index("asdf_tag") < tags.index("zxcv_tag") + + +def test_sort_by_tag_case_insensitive(): + """_sort_by_tag treats uppercase and lowercase tag names equivalently.""" + dlg = HotkeyConfigDialog({"c": "Charlie", "a": "alice", "b": "Bob"}) + dlg._sort_by_tag() + tags = [dlg.table.item(r, 1).text() for r in range(dlg.table.rowCount())] + assert tags == sorted(tags, key=str.lower) + + +def test_load_from_file_populates_table(tmp_path): + """_load_from_file replaces the table contents with the imported map.""" + hotkeys_file = tmp_path / "hotkeys.json" + hotkeys_file.write_text(json.dumps({"g": "galaxy", "s": "star"}), encoding="utf-8") + + dlg = HotkeyConfigDialog({}) + assert dlg.table.rowCount() == 0 + + # Simulate the file-load logic directly (bypasses QFileDialog) + data = json.loads(hotkeys_file.read_text(encoding="utf-8")) + from hotkey_tagger import group_keys_by_tag + hotkey_map = {str(k).lower(): str(v) for k, v in data.items() if len(str(k)) == 1} + by_tag = group_keys_by_tag(hotkey_map) + rows = [(",".join(keys), tag) for tag, keys in by_tag.items()] + dlg._populate_table(rows) + + all_tags = {dlg.table.item(r, 1).text() for r in range(dlg.table.rowCount())} + assert all_tags == {"galaxy", "star"} + + +def test_load_from_file_replaces_existing_rows(tmp_path): + """_load_from_file (via _populate_table) replaces existing rows entirely.""" + hotkeys_file = tmp_path / "hotkeys.json" + hotkeys_file.write_text(json.dumps({"n": "nebula"}), encoding="utf-8") + + dlg = HotkeyConfigDialog({"g": "galaxy", "s": "star"}) + assert dlg.table.rowCount() == 2 + + data = json.loads(hotkeys_file.read_text(encoding="utf-8")) + from hotkey_tagger import group_keys_by_tag + hotkey_map = {str(k).lower(): str(v) for k, v in data.items() if len(str(k)) == 1} + by_tag = group_keys_by_tag(hotkey_map) + rows = [(",".join(keys), tag) for tag, keys in by_tag.items()] + dlg._populate_table(rows) + + assert dlg.table.rowCount() == 1 + assert dlg.table.item(0, 1).text() == "nebula" + + +def test_extract_rows_returns_all_cells(): + """_extract_rows reflects the current table contents accurately.""" + dlg = HotkeyConfigDialog({"g": "galaxy", "s": "star"}) + rows = dlg._extract_rows() + all_tags = {r[1] for r in rows} + assert all_tags == {"galaxy", "star"} diff --git a/tests/test_settings.py b/tests/test_settings.py index f3c722c..09e6476 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -3,6 +3,7 @@ import json import os import tempfile +from pathlib import Path import pytest @@ -33,7 +34,7 @@ def test_save_and_load_roundtrip(): s.last_csv_path = "/data/images/tags.csv" with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f: - path = f.name + path = Path(f.name) try: s.save(path) @@ -52,7 +53,7 @@ def test_save_and_load_roundtrip(): def test_load_returns_false_for_missing_file(): s = HotkeySettings() - assert s.load("/nonexistent/path/settings.json") is False + assert s.load(Path("/nonexistent/path/settings.json")) is False def test_load_returns_false_for_corrupt_json(): @@ -60,7 +61,7 @@ def test_load_returns_false_for_corrupt_json(): suffix=".json", delete=False, mode="w" ) as f: f.write("not valid json {{") - path = f.name + path = Path(f.name) try: s = HotkeySettings() assert s.load(path) is False @@ -72,7 +73,7 @@ def test_save_produces_valid_json(): s = HotkeySettings() s.hotkey_map = {"n": "nebula"} with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f: - path = f.name + path = Path(f.name) try: s.save(path) with open(path) as fh: @@ -91,7 +92,7 @@ def test_load_partial_json_uses_defaults(): suffix=".json", delete=False, mode="w" ) as f: json.dump({"hotkey_map": {"x": "xray"}}, f) - path = f.name + path = Path(f.name) try: s = HotkeySettings() assert s.load(path) is True @@ -107,7 +108,7 @@ def test_hotkey_map_survives_multiple_saves(): s = HotkeySettings() s.hotkey_map = {"a": "asteroid", "b": "binary_star"} with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f: - path = f.name + path = Path(f.name) try: s.save(path) s.hotkey_map["c"] = "comet"