Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,35 @@ set_tests_properties(test_multithreading_benchmark PROPERTIES
TIMEOUT 300
LABELS "benchmark;multithreading;performance"
)

# Test 7: Verification mode test
# Tests --verify flag for checking model accuracy
add_test(
NAME test_verify_mode
COMMAND ${CMAKE_COMMAND}
-DNNETS_EXE=$<TARGET_FILE:NNets>
-DCONFIG_DIR=${CMAKE_SOURCE_DIR}/configs
-DWORK_DIR=${CMAKE_BINARY_DIR}
-P ${CMAKE_SOURCE_DIR}/cmake/test_verify.cmake
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
)
set_tests_properties(test_verify_mode PROPERTIES
TIMEOUT 180
LABELS "verification;inference"
)

# Test 8: Retraining mode test
# Tests -r flag for retraining existing models with new classes
add_test(
NAME test_retraining_mode
COMMAND ${CMAKE_COMMAND}
-DNNETS_EXE=$<TARGET_FILE:NNets>
-DCONFIG_DIR=${CMAKE_SOURCE_DIR}/configs
-DWORK_DIR=${CMAKE_BINARY_DIR}
-P ${CMAKE_SOURCE_DIR}/cmake/test_retraining.cmake
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
)
set_tests_properties(test_retraining_mode PROPERTIES
TIMEOUT 600
LABELS "retraining;training"
)
67 changes: 66 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ g++ -o NNets main.cpp

## Usage

The program supports two main modes: **Training** and **Inference**.
The program supports several modes: **Training**, **Retraining**, **Inference**, and **Verification**.

### Command Line Options

Expand All @@ -56,19 +56,35 @@ Usage: ./NNets [options]
MODES:
Training mode (default): Train network and optionally save to file
Inference mode: Load trained network and classify inputs
Retraining mode: Load existing network and continue training with new data

TRAINING OPTIONS:
-c, --config <file> Load training configuration from JSON file
-s, --save <file> Save trained network to JSON file after training
-t, --test Run automated test after training (no interactive mode)
-b, --benchmark Run benchmark to measure training speed

RETRAINING OPTIONS:
-r, --retrain <file> Load existing network and continue training (retraining mode)
Combines -l (load) with training mode. Requires -c for new data.
New classes in config (without output_neuron) will be trained.

INFERENCE OPTIONS:
-l, --load <file> Load trained network from JSON file (inference mode)
-i, --input <text> Classify single input text and exit (non-interactive)
--verify Verify accuracy of loaded model on training config (-c required)

PERFORMANCE OPTIONS:
-j, --threads <n> Number of threads to use (0 = auto, default)
--single-thread Disable multithreading (use single thread)

GENERAL OPTIONS:
-h, --help Show help message

INTERRUPTION:
Press Ctrl+C during training to interrupt gracefully.
The network will be saved if -s is specified.
Training can be continued later with -r option.
```

### Training Mode
Expand Down Expand Up @@ -105,6 +121,55 @@ Load a pre-trained network and classify inputs:
./NNets -l model.json -i "yes"
```

### Retraining Mode

Continue training an existing network with new classes or additional training data:

```bash
# Add new classes to an existing model
# 1. First, train initial model with classes yes/no
./NNets -c configs/simple.json -s model_v1.json

# 2. Create a new config with additional classes (e.g., adding "maybe")
# 3. Retrain the model with new data
./NNets -r model_v1.json -c configs/extended.json -s model_v2.json
```

Retraining automatically detects which classes are already trained (have `output_neuron`) and only trains new classes. This is useful for:
- Adding new recognition classes without retraining from scratch
- Continuing interrupted training sessions
- Incrementally improving the model

### Verification Mode

Check the accuracy of a trained model on test data:

```bash
# Verify model accuracy on training data
./NNets -l model.json -c configs/test.json --verify
```

This mode loads the trained network and tests it against all samples in the configuration file, reporting accuracy statistics.

### Training Interruption

Training can be interrupted at any time by pressing Ctrl+C:
- The first Ctrl+C requests graceful interruption (finishes current iteration)
- The second Ctrl+C forces immediate exit
- If `-s` option is specified, the network state is saved automatically
- Training can be continued later using the `-r` (retrain) option

```bash
# Start long training with auto-save
./NNets -c configs/large.json -s checkpoint.json

# Press Ctrl+C to interrupt...
# Network saved to checkpoint.json

# Continue training later
./NNets -r checkpoint.json -c configs/large.json -s final_model.json
```

### Training Configuration Format

Training configurations are JSON files that define the classes and training images:
Expand Down
141 changes: 141 additions & 0 deletions cmake/test_retraining.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# CMake script to test retraining functionality
# This script:
# 1. Trains a model with simple.json (yes/no/empty classes)
# 2. Saves the model
# 3. Retrains with extended.json (adds cat/dog/bird/fish classes)
# 4. Verifies the retrained model works for both old and new classes

# Check required variables
if(NOT DEFINED NNETS_EXE)
message(FATAL_ERROR "NNETS_EXE not defined")
endif()

if(NOT DEFINED CONFIG_DIR)
message(FATAL_ERROR "CONFIG_DIR not defined")
endif()

if(NOT DEFINED WORK_DIR)
message(FATAL_ERROR "WORK_DIR not defined")
endif()

set(MODEL_V1 "${WORK_DIR}/test_retrain_v1.json")
set(MODEL_V2 "${WORK_DIR}/test_retrain_v2.json")
set(SIMPLE_CONFIG "${CONFIG_DIR}/simple.json")
set(EXTENDED_CONFIG "${CONFIG_DIR}/extended.json")

message(STATUS "=== Testing Retraining Mode ===")
message(STATUS "Executable: ${NNETS_EXE}")
message(STATUS "Simple config: ${SIMPLE_CONFIG}")
message(STATUS "Extended config: ${EXTENDED_CONFIG}")

# Step 1: Train initial model with simple config (yes/no)
message(STATUS "Step 1: Training initial model with simple config...")
execute_process(
COMMAND "${NNETS_EXE}" -c "${SIMPLE_CONFIG}" -s "${MODEL_V1}" -t
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE TRAIN1_RESULT
OUTPUT_VARIABLE TRAIN1_OUTPUT
ERROR_VARIABLE TRAIN1_ERROR
TIMEOUT 120
)

if(NOT TRAIN1_RESULT EQUAL 0)
message(FATAL_ERROR "Initial training failed with code ${TRAIN1_RESULT}:\nOutput: ${TRAIN1_OUTPUT}\nError: ${TRAIN1_ERROR}")
endif()
message(STATUS "Initial training completed successfully")

# Verify model file was created
if(NOT EXISTS "${MODEL_V1}")
message(FATAL_ERROR "Model file was not created: ${MODEL_V1}")
endif()
message(STATUS "Initial model saved: ${MODEL_V1}")

# Step 2: Verify the initial model using --verify
message(STATUS "Step 2: Verifying initial model accuracy...")
execute_process(
COMMAND "${NNETS_EXE}" -l "${MODEL_V1}" -c "${SIMPLE_CONFIG}" --verify
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE VERIFY1_RESULT
OUTPUT_VARIABLE VERIFY1_OUTPUT
ERROR_VARIABLE VERIFY1_ERROR
TIMEOUT 60
)

if(NOT VERIFY1_RESULT EQUAL 0)
message(FATAL_ERROR "Initial model verification failed with code ${VERIFY1_RESULT}:\nOutput: ${VERIFY1_OUTPUT}\nError: ${VERIFY1_ERROR}")
endif()

# Check that verification output contains accuracy info
string(FIND "${VERIFY1_OUTPUT}" "Accuracy:" ACCURACY_FOUND)
if(ACCURACY_FOUND EQUAL -1)
message(FATAL_ERROR "Verification output doesn't contain accuracy info:\n${VERIFY1_OUTPUT}")
endif()
message(STATUS "Initial model verification passed")
message(STATUS "Verification output:\n${VERIFY1_OUTPUT}")

# Step 3: Test inference with initial model
message(STATUS "Step 3: Testing inference with initial model...")
execute_process(
COMMAND "${NNETS_EXE}" -l "${MODEL_V1}" -i "yes"
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE INFER1_RESULT
OUTPUT_VARIABLE INFER1_OUTPUT
ERROR_VARIABLE INFER1_ERROR
TIMEOUT 30
)

if(NOT INFER1_RESULT EQUAL 0)
message(FATAL_ERROR "Initial inference failed with code ${INFER1_RESULT}:\nOutput: ${INFER1_OUTPUT}\nError: ${INFER1_ERROR}")
endif()
message(STATUS "Initial inference passed")

# Step 4: Retrain with extended config (adding more classes)
# Note: For simplicity, we create a new config that can be used for retraining
# The retrain mode will detect that some classes are already trained
message(STATUS "Step 4: Retraining model with extended config...")
execute_process(
COMMAND "${NNETS_EXE}" -r "${MODEL_V1}" -c "${EXTENDED_CONFIG}" -s "${MODEL_V2}" -t
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE RETRAIN_RESULT
OUTPUT_VARIABLE RETRAIN_OUTPUT
ERROR_VARIABLE RETRAIN_ERROR
TIMEOUT 300
)

# Note: Retrain might show warnings about config mismatch, which is expected
# since simple.json has different classes than extended.json
# The actual retraining will train the new classes
message(STATUS "Retrain output: ${RETRAIN_OUTPUT}")

# Verify retrained model file was created
if(NOT EXISTS "${MODEL_V2}")
message(FATAL_ERROR "Retrained model file was not created: ${MODEL_V2}")
endif()
message(STATUS "Retrained model saved: ${MODEL_V2}")

# Step 5: Test inference with retrained model for new classes
message(STATUS "Step 5: Testing inference with retrained model...")
execute_process(
COMMAND "${NNETS_EXE}" -l "${MODEL_V2}" -i "cat"
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE INFER2_RESULT
OUTPUT_VARIABLE INFER2_OUTPUT
ERROR_VARIABLE INFER2_ERROR
TIMEOUT 30
)

if(NOT INFER2_RESULT EQUAL 0)
message(FATAL_ERROR "Retrained inference failed with code ${INFER2_RESULT}:\nOutput: ${INFER2_OUTPUT}\nError: ${INFER2_ERROR}")
endif()

# Check that output contains new class
string(FIND "${INFER2_OUTPUT}" "cat" CAT_FOUND)
if(CAT_FOUND EQUAL -1)
message(FATAL_ERROR "Retrained model doesn't recognize 'cat' class:\n${INFER2_OUTPUT}")
endif()
message(STATUS "Retrained inference passed")

# Cleanup
file(REMOVE "${MODEL_V1}")
file(REMOVE "${MODEL_V2}")
message(STATUS "=== Retraining Test PASSED ===")
82 changes: 82 additions & 0 deletions cmake/test_verify.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# CMake script to test verification mode (--verify)
# This script trains a model and then verifies its accuracy

# Check required variables
if(NOT DEFINED NNETS_EXE)
message(FATAL_ERROR "NNETS_EXE not defined")
endif()

if(NOT DEFINED CONFIG_DIR)
message(FATAL_ERROR "CONFIG_DIR not defined")
endif()

if(NOT DEFINED WORK_DIR)
message(FATAL_ERROR "WORK_DIR not defined")
endif()

set(MODEL_FILE "${WORK_DIR}/test_verify_model.json")
set(CONFIG_FILE "${CONFIG_DIR}/simple.json")

message(STATUS "=== Testing Verification Mode ===")
message(STATUS "Executable: ${NNETS_EXE}")
message(STATUS "Config: ${CONFIG_FILE}")
message(STATUS "Model: ${MODEL_FILE}")

# Step 1: Train and save model
message(STATUS "Step 1: Training model...")
execute_process(
COMMAND "${NNETS_EXE}" -c "${CONFIG_FILE}" -s "${MODEL_FILE}" -t
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE TRAIN_RESULT
OUTPUT_VARIABLE TRAIN_OUTPUT
ERROR_VARIABLE TRAIN_ERROR
TIMEOUT 120
)

if(NOT TRAIN_RESULT EQUAL 0)
message(FATAL_ERROR "Training failed with code ${TRAIN_RESULT}:\nOutput: ${TRAIN_OUTPUT}\nError: ${TRAIN_ERROR}")
endif()
message(STATUS "Training completed successfully")

# Verify model file was created
if(NOT EXISTS "${MODEL_FILE}")
message(FATAL_ERROR "Model file was not created: ${MODEL_FILE}")
endif()

# Step 2: Verify model accuracy using --verify
message(STATUS "Step 2: Verifying model accuracy...")
execute_process(
COMMAND "${NNETS_EXE}" -l "${MODEL_FILE}" -c "${CONFIG_FILE}" --verify
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE VERIFY_RESULT
OUTPUT_VARIABLE VERIFY_OUTPUT
ERROR_VARIABLE VERIFY_ERROR
TIMEOUT 60
)

if(NOT VERIFY_RESULT EQUAL 0)
message(FATAL_ERROR "Verification failed with code ${VERIFY_RESULT}:\nOutput: ${VERIFY_OUTPUT}\nError: ${VERIFY_ERROR}")
endif()

# Check that verification output contains expected elements
string(FIND "${VERIFY_OUTPUT}" "Verifying model accuracy" VERIFY_HEADER)
if(VERIFY_HEADER EQUAL -1)
message(FATAL_ERROR "Verification output missing header:\n${VERIFY_OUTPUT}")
endif()

string(FIND "${VERIFY_OUTPUT}" "Accuracy:" ACCURACY_FOUND)
if(ACCURACY_FOUND EQUAL -1)
message(FATAL_ERROR "Verification output missing accuracy:\n${VERIFY_OUTPUT}")
endif()

string(FIND "${VERIFY_OUTPUT}" "Passed:" PASSED_FOUND)
if(PASSED_FOUND EQUAL -1)
message(FATAL_ERROR "Verification output missing passed count:\n${VERIFY_OUTPUT}")
endif()

message(STATUS "Verification output:\n${VERIFY_OUTPUT}")
message(STATUS "Verification mode test passed")

# Cleanup
file(REMOVE "${MODEL_FILE}")
message(STATUS "=== Verification Mode Test PASSED ===")
Loading