Skip to content

Leaky behaviour on running repeated simulations #13

@ianosd

Description

@ianosd

First: great package! It was of great help in my work on my master's project.
I found what looks like a memory leak. Here's a minimal example:

import jax
import jax.numpy as np
import numpy.random 
import psutil
import pychastic

def main():
    initial_samples = numpy.random.random((10000, 3))
    def drift_fn(X):
        return -X

    def noise_fn(_):
        return np.eye(3)

    solver = pychastic.sde_solver.SDESolver(dt=0.01)
    problem = pychastic.sde_problem.SDEProblem(
        a=drift_fn, b=noise_fn,
        x0=initial_samples,
        tmax=10)

    for _ in range(10):
        solver.solve_many(problem, n_trajectories=None, progress_bar=True)
        mem_usage_now = psutil.Process().memory_info().rss / 1024 ** 2
#      jax.clear_caches()
        print(f'Mem usage: {mem_usage_now} MB')

if __name__ == "__main__":
    main()

The output to this is:

  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 545.68359375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 234.484375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 233.8984375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 252.140625 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 262.68359375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 262.11328125 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 267.51171875 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 275.12109375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 280.82421875 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 277.37890625 MB

If you uncoment the jax.clear_caches() call, the output is:

  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 544.79296875 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 548.65234375 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 551.78515625 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 555.6796875 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 557.37109375 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 557.46875 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 557.22265625 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 558.96484375 MB
  0%|                                                                                           | 0/1000 [00:01<?, ?it/s]
Mem usage: 558.96875 MB
  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]
Mem usage: 559.33984375 MB

Maybe I should mention that I am running this on a CPU. The version of pychastic is 0.2.2 .

I might look into this these days, but maybe someone around here immediately sees what the issue could be, or what I am doing wrong.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions