-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
Executing the following main.py :
import lightning.pytorch
import torch
from datamodules.s2geo_dataset import S2GeoDataModule
from lightning.pytorch.cli import LightningCLI
from loss import SatCLIPLoss
from model import SatCLIP
torch.set_float32_matmul_precision('high')
class SatCLIPLightningModule(lightning.pytorch.LightningModule):
def __init__(
self,
embed_dim=512,
image_resolution=256,
vision_layers=12,
vision_width=768,
vision_patch_size=32,
in_channels=4,
le_type="grid",
pe_type="siren",
frequency_num=16,
max_radius=260,
min_radius=1,
legendre_polys=16,
harmonics_calculation="analytic",
sh_embedding_dims=32,
learning_rate=1e-4,
weight_decay=0.01,
num_hidden_layers=2,
capacity=256,
) -> None:
super().__init__()
self.model = SatCLIP(
embed_dim=embed_dim,
image_resolution=image_resolution,
vision_layers=vision_layers,
vision_width=vision_width,
vision_patch_size=vision_patch_size,
in_channels=in_channels,
le_type=le_type,
pe_type=pe_type,
frequency_num=frequency_num,
max_radius=max_radius,
min_radius=min_radius,
legendre_polys=legendre_polys,
harmonics_calculation=harmonics_calculation,
sh_embedding_dims=sh_embedding_dims,
num_hidden_layers=num_hidden_layers,
capacity=capacity,
)
self.loss_fun = SatCLIPLoss()
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.save_hyperparameters()
def common_step(self, batch, batch_idx):
images = batch["image"]
t_points = batch["point"].float()
logits_per_image, logits_per_coord = self.model(images, t_points)
return self.loss_fun(logits_per_image, logits_per_coord)
def training_step(self, batch, batch_idx):
loss = self.common_step(batch, batch_idx)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.common_step(batch, batch_idx)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
exclude = (
lambda n, p: p.ndim < 2
or "bn" in n
or "ln" in n
or "bias" in n
or "logit_scale" in n
)
include = lambda n, p: not exclude(n, p)
named_parameters = list(self.model.named_parameters())
gain_or_bias_params = [
p for n, p in named_parameters if exclude(n, p) and p.requires_grad
]
rest_params = [
p for n, p in named_parameters if include(n, p) and p.requires_grad
]
optimizer = torch.optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.0},
{
"params": rest_params,
"weight_decay": self.weight_decay,
}, # specify in configs/default.yaml
],
lr=self.learning_rate, # specify in configs/default.yaml
)
return optimizer
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_argument("--watchmodel", action="store_true")
def cli_main(default_config_filename="./configs/default.yaml"):
save_config_fn = default_config_filename.replace(".yaml", "-latest.yaml")
# modify configs/default.yaml for learning rate etc.
cli = MyLightningCLI(
model_class=SatCLIPLightningModule,
datamodule_class=S2GeoDataModule,
save_config_kwargs=dict(
config_filename=save_config_fn,
overwrite=True,
),
trainer_defaults={
"accumulate_grad_batches": 16,
"log_every_n_steps": 10,
},
parser_kwargs={"default_config_files": [default_config_filename]},
seed_everything_default=0,
run=False,
)
ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
run_name = f"SatCLIP_S2_{ts}"
if cli.trainer.logger is not None:
cli.trainer.logger.experiment.name = run_name
# this seems to be necessary to force logging of datamodule hyperparams
cli.trainer.logger.log_hyperparams(cli.datamodule.hparams)
# Create folder to log configs
# NOTE: Lightning does not handle config paths with subfolders
dirname_cfg = Path(default_config_filename).parent
dir_log_cfg = Path(cli.trainer.log_dir) / dirname_cfg
dir_log_cfg.mkdir(parents=True, exist_ok=True)
cli.trainer.fit(
model=cli.model,
datamodule=cli.datamodule,
)
if __name__ == "__main__":
config_fn = "./configs/default.yaml"
#A100 go vroom vroom 🚗💨
#if torch.cuda.get_device_name(device=0)=='NVIDIA A100 80GB PCIe':
#torch.backends.cuda.matmul.allow_tf32 = True
#print('Superfastmode! 🚀')
#elif:
#torch.backends.cuda.matmul.allow_tf32 = False
#else:
torch.backends.cpu
I get this error:
(.satclip) (base) root@WorldMap:~/VectorEmbeddingsFromGeoCoordinates/satclip/satclip# python main.py
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/main.py", line 162
torch.backends.cpu
IndentationError: unexpected indent
(.satclip) (base) root@WorldMap:~/VectorEmbeddingsFromGeoCoordinates/satclip/satclip# nano main.py
(.satclip) (base) root@WorldMap:~/VectorEmbeddingsFromGeoCoordinates/satclip/satclip# python main.py
2025-05-20 10:56:04.598684: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1747731364.714080 74468 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747731364.756910 74468 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747731364.919538 74468 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747731364.919665 74468 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747731364.919674 74468 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747731364.919680 74468 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-05-20 10:56:04.949589: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Seed set to 0
using vision transformer
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
No dataset found. To download, please follow instructions on: https://github.com/microsoft/satclip
/data/s2/index.csv missing
Traceback (most recent call last):
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/main.py", line 163, in <module>
cli_main(config_fn)
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/main.py", line 146, in cli_main
cli.trainer.fit(
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
call._call_and_handle_interrupt(
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 974, in _run
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 107, in _call_setup_hook
_call_lightning_datamodule_hook(trainer, "setup", stage=fn)
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/.satclip/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 198, in _call_lightning_datamodule_hook
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/datamodules/s2geo_dataset.py", line 52, in setup
dataset = S2Geo(root=self.data_dir, transform=self.train_transform, mode=self.mode)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/VectorEmbeddingsFromGeoCoordinates/satclip/satclip/datamodules/s2geo_dataset.py", line 111, in __init__
raise RuntimeError("Dataset not found or corrupted.")
RuntimeError: Dataset not found or corrupted.
Metadata
Metadata
Assignees
Labels
No labels