Skip to content

vectorizing Tree.sample()#1

Open
chanind wants to merge 2 commits intonoanabeshima:mainfrom
chanind:vectorized-tree-sampling
Open

vectorizing Tree.sample()#1
chanind wants to merge 2 commits intonoanabeshima:mainfrom
chanind:vectorized-tree-sampling

Conversation

@chanind
Copy link
Copy Markdown

@chanind chanind commented Dec 18, 2024

I really like your hierarchical feature setup via the Tree class, but found it ran a bit slow due to using Python loops and lists. This PR vectorizes more of the Tree.sample() function so it runs a lot faster. On my laptop, the original implementation takes 7 hours to go through 100 Million samples in batches of 10k, but this vectorized implementation runs 100 Million samples in 6 minutes (70x speedup). It's likely possible to get this to go faster with more effort, but figured this is already a good starting place.

This also changes the implementation so Tree.sample() always requires a batch_size param, and always returns a tensor, but I figure this is probably what's wanted in practice anyway. I also added some test coverage for the changes in Pytest, asserting that this implementation matches what's specified in the JSON file.

No worries if this doesn't fit into your vision of the project, but just figured I'd share this implementation in case it's helpful to others!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant