forked from lzzcd001/MeshDiffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_diffusion.py
More file actions
30 lines (23 loc) · 869 Bytes
/
main_diffusion.py
File metadata and controls
30 lines (23 loc) · 869 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Training and evaluation"""
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import lib.diffusion.trainer as trainer
import lib.diffusion.evaler as evaler
import lib.diffusion.eval_neighbors as eval_gen
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "diffusion configs", lock_config=False)
flags.DEFINE_enum("mode", None, ["train", "uncond_gen", "cond_gen", "eval_gen"], "Running mode")
flags.mark_flags_as_required(["config", "mode"])
def main(argv):
if FLAGS.mode == 'train':
trainer.train(FLAGS.config)
elif FLAGS.mode == 'uncond_gen':
evaler.uncond_gen(FLAGS.config)
elif FLAGS.mode == 'cond_gen':
evaler.cond_gen(FLAGS.config)
elif FLAGS.mode == 'eval_gen':
eval_gen.eval_gen(FLAGS.config)
if __name__ == "__main__":
app.run(main)