Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ continue.md
.uv/
.analysis_archive/

# Claude Code
.claude/

# Rust
med_core_rs/target/
med_core_rs/Cargo.lock
Expand Down
1 change: 1 addition & 0 deletions .nvmrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
22
56 changes: 28 additions & 28 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

project = 'MedFusion'
copyright = '2024-2026, MedFusion Team'
author = 'MedFusion Team'
release = '0.2.0'
project = "MedFusion"
copyright = "2024-2026, MedFusion Team"
author = "MedFusion Team"
release = "0.2.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.githubpages',
'myst_parser',
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx.ext.coverage",
"sphinx.ext.mathjax",
"sphinx.ext.githubpages",
"myst_parser",
]

# Napoleon settings
Expand All @@ -51,14 +51,14 @@

# Autodoc settings
autodoc_default_options = {
'members': True,
'member-order': 'bysource',
'special-members': '__init__',
'undoc-members': True,
'exclude-members': '__weakref__'
"members": True,
"member-order": "bysource",
"special-members": "__init__",
"undoc-members": True,
"exclude-members": "__weakref__",
}
autodoc_typehints = 'description'
autodoc_typehints_description_target = 'documented'
autodoc_typehints = "description"
autodoc_typehints_description_target = "documented"

# Autosummary settings
autosummary_generate = True
Expand All @@ -80,14 +80,14 @@
"tasklist",
]

templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

html_theme = 'furo'
html_static_path = ['_static']
html_theme = "furo"
html_static_path = ["_static"]

html_theme_options = {
"light_css_variables": {
Expand All @@ -112,9 +112,9 @@

# Intersphinx mapping
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'torch': ('https://pytorch.org/docs/stable/', None),
'numpy': ('https://numpy.org/doc/stable/', None),
"python": ("https://docs.python.org/3", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
}

# Todo extension
Expand Down
9 changes: 5 additions & 4 deletions examples/advanced_attention_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def demo_factory_function():
print(f" • {attn_type:12s}: {x.shape} -> {out.shape}")

print("\n使用示例:")
code = '''
code = """
from med_core.attention_supervision import create_attention_module

# 创建注意力模块
Expand All @@ -202,7 +202,7 @@ def forward(self, x):
x = self.conv(x)
x = self.attention(x) # 应用注意力
return x
'''
"""
print(code)


Expand Down Expand Up @@ -267,7 +267,7 @@ def demo_integration_example():
print("=" * 60)

print("\n完整的模型集成示例:")
code = '''
code = """
import torch.nn as nn
from med_core.attention_supervision import (
SEAttention,
Expand Down Expand Up @@ -349,7 +349,7 @@ def compute_loss(self, x, y):
loss, loss_dict = model.compute_loss(x, y)
print(f"Total loss: {loss.item():.4f}")
print(f"Loss components: {loss_dict}")
'''
"""
print(code)


Expand Down Expand Up @@ -403,6 +403,7 @@ def main():
except Exception as e:
print(f"\n❌ 错误: {e}")
import traceback

traceback.print_exc()


Expand Down
13 changes: 7 additions & 6 deletions examples/benchmark_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def fast_function():
slow_function()
slow_time = time.time() - start
print(f" 耗时: {slow_time:.3f}s")
print(f" 吞吐量: {100/slow_time:.1f} ops/s")
print(f" 吞吐量: {100 / slow_time:.1f} ops/s")

# 测试快速实现
print("\n2. 测试快速实现:")
Expand All @@ -44,7 +44,7 @@ def fast_function():
fast_function()
fast_time = time.time() - start
print(f" 耗时: {fast_time:.3f}s")
print(f" 吞吐量: {100/fast_time:.1f} ops/s")
print(f" 吞吐量: {100 / fast_time:.1f} ops/s")

# 比较
speedup = slow_time / fast_time
Expand All @@ -65,7 +65,7 @@ def demo_benchmark_suite():
print(" • 检测性能回归")

print("\n使用示例:")
code = '''
code = """
from med_core.utils.benchmark import BenchmarkSuite, PerformanceBenchmark

# 1. 创建测试套件
Expand All @@ -86,7 +86,7 @@ def test_model_inference():

# 5. 与基线比较
suite.compare_with("baseline.json")
'''
"""
print(code)


Expand Down Expand Up @@ -204,7 +204,7 @@ def demo_ci_integration():
print("=" * 60)

print("\nGitHub Actions 示例:")
yaml = '''
yaml = """
name: Performance Benchmarks

on: [push, pull_request]
Expand Down Expand Up @@ -239,7 +239,7 @@ def demo_ci_integration():
with:
name: benchmark-results
path: benchmarks/
'''
"""
print(yaml)


Expand Down Expand Up @@ -289,6 +289,7 @@ def main():
except Exception as e:
print(f"\n❌ 错误: {e}")
import traceback

traceback.print_exc()


Expand Down
5 changes: 3 additions & 2 deletions examples/cache_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def demo_usage_example():
print("实际使用示例")
print("=" * 60)

code = '''
code = """
# 1. 创建原始数据集
from med_core.datasets import MedicalDataset

Expand Down Expand Up @@ -156,7 +156,7 @@ def demo_usage_example():
if hasattr(cached_dataset, 'get_cache_stats'):
stats = cached_dataset.get_cache_stats()
print(f"缓存命中率: {stats['hit_rate']:.2%}")
'''
"""

print("\n代码示例:")
print(code)
Expand Down Expand Up @@ -193,6 +193,7 @@ def main():
except Exception as e:
print(f"\n错误: {e}")
import traceback

traceback.print_exc()


Expand Down
13 changes: 7 additions & 6 deletions examples/cache_demo_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def demo_lru_cache():

# 获取数据
print("\n2. 从缓存获取数据:")
result1 = cache.get('image_001')
result1 = cache.get("image_001")
print(f" image_001: {result1} (命中)")
result2 = cache.get('image_002')
result2 = cache.get("image_002")
print(f" image_002: {result2} (命中)")
result3 = cache.get('image_999')
result3 = cache.get("image_999")
print(f" image_999: {result3} (未命中)")

# 添加新数据(触发淘汰)
Expand All @@ -84,7 +84,7 @@ def demo_lru_cache():
print("\n4. 缓存统计:")
stats = cache.get_stats()
for key, value in stats.items():
if key == 'hit_rate':
if key == "hit_rate":
print(f" {key}: {value:.2%}")
else:
print(f" {key}: {value}")
Expand Down Expand Up @@ -126,7 +126,7 @@ def slow_load_data(idx):
if time_with_cache > 0:
speedup = time_no_cache / time_with_cache
print(f"\n3. 加速比: {speedup:.1f}x")
print(f" 性能提升: {(1 - time_with_cache/time_no_cache) * 100:.1f}%")
print(f" 性能提升: {(1 - time_with_cache / time_no_cache) * 100:.1f}%")


def demo_access_patterns():
Expand Down Expand Up @@ -185,7 +185,7 @@ def demo_cache_size_impact():
cache.get(idx)

stats = cache.get_stats()
hit_rate = stats['hit_rate']
hit_rate = stats["hit_rate"]

if hit_rate < 0.3:
desc = "太小,效果差"
Expand Down Expand Up @@ -264,6 +264,7 @@ def main():
except Exception as e:
print(f"\n❌ 错误: {e}")
import traceback

traceback.print_exc()


Expand Down
67 changes: 38 additions & 29 deletions examples/distributed_training_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def train_epoch(model, dataloader, criterion, optimizer, device, rank):
total += target.size(0)

if batch_idx % 10 == 0:
print(f"Rank {rank}: Batch {batch_idx}/{len(dataloader)}, "
f"Loss: {loss.item():.4f}")
print(
f"Rank {rank}: Batch {batch_idx}/{len(dataloader)}, "
f"Loss: {loss.item():.4f}"
)

avg_loss = total_loss / len(dataloader)
accuracy = 100.0 * correct / total
Expand Down Expand Up @@ -135,8 +137,8 @@ def train_ddp(args):
optimizer,
epoch + 1,
f"outputs/checkpoint_ddp_epoch_{epoch + 1}.pt",
loss=avg_metrics['loss'].item(),
accuracy=avg_metrics['accuracy'].item(),
loss=avg_metrics["loss"].item(),
accuracy=avg_metrics["accuracy"].item(),
)

# 清理
Expand Down Expand Up @@ -219,8 +221,8 @@ def train_fsdp(args):
optimizer,
epoch + 1,
f"outputs/checkpoint_fsdp_epoch_{epoch + 1}.pt",
loss=avg_metrics['loss'].item(),
accuracy=avg_metrics['accuracy'].item(),
loss=avg_metrics["loss"].item(),
accuracy=avg_metrics["accuracy"].item(),
)

# 清理
Expand Down Expand Up @@ -273,33 +275,40 @@ def main():
parser = argparse.ArgumentParser(description="分布式训练示例")

# 训练参数
parser.add_argument("--strategy", type=str, default="ddp",
choices=["ddp", "fsdp"],
help="分布式策略")
parser.add_argument("--backend", type=str, default="nccl",
choices=["nccl", "gloo", "mpi"],
help="分布式后端")
parser.add_argument("--epochs", type=int, default=5,
help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32,
help="批次大小")
parser.add_argument("--lr", type=float, default=0.001,
help="学习率")
parser.add_argument("--num_samples", type=int, default=1000,
help="样本数量")
parser.add_argument("--save_interval", type=int, default=2,
help="保存间隔")
parser.add_argument(
"--strategy",
type=str,
default="ddp",
choices=["ddp", "fsdp"],
help="分布式策略",
)
parser.add_argument(
"--backend",
type=str,
default="nccl",
choices=["nccl", "gloo", "mpi"],
help="分布式后端",
)
parser.add_argument("--epochs", type=int, default=5, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=32, help="批次大小")
parser.add_argument("--lr", type=float, default=0.001, help="学习率")
parser.add_argument("--num_samples", type=int, default=1000, help="样本数量")
parser.add_argument("--save_interval", type=int, default=2, help="保存间隔")

# FSDP 参数
parser.add_argument("--sharding_strategy", type=str, default="FULL_SHARD",
choices=["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"],
help="FSDP 分片策略")
parser.add_argument("--min_num_params", type=int, default=1000,
help="自动包装的最小参数数量")
parser.add_argument(
"--sharding_strategy",
type=str,
default="FULL_SHARD",
choices=["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"],
help="FSDP 分片策略",
)
parser.add_argument(
"--min_num_params", type=int, default=1000, help="自动包装的最小参数数量"
)

# 其他
parser.add_argument("--demo", action="store_true",
help="显示使用示例")
parser.add_argument("--demo", action="store_true", help="显示使用示例")

args = parser.parse_args()

Expand Down
Loading
Loading