-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcache_nodes.py
More file actions
197 lines (173 loc) · 6.25 KB
/
cache_nodes.py
File metadata and controls
197 lines (173 loc) · 6.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import comfy.model_patcher
try:
from pruna_pro import SmashConfig, smash
except ImportError:
print("pruna_pro not installed, skipping")
try:
from pruna import SmashConfig, smash
except ImportError:
print("Neither pruna_pro nor pruna are installed, skipping")
class CacheModelMixin:
def _clone_patcher(self, model):
"""Clone the model patcher from a given model instance."""
if isinstance(model, comfy.model_patcher.ModelPatcher):
return model.clone()
else:
# model is a BaseModel
return model.patcher.clone()
def _apply_common_caching(self, model, caching_method, hyperparams):
"""Apply a specific caching method to a model."""
# Clone the model patcher
model_patcher = self._clone_patcher(model)
# Set up smash config
smash_config = SmashConfig()
try:
smash_config["cacher"] = caching_method
except KeyError:
raise ValueError(
f"{caching_method} caching requires pruna_pro to be installed"
)
# Merge the hyperparameters into smash config
for key, value in hyperparams.items():
smash_config[key] = value
smash_config._prepare_saving = False
# Add an attribute to patched to pass the info that it is a comfy model
model_patcher.model.diffusion_model.is_comfy = True
# Smash the model and update the internal reference
smashed_model = smash(model_patcher.model.diffusion_model, smash_config)
model_patcher.add_object_patch(
"diffusion_model",
smashed_model.__getattribute__("_PrunaProModel__internal_model_ref"),
)
return model_patcher
# CacheModelAdaptive now simply supplies its specific config parameters
class CacheModelAdaptive(CacheModelMixin):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"threshold": (
"FLOAT",
{"default": 0.01, "step": 0.001, "min": 0.001, "max": 0.2},
),
"max_skip_steps": (
"INT",
{"default": 4, "step": 1, "min": 1, "max": 5},
),
"cache_mode": (
"STRING",
{
"default": "default",
"options": ["default", "taylor", "ab", "bdf"],
},
),
"compiler": (
"STRING",
{
"default": "torch_compile",
"options": ["torch_compile", "stable_fast", "none"],
},
),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_caching"
CATEGORY = "Pruna"
def apply_caching(self, model, threshold, max_skip_steps, cache_mode, compiler):
# Prepare caching-specific configuration
hyperparams = {
"adaptive_threshold": threshold,
"adaptive_max_skip_steps": max_skip_steps,
"adaptive_cache_mode": cache_mode,
}
if compiler != "none":
hyperparams["compiler"] = compiler
model_patcher = self._apply_common_caching(
model,
caching_method="adaptive",
hyperparams=hyperparams,
)
return (model_patcher,)
# CacheModelPeriodic also supplies its own configuration parameters
class CacheModelPeriodic(CacheModelMixin):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"cache_interval": ("INT", {"default": 2, "min": 1, "max": 7}),
"start_step": ("INT", {"default": 2, "min": 0, "max": 10}),
"cache_mode": (
"STRING",
{
"default": "default",
"options": ["default", "taylor", "ab", "bdf"],
},
),
"compiler": (
"STRING",
{
"default": "torch_compile",
"options": ["torch_compile", "stable_fast", "none"],
},
),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_caching"
CATEGORY = "Pruna"
def apply_caching(self, model, cache_interval, start_step, cache_mode, compiler):
# Prepare caching-specific configuration
hyperparams = {
"periodic_cache_interval": cache_interval,
"periodic_start_step": start_step,
"periodic_cache_mode": cache_mode,
}
if compiler != "none":
hyperparams["compiler"] = compiler
model_patcher = self._apply_common_caching(
model,
caching_method="periodic",
hyperparams=hyperparams,
)
return (model_patcher,)
class CacheModelAuto(CacheModelMixin):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"speed_factor": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0}),
"cache_mode": (
"STRING",
{
"default": "default",
"options": ["default", "taylor", "ab", "bdf"],
},
),
"compiler": (
"STRING",
{
"default": "torch_compile",
"options": ["torch_compile", "stable_fast", "none"],
},
),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_caching"
CATEGORY = "Pruna"
def apply_caching(self, model, compiler, speed_factor, cache_mode):
hyperparams = {
"auto_speed_factor": speed_factor,
"auto_cache_mode": cache_mode,
}
if compiler != "none":
hyperparams["compiler"] = compiler
model_patcher = self._apply_common_caching(
model,
caching_method="auto",
hyperparams=hyperparams,
)
return (model_patcher,)