Implement the bridge to call Tatoo from java. Very slow at the moment.
[tatoo.git] / src / bindings / java / TatooEngine.cc
diff --git a/src/bindings/java/TatooEngine.cc b/src/bindings/java/TatooEngine.cc
new file mode 100644 (file)
index 0000000..624c2df
--- /dev/null
@@ -0,0 +1,696 @@
+#include "fxslt_memory_TatooEngine.h"
+#include "fxslt_memory_TatooEngine_Automaton.h"
+#include "fxslt_memory_TatooEngine_Tree.h"
+#include "fxslt_memory_TatooEngine_CustomBlock.h"
+#include "tatoo.h"
+
+#include <cassert>
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include <caml/alloc.h>
+#include <caml/callback.h>
+#include <caml/memory.h>
+extern "C" {
+#include <caml/threads.h>
+}
+
+namespace jni {
+namespace priv {
+
+enum { JNI_VERSION = JNI_VERSION_1_2 };
+
+JNIEnv *current_env;
+
+template<typename T>
+class Integer;
+
+template<typename T>
+struct Traits;
+
+#define MAKE_TRAIT(T, N)                                                                    \
+template<>                                                                                  \
+struct Traits<T> {                                                                          \
+    static T (JNIEnv::*call)(jobject, jmethodID, va_list);                                  \
+};                                                                                          \
+T (JNIEnv::* Traits<T>::call)(jobject, jmethodID, va_list) = &JNIEnv::Call ## N ## MethodV
+
+MAKE_TRAIT(jobject, Object);
+MAKE_TRAIT(jint, Int);
+MAKE_TRAIT(jshort, Short);
+MAKE_TRAIT(jboolean, Boolean);
+
+}
+
+JNIEnv &env() throw() {
+    assert(priv::current_env);
+    return *priv::current_env;
+}
+
+class scoped_env {
+public:
+    scoped_env(JNIEnv *env) throw()
+    {
+        assert(not priv::current_env);
+        priv::current_env = env;
+    }
+
+    scoped_env(JavaVM *vm) throw(jint)
+    {
+        assert(not priv::current_env);
+        if(vm->GetEnv(reinterpret_cast<void **>(&priv::current_env), priv::JNI_VERSION) != JNI_OK)
+            throw jint(-1);
+    }
+
+    ~scoped_env() throw()
+    {
+        assert(priv::current_env);
+        priv::current_env = NULL;
+    }
+};
+
+
+
+struct MemberDesc {
+    const char *name;
+    const char *signature;
+};
+
+struct ClassDesc {
+    const char *name;
+    const std::vector<MemberDesc> methods;
+    const std::vector<MemberDesc> fields;
+};
+
+class Class {
+public:
+
+private:
+    typedef std::vector<jmethodID> Methods;
+    typedef std::vector<jfieldID> Fields;
+
+    static jclass get_class(const char *name) throw(jint)
+    {
+        jclass c;
+        if((c = env().FindClass(name)) == NULL)
+            throw jint(-1);
+        if((c = static_cast<jclass>(env().NewGlobalRef(c))) == NULL)
+            throw jint(-1);
+        return c;
+    }
+
+    static Methods
+    get_methods(const jclass class_, const std::vector<MemberDesc> &methods) throw()
+    {
+        Methods ret(methods.size());
+        size_t i = 0;
+        for(auto it = methods.begin(); it != methods.end(); ++it, ++i)
+            ret[i] = env().GetMethodID(class_, it->name, it->signature);
+        return ret;
+    }
+
+    static Fields
+    get_fields(const jclass class_, const std::vector<MemberDesc> &fields) throw()
+    {
+        Fields ret(fields.size());
+        size_t i = 0;
+        for(auto it = fields.begin(); it != fields.end(); ++it, ++i)
+            ret[i] = env().GetFieldID(class_, it->name, it->signature);
+        return ret;
+    }
+
+public:
+    const jclass class_;
+    const Methods methods;
+    const Fields fields;
+
+    Class(const ClassDesc &desc) throw(jint)
+        : class_(get_class(desc.name)), methods(get_methods(class_, desc.methods)),
+          fields(get_fields(class_, desc.fields))
+    { }
+
+    ~Class() throw() { env().DeleteGlobalRef(class_); }
+
+    jboolean IsInstanceOf(jobject obj) const throw()
+    { return env().IsInstanceOf(obj, class_); }
+};
+
+template<typename C>
+class Object {
+protected:
+    typedef Object Base;
+    static ClassDesc desc;
+    static Class *class_;
+
+       Object(jobject this_) throw() : this_(this_) {
+               assert(class_->IsInstanceOf(this_));
+#if 0
+               if (!class_->IsInstanceOf(this_)) {
+
+                       jclass object = env().FindClass("java/lang/Object");
+                       jmethodID getClass_id = env().GetMethodID(object, "getClass", "()Ljava/lang/Class;");
+                       jobject oclass = env().CallObjectMethod(this_, getClass_id);
+                       jclass cls = env().FindClass("java/lang/Class");
+                       jmethodID getName_id = env().GetMethodID(cls, "getName", "()Ljava/lang/String;");
+                       jstring name = static_cast<jstring>(env().CallObjectMethod(oclass, getName_id));
+                       fprintf(stderr, "ERROR: class: %s is not an instance of %s\n", desc.name, jni::env().GetStringUTFChars(name, NULL));
+
+
+                       assert(class_->IsInstanceOf(this_));
+               };
+#endif
+       }
+    template<typename T>
+    T call(int method_id, ...) const
+    {
+        va_list vl;
+        va_start(vl, method_id);
+        T ret = (env().*priv::Traits<T>::call)(this_, class_->methods[method_id], vl);
+        va_end(vl);
+        return ret;
+    }
+    template<typename T>
+    static inline T static_call(jobject j, int method_id, ...) throw ()
+    {
+           va_list vl;
+           va_start(vl, method_id);
+           T ret = (env().*priv::Traits<T>::call)(j, class_->methods[method_id], vl);
+           va_end(vl);
+           return ret;
+    }
+public:
+    const jobject this_;
+
+    static void initialize() throw(jint) { class_ = new Class(desc); }
+    static void finalize() throw() { delete class_; class_ = NULL; }
+
+    static const Class& get_class() { return *class_; }
+};
+template<typename C>
+jni::Class *jni::Object<C>::class_ = NULL;
+
+typedef priv::Integer<jint> Integer;
+typedef priv::Integer<jshort> Short;
+
+template<>
+jni::ClassDesc jni::Object<Integer>::desc = {
+    "java/lang/Integer", {
+        { "intValue", "()I" },
+        { "<init>", "(I)V" }
+    }, { }
+};
+template<>
+jni::ClassDesc jni::Object<Short>::desc = {
+    "java/lang/Short", {
+        { "shortValue", "()S" },
+    }, { }
+};
+
+template<typename T>
+class priv::Integer: public Object<Integer<T>> {
+    enum Methods { valueID, initID };
+    typedef Object<Integer<T>> Base;
+public:
+    Integer(jobject this_) throw() : Base(this_) { }
+    Integer(jint i) throw()
+           : Base(jni::env().NewObject(Base::class_->class_, Base::class_->methods[initID], i))
+    { }
+    T operator*() const throw() { return Base::template call<T>(valueID); }
+};
+
+class String {
+private:
+    String(const String &) = delete;
+    String& operator=(const String &) = delete;
+
+    mutable const char *c_str_;
+public:
+    const jstring this_;
+
+    String(jstring this_) throw() : c_str_(NULL), this_(this_) { }
+    String(String &&rhs) throw() : c_str_(rhs.c_str_), this_(rhs.this_) { rhs.c_str_ = NULL; }
+    ~String() throw() {
+        if(c_str_)
+            env().ReleaseStringUTFChars(this_, c_str_);
+    }
+
+    const char* c_str() const throw()
+    {
+        if(c_str_)
+            return c_str_;
+
+        return c_str_ = env().GetStringUTFChars(this_, NULL);
+    }
+};
+
+} // namespace jni
+
+class Node;
+class Attr;
+class NodeList;
+class NamedNodeMap;
+class MutableNodeList;
+class CustomBlock;
+
+template<>
+jni::ClassDesc jni::Object<Node>::desc = {
+    "org/w3c/dom/Node", {
+        { "getFirstChild", "()Lorg/w3c/dom/Node;" },
+        { "getNextSibling", "()Lorg/w3c/dom/Node;" },
+        { "getNodeName", "()Ljava/lang/String;" },
+        { "getNodeValue", "()Ljava/lang/String;" },
+        { "getUserData", "(Ljava/lang/String;)Ljava/lang/Object;" },
+        { "setUserData", "(Ljava/lang/String;Ljava/lang/Object;Lorg/w3c/dom/UserDataHandler;)Ljava/lang/Object;" },
+        { "getNodeType", "()S" },
+        { "getAttributes", "()Lorg/w3c/dom/NamedNodeMap;" }
+    }, { }
+};
+
+class Node: public jni::Object<Node> {
+    enum Methods {
+           getFirstChildID, getNextSiblingID, getNodeNameID, getNodeValueID, getUserDataID, setUserDataID, getNodeTypeID,
+        getAttributesID
+    };
+
+    static jni::String *empty_key;
+public:
+    static void initialize() throw(jint)
+    {
+        Base::initialize();
+        empty_key = new jni::String(static_cast<jstring>(
+                    jni::env().NewGlobalRef(jni::env().NewStringUTF("")) ));
+    }
+    static void finalize() throw()
+    {
+        jni::env().DeleteGlobalRef(empty_key->this_);
+        delete empty_key;
+        empty_key = NULL;
+        Base::finalize();
+    }
+
+    Node(jobject this_) throw() : Base(this_) { }
+
+    Node getFirstChild() const throw() { return  Node(call<jobject>(getFirstChildID)); }
+    static inline jobject getFirstChildO(jobject obj) throw () {
+               return static_call<jobject>(obj, getFirstChildID);
+       }
+    static inline jobject getNextSiblingO(jobject obj) throw () {
+               return static_call<jobject>(obj, getNextSiblingID);
+       }
+
+    Node getNextSibling() const throw() { return Node(call<jobject>(getNextSiblingID)); }
+    jshort getNodeType() const throw() { return call<jshort>(getNodeTypeID); }
+
+    jni::String getNodeName() const throw()
+    { return jni::String(static_cast<jstring>(call<jobject>(getNodeNameID))); }
+
+    jni::String getNodeValue() const throw()
+    { return jni::String(static_cast<jstring>(call<jobject>(getNodeValueID))); }
+
+
+    jint getPreorder() const throw()
+    {
+           jobject data = call<jobject>(getUserDataID, empty_key->this_);
+           return *jni::Integer(data); }
+    static inline jobject getPreorderO(jobject obj) throw () {
+               return static_call<jobject>(obj, getNextSiblingID);
+       }
+    void setPreorder(jint i) const throw()
+    {
+      call<jobject>(setUserDataID, empty_key->this_, jni::Integer(i), NULL);
+    }
+
+    NamedNodeMap getAttributes() const throw();
+};
+jni::String *Node::empty_key = NULL;
+
+/********** Attr *************/
+template<>
+jni::ClassDesc jni::Object<Attr>::desc = {
+    "org/w3c/dom/Attr", {
+        { "getOwnerElement", "()Lorg/w3c/dom/Element;" }
+    }, { }
+};
+
+class Attr: public jni::Object<Attr> {
+    enum Methods {
+           getOwnerElementID
+    };
+
+public:
+
+    Attr(jobject this_) throw() : Base(this_) { }
+
+    Node getOwnerElement() const throw() { return Node(call<jobject>(getOwnerElementID)); }
+};
+
+/********** NodeList **********/
+template<>
+jni::ClassDesc jni::Object<NodeList>::desc = {
+    "org/w3c/dom/NodeList", {
+        { "getLength", "()I" },
+        { "item", "(I)Lorg/w3c/dom/Node;" }
+    }, { }
+};
+
+class NodeList: public jni::Object<NodeList> {
+    enum Methods { getLengthID, itemID };
+
+public:
+    NodeList(jobject this_) throw() : Base(this_) { }
+
+    jint getLength() const throw() { return call<jint>(getLengthID); }
+    Node item(jint i) const throw() { return Node(call<jobject>(itemID, i)); }
+};
+
+
+/********** NamedNodeMap **********/
+template<>
+jni::ClassDesc jni::Object<NamedNodeMap>::desc = {
+    "org/w3c/dom/NamedNodeMap", {
+        { "getLength", "()I" },
+        { "item", "(I)Lorg/w3c/dom/Node;" }
+    }, { }
+};
+
+class NamedNodeMap: public jni::Object<NamedNodeMap> {
+    enum Methods { getLengthID, itemID };
+
+public:
+    NamedNodeMap(jobject this_) throw() : Base(this_) { }
+
+    jint getLength() const throw() { return call<jint>(getLengthID); }
+    Node item(jint i) const throw() { return Node(call<jobject>(itemID, i)); }
+};
+
+
+template<>
+jni::ClassDesc jni::Object<MutableNodeList>::desc = {
+    "fxslt/memory/MutableNodeList", {
+        { "add", "(Lorg/w3c/dom/Node;)V" },
+        { "<init>", "()V" }
+    }, { }
+};
+
+class MutableNodeList: public jni::Object<MutableNodeList> {
+    enum Methods { addID, initID };
+
+public:
+    MutableNodeList(jobject this_) throw() : Base(this_) { }
+    MutableNodeList() throw()
+        : Base(jni::env().NewObject(class_->class_, class_->methods[initID]))
+    { }
+
+   void add(Node n) throw() { call<jobject>(addID, n.this_); }
+};
+
+
+NamedNodeMap Node::getAttributes() const throw()
+{
+       return NamedNodeMap(call<jobject>(getAttributesID));
+}
+
+
+
+template<>
+jni::ClassDesc jni::Object<CustomBlock>::desc = {
+    "fxslt/memory/TatooEngine$CustomBlock", {
+        { "<init>", "(J)V" }
+    }, {
+        { "value_ptr", "J" }
+    }
+};
+
+class CustomBlock: public jni::Object<CustomBlock> {
+    enum Methods { initID };
+    enum Fields { valueID };
+
+    static_assert(sizeof(jlong) <= sizeof(value *), "We use jlong to store pointers.");
+
+    value* get() const throw()
+    { return reinterpret_cast<value *>(jni::env().GetLongField(this_, class_->fields[valueID])); }
+public:
+    CustomBlock(jobject this_) throw() : Base(this_) {   }
+
+    CustomBlock(value val) throw()
+        : Base(jni::env().NewObject(class_->class_, class_->methods[initID], new value(val)))
+    { caml_register_generational_global_root(get()); }
+
+    value getValue() const throw() { return *get(); }
+};
+
+JNIEXPORT void JNICALL
+Java_fxslt_memory_TatooEngine_unregister  (JNIEnv *env, jclass cls, jlong value_ptr)
+{
+       value * vptr = reinterpret_cast<value *>(value_ptr);
+       caml_remove_generational_global_root(vptr);
+       delete vptr;
+}
+
+static value *init_document;
+static value *xpath_compile;
+static value *auto_evaluate;
+
+JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *)
+{
+    try {
+        jni::scoped_env se(vm);
+        jni::Integer::initialize();
+       Attr::initialize();
+        Node::initialize();
+        NodeList::initialize();
+        NamedNodeMap::initialize();
+       MutableNodeList::initialize();
+        CustomBlock::initialize();
+
+        char *argv[] = { NULL };
+        caml_startup(argv);
+
+        init_document = caml_named_value("init_document"); assert(init_document);
+        xpath_compile = caml_named_value("xpath_compile"); assert(xpath_compile);
+        auto_evaluate = caml_named_value("auto_evaluate"); assert(auto_evaluate);
+        caml_release_runtime_system();
+    }
+    catch(jint e) {
+        return e;
+    }
+
+    return jni::priv::JNI_VERSION;
+}
+
+JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *)
+{
+    try {
+        jni::scoped_env se(vm);
+        MutableNodeList::finalize();
+        NamedNodeMap::finalize();
+        NodeList::finalize();
+        Node::finalize();
+       Attr::finalize();
+        jni::Integer::finalize();
+    }
+    catch(jint e) {
+        fprintf(stderr, "Critical error unloading shared library.\n");
+    }
+}
+
+static jobject extract(const value &val)
+{ return reinterpret_cast<jobject>(val); }
+
+static value pack(jobject obj)
+{
+//    static_assert(sizeof(uintptr_t) <= sizeof(long), "We need long to hold pointers.");
+
+//    uintptr_t p = reinterpret_cast<uintptr_t>(obj);
+//    assert(! (p & 1));
+       return reinterpret_cast<value> (obj);
+}
+
+
+JNIEXPORT jobject JNICALL
+Java_fxslt_memory_TatooEngine_init_1document(JNIEnv *env, jclass TatooEngine, jobject node, jint i)
+{
+       CAMLparam0();
+       CAMLlocal1(val);
+       try {
+               jni::scoped_env se(env);
+               node = jni::env().NewGlobalRef(node);
+               val = caml_callback2(*init_document, pack(node), Val_int(i));
+               auto t = jni::env().NewGlobalRef(CustomBlock(val).this_);
+               CAMLreturnT(auto, t);
+       } catch(jint e) {
+               fprintf(stderr, "Critical error while initializing the document.\n");
+               CAMLreturnT(jobject, NULL);
+    }
+
+}
+
+
+JNIEXPORT jobject JNICALL
+Java_fxslt_memory_TatooEngine_compile(JNIEnv *env, jclass TatooEngine, jstring xpath)
+{
+    CAMLparam0();
+    CAMLlocal1(val);
+
+    try {
+        jni::scoped_env se(env);
+
+        val = caml_callback(*xpath_compile, caml_copy_string(jni::String(xpath).c_str()));
+       auto a = CustomBlock(val).this_;
+       a = jni::env().NewGlobalRef(a);
+        CAMLreturnT(auto, a);
+    }
+    catch(jint e) {
+        fprintf(stderr, "Critical error while compiling.\n");
+        CAMLreturnT(jobject, NULL);
+    }
+}
+
+JNIEXPORT jobject JNICALL
+Java_fxslt_memory_TatooEngine_evaluate(JNIEnv *env, jclass TatooEngine,
+                                      jobject automaton, jobject tree, jobject node_list)
+{
+    CAMLparam0();
+    CAMLlocal4(res, vauto, vtree, vnl);
+
+    try {
+        jni::scoped_env se(env);
+       vauto = CustomBlock(automaton).getValue();
+       vtree = CustomBlock(tree).getValue();
+       vnl = pack(node_list);
+
+        res = caml_callback3(*auto_evaluate, vauto, vtree, vnl);
+        CAMLreturnT(auto, extract(res));
+    } catch(jint e) {
+        fprintf(stderr, "Critical error while evaluating.\n");
+        CAMLreturnT(jobject, NULL);
+    }
+}
+
+#define GR (node) (jni::env().NewGlobalRef((node)))
+
+#if 0
+#define CHECK_EXCEPTION()  do {                                \
+       if (jni::env().ExceptionCheck() == JNI_TRUE) {  \
+       jni::env().ExceptionDescribe();                 \
+       assert(false);                                  \
+       }                                               \
+       } while (0)
+#else
+#define CHECK_EXCEPTION()
+#endif
+
+extern "C" {
+CAMLprim value node_getFirstChild(value node)
+{ CAMLparam1(node);
+       CHECK_EXCEPTION();
+
+       CAMLreturn(pack(Node::getFirstChildO(extract(node))));
+       //CAMLreturn(pack(Node(extract(node)).getFirstChild().this_));
+}
+
+CAMLprim value node_getNextSibling(value node)
+{       CAMLparam1(node);
+       CHECK_EXCEPTION();
+       CAMLreturn(pack(Node::getNextSiblingO(extract(node))));
+       //CAMLreturn(pack(Node(extract(node)).getNextSibling().this_));
+}
+
+CAMLprim value node_getNodeType(value node)
+{ CAMLparam1(node);
+       CHECK_EXCEPTION();
+       CAMLreturn(Val_int(Node(extract(node)).getNodeType())); }
+
+CAMLprim value node_getNodeName(value node)
+{ CAMLparam1(node);
+       CHECK_EXCEPTION();
+       jstring obj = Node(extract(node)).getNodeName().this_;
+       value cstr = caml_copy_string(jni::env().GetStringUTFChars(obj, NULL));
+       CAMLreturn(cstr);
+}
+
+CAMLprim value node_getPreorder(value node)
+{ CAMLparam1(node);
+       CHECK_EXCEPTION();
+       CAMLreturn(Val_int(Node(extract(node)).getPreorder()));
+ }
+
+CAMLprim value node_setPreorder(value node, value i)
+{
+       CAMLparam1(node);
+       CHECK_EXCEPTION();
+       Node(extract(node)).setPreorder(Int_val(i));
+       CAMLreturn(Val_unit);
+}
+
+CAMLprim value print_runtime_class(value o)
+{
+   CAMLparam1(o);
+   CHECK_EXCEPTION();
+   jclass object = jni::env().FindClass("java/lang/Object");
+   jmethodID getClass_id = jni::env().GetMethodID(object, "getClass", "()Ljava/lang/Class;");
+   jobject oclass = jni::env().CallObjectMethod(extract(o), getClass_id);
+   jclass cls = jni::env().FindClass("java/lang/Class");
+   jmethodID getName_id = jni::env().GetMethodID(cls, "getName", "()Ljava/lang/String;");
+   jstring name = static_cast<jstring>(jni::env().CallObjectMethod(oclass, getName_id));
+   fprintf(stderr, "CLASS OF ATTTRIBUTE IS %s \n", jni::env().GetStringUTFChars(name, NULL));
+   fflush(stderr);
+   CAMLreturn(Val_unit);
+}
+CAMLprim value attr_getOwnerElement(value node)
+{
+       CAMLparam1(node);
+       CHECK_EXCEPTION();
+       auto attr = Attr(extract(node));
+       CAMLreturn(pack(attr.getOwnerElement().this_));
+ }
+
+CAMLprim value node_getAttributes(value node)
+{ CAMLparam1(node);
+       CHECK_EXCEPTION();
+       CAMLreturn(pack(Node(extract(node)).getAttributes().this_)); }
+
+CAMLprim value nodelist_getLength(value list)
+{ CAMLparam1(list);
+       CHECK_EXCEPTION();
+       CAMLreturn(Val_int(NodeList(extract(list)).getLength())); }
+
+CAMLprim value nodelist_item(value list, value idx)
+{
+    CAMLparam2(list, idx);
+    CHECK_EXCEPTION();
+    CAMLreturn(pack(NodeList(extract(list)).item(Long_val(idx)).this_));
+}
+
+CAMLprim value nodelist_new(value list)
+{ CAMLparam1(list);
+ auto l = jni::env().NewGlobalRef(MutableNodeList().this_);
+ CAMLreturn(pack(l));
+}
+
+CAMLprim value nodelist_add(value list, value node)
+{
+    CAMLparam2(list, node);
+    MutableNodeList(extract(list)).add(Node(extract(node)));
+    CAMLreturn(list);
+}
+
+CAMLprim value namednodemap_getLength(value list)
+{ CAMLparam1(list); CAMLreturn(Val_int(NamedNodeMap(extract(list)).getLength())); }
+
+CAMLprim value namednodemap_item(value list, value idx)
+{
+    CAMLparam2(list, idx);
+    CAMLreturn(pack(NamedNodeMap(extract(list)).item(Long_val(idx)).this_));
+}
+
+CAMLprim value getNull(value unit) { CAMLparam1(unit); CAMLreturn((value) NULL); }
+
+CAMLprim value dereference_object (value obj)
+{
+       CAMLparam1(obj);
+       jni::env().DeleteGlobalRef(reinterpret_cast<jobject>(obj));
+       CAMLreturn(Val_unit);
+}
+}