Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/live_calculation.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ Note that _string_ inputs must have quotes when used, but property names and _ex
- *property_criterion* (string): either `'max'` or `'min'` to pick out either the progenitor with
the maximum or minimum value of the target property

* `match_reduce(s, calculation, reduction)`: finds all linked objects in a given target simulation
or timestep, performs a calculation on each of them, then reduces the result in a specified way.
Inputs:
- *s* (string): the name of the simulation or timestep to link to
- *calculation* (expression): the calculation to perform on each matching object
- *reduction* (string): either `'min'`, `'max'`, `'mean'` or `'sum'`. Specifies how to
reduce multiple results to a single per input object.
* Redirection operator `.`: finds a property in the linked object, e.g. `find_progenitor(SFR, 'max').mass` gets `mass`
at the time of maximum `SFR`.

Expand Down
59 changes: 58 additions & 1 deletion tangos/live_calculation/builtin_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import tangos
from tangos.util import consistent_collection

from ... import core
from ... import core, temporary_halolist
from ...core import extraction_patterns
from .. import (
BuiltinFunction,
Calculation,
FixedInput,
FixedNumericInput,
LiveProperty,
Expand All @@ -29,6 +30,62 @@ def match(source_halos, target):
return np.array(results, dtype=object)
match.set_input_options(0, provide_proxy=True, assert_class = FixedInput)


@BuiltinFunction.register
def match_reduce(source_halos: list[core.halo.Halo],
target_name: str,
calculation: LiveProperty,
reduction: str):
"""Get the reduction (sum, mean, min, max) of the specified calculation over all objects linked in the
specified target timestep or simulation"""
if len(source_halos) == 0:
return []

calculation = calculation.name

if not isinstance(calculation, Calculation):
calculation = tangos.live_calculation.parser.parse_property_name(calculation)

reduction_map = {'sum': np.sum,
'mean': np.mean,
'min': lambda input: np.min(input) if len(input)>0 else None,
'max': lambda input: np.max(input) if len(input)>0 else None}
if reduction not in reduction_map.keys():
raise ValueError(f"Unsupported reduction '{reduction}' in match_reduce. Supported reductions are sum, mean, min, max.")

from ... import relation_finding

target = tangos.get_item(target_name, core.Session.object_session(source_halos[0]))
strategy = relation_finding.MultiSourceMultiHopStrategy(source_halos, target, one_match_per_input=False)

# using strategy.temp_table doesn't seem to offer access to the sources of the halos, so we
# take a pass through python. This also offers the opportunity to use a throw-away session
# for the onwards calculation. There may be more efficient ways to do all this.

all_halos = strategy.all()
all_sources = strategy.sources()

with core.Session() as session:
with temporary_halolist.temporary_halolist_table(session, [h.id for h in all_halos]) as tt:
target_halos_supplemented = calculation.supplement_halo_query(
temporary_halolist.halo_query(tt)
)
values, = calculation.values(target_halos_supplemented.all())

values_per_halo = [[] for _ in source_halos]
for source, value in zip(all_sources, values):
values_per_halo[source].append(value)

reduction_func = reduction_map[reduction]
return [reduction_func(vals) for vals in values_per_halo]



match_reduce.set_input_options(2, assert_class=FixedInput, provide_proxy=True)
match_reduce.set_input_options(0, assert_class=FixedInput, provide_proxy=True)
match_reduce.set_input_options(1, assert_class=Calculation, provide_proxy=True)


@BuiltinFunction.register
def later(source_halos, num_steps):
timestep = consistent_collection.ConsistentCollection(source_halos).timestep.get_next(num_steps)
Expand Down
4 changes: 4 additions & 0 deletions tangos/live_calculation/builtin_functions/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def logical_or(halos, vals1, vals2):
def logical_not(halos, vals):
return arithmetic_unary_op(vals, np.logical_not)

@BuiltinFunction.register
def negate(halos, vals):
return arithmetic_unary_op(vals, np.negative)

@BuiltinFunction.register
def power(halos, vals1, vals2):
return arithmetic_binary_op(vals1, vals2, np.power)
Expand Down
4 changes: 2 additions & 2 deletions tangos/live_calculation/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def pack_args(for_function):
(">=", "greater_equal"),
("<=", "less_equal")]

UNARY_OPS = [("!", "logical_not")]
UNARY_OPS = [("!", "logical_not"),
("-", "negate")]

IN_OPS_PYPARSING = []
UNARY_OPS_PYPARSING = []
Expand Down Expand Up @@ -70,7 +71,6 @@ def generate_property_from_inop(opFunctionName, tokens):

redirection = pp.Forward().setParseAction(pack_args(Link))


element_identifier = pp.Literal("[").suppress()+numerical_value+pp.Literal("]").suppress();

multiple_properties = pp.Forward().setParseAction(pack_args(MultiCalculation))
Expand Down
5 changes: 5 additions & 0 deletions tests/test_live_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def test_nested_abs_at_function():
# n.b. for J_dm_enc
assert np.allclose(halo.calculate("abs(at(3.0,dummy_property_2))"), 15.0*np.sqrt(3))

def test_unary_minus_function():
halo = tangos.get_halo("sim/ts1/1")
assert np.allclose(halo.calculate("-dummy_property_3"), 2.5)
assert np.allclose(halo.calculate("--dummy_property_3"), -2.5)

def test_abcissa_passing_function():
"""In this example, the x-coordinates need to be successfully passed "through" the abs function for the
at function to return the correct result."""
Expand Down
85 changes: 85 additions & 0 deletions tests/test_live_calculation_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
from pytest import raises as assert_raises

import tangos as db
import tangos.testing as testing
import tangos.testing.simulation_generator


def setup_module():
testing.init_blank_db_for_testing()

generator = tangos.testing.simulation_generator.SimulationGeneratorForTests()
generator.add_timestep()
ts1_h1, ts1_h2, ts1_h3, ts1_h4, ts1_h5 = generator.add_objects_to_timestep(5)
generator.add_timestep()
ts2_h1, ts2_h2, ts2_h3 = generator.add_objects_to_timestep(3)


generator.link_last_halos_using_mapping({1: 1, 2: 3, 3: 2, 4: 3, 5: 2})

generator.add_timestep()
ts3_h1, = generator.add_objects_to_timestep(1)
generator.link_last_halos()

ts1_h1['val'] = 1.0
ts1_h2['val'] = 2.0
ts1_h3['val'] = 3.0
ts1_h4['val'] = 4.0
ts1_h5['val'] = 5.0

ts2_h1['val'] = 10.0
ts2_h2['val'] = 20.0
ts2_h3['val'] = 30.0

ts3_h1['val'] = 100.0

db.core.get_default_session().commit()

def teardown_module():
tangos.core.close_db()

def test_reduce_function():
results, = db.get_timestep('sim/ts2').calculate_all('match_reduce("sim/ts1", val * 2.0, "min")')
assert results[0] == 2.0
assert results[1] == 6.0
assert results[2] == 4.0

def test_reduce_min():
results, = db.get_timestep('sim/ts2').calculate_all('match_reduce("sim/ts1", val, "min")')
assert results[0] == 1.0
assert results[1] == 3.0
assert results[2] == 2.0

def test_reduce_max():
results, = db.get_timestep('sim/ts2').calculate_all('match_reduce("sim/ts1", val, "max")')
assert results[0] == 1.0
assert results[1] == 5.0
assert results[2] == 4.0

def test_reduce_mean():
results, = db.get_timestep('sim/ts2').calculate_all('match_reduce("sim/ts1", val, "mean")')
assert results[0] == 1.0
assert results[1] == 4.0
assert results[2] == 3.0

def test_reduce_sum():
results, = db.get_timestep('sim/ts2').calculate_all('match_reduce("sim/ts1", val, "sum")')
assert results[0] == 1.0
assert results[1] == 8.0
assert results[2] == 6.0

def test_unsupported_reduce():
with assert_raises(ValueError):
db.get_timestep('sim/ts2').calculate_all('match_reduce("sim/ts1", val, "unsupported")')

def test_reduce_no_matches():
hnum, results, = db.get_timestep('sim/ts2').calculate_all('halo_number()',
'match_reduce("sim/ts3", val, "max")')
assert results == [100.0]
assert hnum == [1]

hnum, results, = db.get_timestep('sim/ts2').calculate_all('halo_number()',
'match_reduce("sim/ts3", val, "sum")')
assert np.allclose(results, [100.0, 0.0, 0.0])
assert (hnum == [1, 2, 3]).all()
7 changes: 5 additions & 2 deletions tests/test_simulation_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ def test_handler_properties():
npt.assert_allclose(prop['approx_resolution_Msol'], 144411.17640)

def test_handler_properties_quicker_flag():
output_manager.quicker = True
prop = output_manager.get_properties()
try:
output_manager.quicker = True
prop = output_manager.get_properties()
finally:
output_manager.quicker = False
npt.assert_allclose(prop['approx_resolution_kpc'], 33.590757, rtol=1e-5)
npt.assert_allclose(prop['approx_resolution_Msol'], 2.412033e+10, rtol=1e-4)

Expand Down