From a640b1ecb644e8c85b1baa12da1bf2996de9ef45 Mon Sep 17 00:00:00 2001 From: Raunak Ramakrishnan Date: Mon, 8 Jan 2024 11:36:46 +0530 Subject: [PATCH] WIP --- wasp/src/main/java/rrampage/wasp/Runner.java | 18 ++-- .../java/rrampage/wasp/data/FunctionType.java | 43 ++++++++ .../rrampage/wasp/utils/FunctionUtils.java | 97 +++++++++++++++++++ 3 files changed, 152 insertions(+), 6 deletions(-) create mode 100644 wasp/src/main/java/rrampage/wasp/utils/FunctionUtils.java diff --git a/wasp/src/main/java/rrampage/wasp/Runner.java b/wasp/src/main/java/rrampage/wasp/Runner.java index a27562e..823fd3f 100644 --- a/wasp/src/main/java/rrampage/wasp/Runner.java +++ b/wasp/src/main/java/rrampage/wasp/Runner.java @@ -1,21 +1,27 @@ package rrampage.wasp; -import rrampage.wasp.data.FunctionType; -import rrampage.wasp.data.ValueType; import rrampage.wasp.parser.WasmParser; -import rrampage.wasp.utils.ImportUtils; +import rrampage.wasp.utils.FunctionUtils; import java.nio.file.Paths; import java.util.Map; +import java.util.function.IntUnaryOperator; public class Runner { - public static void main(String[] args) throws Exception { + public static void main(String[] args) throws Throwable { // TODO Load and run wasm file. How to handle imports? // (import "host" "print" (func $hprint (param i32) (result i32))) - String fileName = "./examples/add_two.wasm"; + String fileName = "./examples/call_indirect_example.wasm"; String path = Paths.get(fileName).toAbsolutePath().normalize().toString(); var parser = WasmParser.fromFile(path); var module = parser.parseModule(); - var machine = module.instantiate(Map.of("host", Map.of("print", ImportUtils.generateLoggerHandle(new FunctionType(new ValueType.NumType[]{ValueType.NumType.I32}, new ValueType.NumType[]{ValueType.NumType.I32}))))); + var lambda = (IntUnaryOperator) (i) -> i*3; + var mh = FunctionUtils.generateMethodHandle(lambda); + var mh2 = FunctionUtils.generateMethodHandleForLambda(lambda); + System.out.println(mh.type()); + System.out.println(mh2.type()); + System.out.println(mh2.invoke(9)); + var machine = module.instantiate(Map.of("env", Map.of("jstimes3", mh))); + machine.start(); } } diff --git a/wasp/src/main/java/rrampage/wasp/data/FunctionType.java b/wasp/src/main/java/rrampage/wasp/data/FunctionType.java index f72a827..13ab779 100644 --- a/wasp/src/main/java/rrampage/wasp/data/FunctionType.java +++ b/wasp/src/main/java/rrampage/wasp/data/FunctionType.java @@ -87,6 +87,49 @@ public static FunctionType getBlockType(int blockType) { }; } + public static ValueType getDataTypeFromClass(Class clazz) { + if (clazz == void.class || clazz == Void.class) { + return null; + } + if (clazz == int.class || clazz == Integer.class || + clazz == byte.class || clazz == Byte.class || + clazz == short.class || clazz == Short.class || + clazz == char.class || clazz == Character.class || + clazz == boolean.class || clazz == Boolean.class + ) { + return I32; + } + if (clazz == float.class || clazz == Float.class) { + return F32; + } + if (clazz == long.class || clazz == Long.class) { + return I64; + } + if (clazz == double.class || clazz == Double.class) { + return F64; + } + // TODO : Check for Wrapper classes as well + if (clazz.isArray()) { + // we can't translate an Object[] or primitive [] to correct number of elements + var c = clazz.arrayType(); + } + // TODO : RefType , VecType + return null; + } + + public static FunctionType getFunctionTypeFromMethodType(MethodType mt) { + var rc = mt.returnType(); + ValueType rt = getDataTypeFromClass(rc); + ValueType[] params = new ValueType[mt.parameterCount()]; + var rps = mt.parameterArray(); + for (int i = 0; i < params.length; i++) { + params[i] = getDataTypeFromClass(rps[i]); + } + return new FunctionType(params, new ValueType[]{rt}); + } + + // Constants for less boilerplate when creating function types + public static final FunctionType VOID = new FunctionType(null, null); public static final FunctionType I32_RETURN = new FunctionType(null, new NumType[]{I32}); public static final FunctionType I64_RETURN = new FunctionType(null, new NumType[]{I64}); diff --git a/wasp/src/main/java/rrampage/wasp/utils/FunctionUtils.java b/wasp/src/main/java/rrampage/wasp/utils/FunctionUtils.java new file mode 100644 index 0000000..1e714eb --- /dev/null +++ b/wasp/src/main/java/rrampage/wasp/utils/FunctionUtils.java @@ -0,0 +1,97 @@ +package rrampage.wasp.utils; + +import java.lang.invoke.*; +import java.lang.reflect.Method; +import java.util.function.*; + +public class FunctionUtils { + private BiConsumer createVoidHandlerLambda(Object bean, Method method) throws Throwable { + MethodHandles.Lookup caller = MethodHandles.lookup(); + CallSite site = LambdaMetafactory.metafactory(caller, + "accept", + MethodType.methodType(BiConsumer.class), + MethodType.methodType(void.class, Object.class, Object.class), + caller.findVirtual(bean.getClass(), method.getName(), + MethodType.methodType(void.class, method.getParameterTypes()[0])), + MethodType.methodType(void.class, bean.getClass(), method.getParameterTypes()[0])); + MethodHandle factory = site.getTarget(); + BiConsumer listenerMethod = (BiConsumer) factory.invoke(); + return listenerMethod; + } + + public static F createWrapper(final MethodHandles.Lookup lookup, + final MethodHandle original, + final String lambdaName, + final Class interfaceType, + final Method wrapperMethod) throws Exception{ + final CallSite site = LambdaMetafactory.metafactory(lookup, // MethodHandles.Lookup + lambdaName, // Name of lambda method from interface + MethodType.methodType(interfaceType), // MethodType of interface + MethodType.methodType(wrapperMethod.getReturnType(), wrapperMethod.getParameterTypes()), //Signature of wrapper lambda + original, // MethodHandle + original.type()); //Actual signature of method + try { + return (F) site.getTarget().invoke(); + } catch (final Exception e) { + throw e; + } catch (final Throwable e) { + throw new Error(e); + } + } + + public static void lambdaMetaFactoryExample() throws Throwable { + // from https://wttech.blog/blog/2020/method-handles-and-lambda-metafactory/ + String toBeTrimmed = " text with spaces "; + Method reflectionMethod = String.class.getMethod("trim"); + MethodHandles.Lookup lookup = MethodHandles.lookup(); + MethodHandle handle = lookup.unreflect(reflectionMethod); + CallSite callSite = LambdaMetafactory.metafactory( + // method handle lookup + lookup, + // name of the method defined in the target functional interface + "get", + // type to be implemented and captured objects + // in this case the String instance to be trimmed is captured + MethodType.methodType(Supplier.class, String.class), + // type erasure, Supplier will return an Object + MethodType.methodType(Object.class), + // method handle to transform + handle, + // Supplier method real signature (reified) + // trim accepts no parameters and returns String + MethodType.methodType(String.class)); + Supplier lambda = (Supplier) callSite.getTarget().bindTo(toBeTrimmed).invoke(); + } + + public static MethodHandle generateMethodHandle(IntUnaryOperator op) { + try { + var method = op.getClass().getMethod("applyAsInt", int.class); + return getLookup(op.getClass()).unreflect(method).bindTo(op); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static MethodHandle generateMethodHandleForLambda(F lambda) { + try { + Class clazz = lambda.getClass(); + Method[] methods = clazz.getDeclaredMethods(); + // As lambdas should only have one declared method + if (methods.length != 1) { + throw new RuntimeException("Invalid lambda passed"); + } + return getLookup(lambda.getClass()).unreflect(methods[0]).bindTo(lambda); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static MethodHandles.Lookup getLookup(Class targetClass) { + MethodHandles.Lookup lookupMe = MethodHandles.lookup(); + try { + return MethodHandles.privateLookupIn(targetClass, lookupMe); + } catch (IllegalAccessException e) { + return lookupMe; + } + } +}