-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_docker_functionality.py
More file actions
146 lines (115 loc) · 5.09 KB
/
test_docker_functionality.py
File metadata and controls
146 lines (115 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python3
"""
Test script to verify Docker deployment functionality
"""
import sys
import os
sys.path.insert(0, '/app')
def test_imports():
"""Test that all required modules can be imported"""
try:
from synthesizability_predictor import SynthesizabilityClassifier, LLMSynthesizabilityPredictor
from centralized_field_mapping import apply_field_mapping_to_generation
from gradio_app import create_synthetic_dataset, generate_materials, train_vae_model
from sklearn.preprocessing import StandardScaler
import torch
import pandas as pd
import numpy as np
print("✅ All imports successful")
return True
except Exception as e:
print(f"❌ Import error: {e}")
return False
def test_synthetic_dataset():
"""Test synthetic dataset creation"""
try:
from gradio_app import create_synthetic_dataset
# Create synthetic dataset
dataset = create_synthetic_dataset(100)
# Check required fields
required_fields = ['formation_energy_per_atom', 'energy_above_hull', 'band_gap', 'nsites', 'density', 'electronegativity', 'atomic_radius']
missing_fields = [field for field in required_fields if field not in dataset.columns]
if missing_fields:
print(f"❌ Missing required fields in synthetic dataset: {missing_fields}")
return False
print(f"✅ Synthetic dataset created with {len(dataset)} materials")
print(f"✅ All required fields present: {list(dataset.columns)}")
return True
except Exception as e:
print(f"❌ Synthetic dataset test failed: {e}")
return False
def test_vae_training():
"""Test VAE training with synthetic data"""
try:
from gradio_app import create_synthetic_dataset, train_vae_model
from sklearn.preprocessing import StandardScaler
# Create dataset
dataset = create_synthetic_dataset(100)
# Prepare features
feature_cols = ['composition_1', 'composition_2', 'formation_energy_per_atom', 'density', 'electronegativity', 'atomic_radius']
features = dataset[feature_cols].values
# Scale features
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
# Train VAE
vae_model = train_vae_model(features_scaled, latent_dim=5, epochs=10)
print("✅ VAE training successful")
return True
except Exception as e:
print(f"❌ VAE training test failed: {e}")
return False
def test_material_generation():
"""Test material generation"""
try:
from gradio_app import create_synthetic_dataset, train_vae_model, generate_materials
from sklearn.preprocessing import StandardScaler
# Create dataset and train VAE
dataset = create_synthetic_dataset(100)
feature_cols = ['composition_1', 'composition_2', 'formation_energy_per_atom', 'density', 'electronegativity', 'atomic_radius']
features = dataset[feature_cols].values
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
vae_model = train_vae_model(features_scaled, latent_dim=5, epochs=10)
# Generate materials
generated_df = generate_materials(vae_model, scaler, num_samples=10)
# Check required fields
required_fields = ['formation_energy_per_atom', 'energy_above_hull', 'band_gap', 'nsites', 'density', 'electronegativity', 'atomic_radius']
missing_fields = [field for field in required_fields if field not in generated_df.columns]
if missing_fields:
print(f"❌ Missing required fields in generated materials: {missing_fields}")
return False
print(f"✅ Material generation successful: {len(generated_df)} materials generated")
print(f"✅ All required fields present in generated materials")
return True
except Exception as e:
print(f"❌ Material generation test failed: {e}")
return False
def main():
"""Run all tests"""
print("🧪 Testing Docker deployment functionality...")
print("=" * 60)
tests = [
("Import Test", test_imports),
("Synthetic Dataset Test", test_synthetic_dataset),
("VAE Training Test", test_vae_training),
("Material Generation Test", test_material_generation)
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
print(f"\n📋 {test_name}")
print("-" * 40)
if test_func():
passed += 1
else:
print(f"❌ {test_name} failed")
print("\n" + "=" * 60)
print(f"📊 Test Results: {passed}/{total} tests passed")
if passed == total:
print("🎉 All tests passed! Docker deployment is working correctly.")
return 0
else:
print("⚠️ Some tests failed. Please check the errors above.")
return 1
if __name__ == "__main__":
sys.exit(main())