Skip to content

add functionality required for this test load_collection & collect #633

@github-actions

Description

@github-actions

add functionality required for this test load_collection & collect

https://github.com/apache/incubator-wayang/blob/1e0f9e8166225176fe3022de5fbcce3dbcba96b9/python/src/pywy/tests/train_logistic_test.py#L24

#  limitations under the License.
#

import pytest

from pywy.dataquanta import WayangContext
from pywy.platforms.java import JavaPlugin
from pywy.platforms.spark import SparkPlugin

# TODO: add functionality required for this test load_collection & collect
@pytest.mark.skip(reason="no way of currently testing this, since we are missing implementations for load_collection & collect")
def test_train_and_predict():
    ctx = WayangContext().register({JavaPlugin, SparkPlugin})

    features = ctx.load_collection([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]])
    labels = ctx.load_collection([1.0, 1.0, 0.0, 0.0])

    model = features.train_logistic_regression(labels)
    predictions = model.predict(features)

    result = predictions.collect()
    print("Predictions:", result)

    assert len(result) is 4, f"Expected len(result) to be 4, but got: {len(result)}"
    for pred in result:
        assert pred in [0.0, 1.0], f"Expected prediction to be in [0.0, 0.1], but got: {pred}"

f35294bf9a27cc63281f965c1bf55a42419e4b35

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions