Implement the bridge to call Tatoo from java. Very slow at the moment.
[tatoo.git] / src / bindings / java / TatooEngine.cc
1 #include "fxslt_memory_TatooEngine.h"
2 #include "fxslt_memory_TatooEngine_Automaton.h"
3 #include "fxslt_memory_TatooEngine_Tree.h"
4 #include "fxslt_memory_TatooEngine_CustomBlock.h"
5 #include "tatoo.h"
6
7 #include <cassert>
8 #include <cstdint>
9 #include <memory>
10 #include <vector>
11
12 #include <caml/alloc.h>
13 #include <caml/callback.h>
14 #include <caml/memory.h>
15 extern "C" {
16 #include <caml/threads.h>
17 }
18
19 namespace jni {
20 namespace priv {
21
22 enum { JNI_VERSION = JNI_VERSION_1_2 };
23
24 JNIEnv *current_env;
25
26 template<typename T>
27 class Integer;
28
29 template<typename T>
30 struct Traits;
31
32 #define MAKE_TRAIT(T, N)                                                                    \
33 template<>                                                                                  \
34 struct Traits<T> {                                                                          \
35     static T (JNIEnv::*call)(jobject, jmethodID, va_list);                                  \
36 };                                                                                          \
37 T (JNIEnv::* Traits<T>::call)(jobject, jmethodID, va_list) = &JNIEnv::Call ## N ## MethodV
38
39 MAKE_TRAIT(jobject, Object);
40 MAKE_TRAIT(jint, Int);
41 MAKE_TRAIT(jshort, Short);
42 MAKE_TRAIT(jboolean, Boolean);
43
44 }
45
46 JNIEnv &env() throw() {
47     assert(priv::current_env);
48     return *priv::current_env;
49 }
50
51 class scoped_env {
52 public:
53     scoped_env(JNIEnv *env) throw()
54     {
55         assert(not priv::current_env);
56         priv::current_env = env;
57     }
58
59     scoped_env(JavaVM *vm) throw(jint)
60     {
61         assert(not priv::current_env);
62         if(vm->GetEnv(reinterpret_cast<void **>(&priv::current_env), priv::JNI_VERSION) != JNI_OK)
63             throw jint(-1);
64     }
65
66     ~scoped_env() throw()
67     {
68         assert(priv::current_env);
69         priv::current_env = NULL;
70     }
71 };
72
73
74
75 struct MemberDesc {
76     const char *name;
77     const char *signature;
78 };
79
80 struct ClassDesc {
81     const char *name;
82     const std::vector<MemberDesc> methods;
83     const std::vector<MemberDesc> fields;
84 };
85
86 class Class {
87 public:
88
89 private:
90     typedef std::vector<jmethodID> Methods;
91     typedef std::vector<jfieldID> Fields;
92
93     static jclass get_class(const char *name) throw(jint)
94     {
95         jclass c;
96         if((c = env().FindClass(name)) == NULL)
97             throw jint(-1);
98         if((c = static_cast<jclass>(env().NewGlobalRef(c))) == NULL)
99             throw jint(-1);
100         return c;
101     }
102
103     static Methods
104     get_methods(const jclass class_, const std::vector<MemberDesc> &methods) throw()
105     {
106         Methods ret(methods.size());
107         size_t i = 0;
108         for(auto it = methods.begin(); it != methods.end(); ++it, ++i)
109             ret[i] = env().GetMethodID(class_, it->name, it->signature);
110         return ret;
111     }
112
113     static Fields
114     get_fields(const jclass class_, const std::vector<MemberDesc> &fields) throw()
115     {
116         Fields ret(fields.size());
117         size_t i = 0;
118         for(auto it = fields.begin(); it != fields.end(); ++it, ++i)
119             ret[i] = env().GetFieldID(class_, it->name, it->signature);
120         return ret;
121     }
122
123 public:
124     const jclass class_;
125     const Methods methods;
126     const Fields fields;
127
128     Class(const ClassDesc &desc) throw(jint)
129         : class_(get_class(desc.name)), methods(get_methods(class_, desc.methods)),
130           fields(get_fields(class_, desc.fields))
131     { }
132
133     ~Class() throw() { env().DeleteGlobalRef(class_); }
134
135     jboolean IsInstanceOf(jobject obj) const throw()
136     { return env().IsInstanceOf(obj, class_); }
137 };
138
139 template<typename C>
140 class Object {
141 protected:
142     typedef Object Base;
143     static ClassDesc desc;
144     static Class *class_;
145
146         Object(jobject this_) throw() : this_(this_) {
147                 assert(class_->IsInstanceOf(this_));
148 #if 0
149                 if (!class_->IsInstanceOf(this_)) {
150
151                         jclass object = env().FindClass("java/lang/Object");
152                         jmethodID getClass_id = env().GetMethodID(object, "getClass", "()Ljava/lang/Class;");
153                         jobject oclass = env().CallObjectMethod(this_, getClass_id);
154                         jclass cls = env().FindClass("java/lang/Class");
155                         jmethodID getName_id = env().GetMethodID(cls, "getName", "()Ljava/lang/String;");
156                         jstring name = static_cast<jstring>(env().CallObjectMethod(oclass, getName_id));
157                         fprintf(stderr, "ERROR: class: %s is not an instance of %s\n", desc.name, jni::env().GetStringUTFChars(name, NULL));
158
159
160                         assert(class_->IsInstanceOf(this_));
161                 };
162 #endif
163         }
164     template<typename T>
165     T call(int method_id, ...) const
166     {
167         va_list vl;
168         va_start(vl, method_id);
169         T ret = (env().*priv::Traits<T>::call)(this_, class_->methods[method_id], vl);
170         va_end(vl);
171         return ret;
172     }
173     template<typename T>
174     static inline T static_call(jobject j, int method_id, ...) throw ()
175     {
176             va_list vl;
177             va_start(vl, method_id);
178             T ret = (env().*priv::Traits<T>::call)(j, class_->methods[method_id], vl);
179             va_end(vl);
180             return ret;
181     }
182 public:
183     const jobject this_;
184
185     static void initialize() throw(jint) { class_ = new Class(desc); }
186     static void finalize() throw() { delete class_; class_ = NULL; }
187
188     static const Class& get_class() { return *class_; }
189 };
190 template<typename C>
191 jni::Class *jni::Object<C>::class_ = NULL;
192
193 typedef priv::Integer<jint> Integer;
194 typedef priv::Integer<jshort> Short;
195
196 template<>
197 jni::ClassDesc jni::Object<Integer>::desc = {
198     "java/lang/Integer", {
199         { "intValue", "()I" },
200         { "<init>", "(I)V" }
201     }, { }
202 };
203 template<>
204 jni::ClassDesc jni::Object<Short>::desc = {
205     "java/lang/Short", {
206         { "shortValue", "()S" },
207     }, { }
208 };
209
210 template<typename T>
211 class priv::Integer: public Object<Integer<T>> {
212     enum Methods { valueID, initID };
213     typedef Object<Integer<T>> Base;
214 public:
215     Integer(jobject this_) throw() : Base(this_) { }
216     Integer(jint i) throw()
217             : Base(jni::env().NewObject(Base::class_->class_, Base::class_->methods[initID], i))
218     { }
219     T operator*() const throw() { return Base::template call<T>(valueID); }
220 };
221
222 class String {
223 private:
224     String(const String &) = delete;
225     String& operator=(const String &) = delete;
226
227     mutable const char *c_str_;
228 public:
229     const jstring this_;
230
231     String(jstring this_) throw() : c_str_(NULL), this_(this_) { }
232     String(String &&rhs) throw() : c_str_(rhs.c_str_), this_(rhs.this_) { rhs.c_str_ = NULL; }
233     ~String() throw() {
234         if(c_str_)
235             env().ReleaseStringUTFChars(this_, c_str_);
236     }
237
238     const char* c_str() const throw()
239     {
240         if(c_str_)
241             return c_str_;
242
243         return c_str_ = env().GetStringUTFChars(this_, NULL);
244     }
245 };
246
247 } // namespace jni
248
249 class Node;
250 class Attr;
251 class NodeList;
252 class NamedNodeMap;
253 class MutableNodeList;
254 class CustomBlock;
255
256 template<>
257 jni::ClassDesc jni::Object<Node>::desc = {
258     "org/w3c/dom/Node", {
259         { "getFirstChild", "()Lorg/w3c/dom/Node;" },
260         { "getNextSibling", "()Lorg/w3c/dom/Node;" },
261         { "getNodeName", "()Ljava/lang/String;" },
262         { "getNodeValue", "()Ljava/lang/String;" },
263         { "getUserData", "(Ljava/lang/String;)Ljava/lang/Object;" },
264         { "setUserData", "(Ljava/lang/String;Ljava/lang/Object;Lorg/w3c/dom/UserDataHandler;)Ljava/lang/Object;" },
265         { "getNodeType", "()S" },
266         { "getAttributes", "()Lorg/w3c/dom/NamedNodeMap;" }
267     }, { }
268 };
269
270 class Node: public jni::Object<Node> {
271     enum Methods {
272             getFirstChildID, getNextSiblingID, getNodeNameID, getNodeValueID, getUserDataID, setUserDataID, getNodeTypeID,
273         getAttributesID
274     };
275
276     static jni::String *empty_key;
277 public:
278     static void initialize() throw(jint)
279     {
280         Base::initialize();
281         empty_key = new jni::String(static_cast<jstring>(
282                     jni::env().NewGlobalRef(jni::env().NewStringUTF("")) ));
283     }
284     static void finalize() throw()
285     {
286         jni::env().DeleteGlobalRef(empty_key->this_);
287         delete empty_key;
288         empty_key = NULL;
289         Base::finalize();
290     }
291
292     Node(jobject this_) throw() : Base(this_) { }
293
294     Node getFirstChild() const throw() { return  Node(call<jobject>(getFirstChildID)); }
295     static inline jobject getFirstChildO(jobject obj) throw () {
296                 return static_call<jobject>(obj, getFirstChildID);
297         }
298     static inline jobject getNextSiblingO(jobject obj) throw () {
299                 return static_call<jobject>(obj, getNextSiblingID);
300         }
301
302     Node getNextSibling() const throw() { return Node(call<jobject>(getNextSiblingID)); }
303     jshort getNodeType() const throw() { return call<jshort>(getNodeTypeID); }
304
305     jni::String getNodeName() const throw()
306     { return jni::String(static_cast<jstring>(call<jobject>(getNodeNameID))); }
307
308     jni::String getNodeValue() const throw()
309     { return jni::String(static_cast<jstring>(call<jobject>(getNodeValueID))); }
310
311
312     jint getPreorder() const throw()
313     {
314             jobject data = call<jobject>(getUserDataID, empty_key->this_);
315             return *jni::Integer(data); }
316     static inline jobject getPreorderO(jobject obj) throw () {
317                 return static_call<jobject>(obj, getNextSiblingID);
318         }
319     void setPreorder(jint i) const throw()
320     {
321       call<jobject>(setUserDataID, empty_key->this_, jni::Integer(i), NULL);
322     }
323
324     NamedNodeMap getAttributes() const throw();
325 };
326 jni::String *Node::empty_key = NULL;
327
328 /********** Attr *************/
329 template<>
330 jni::ClassDesc jni::Object<Attr>::desc = {
331     "org/w3c/dom/Attr", {
332         { "getOwnerElement", "()Lorg/w3c/dom/Element;" }
333     }, { }
334 };
335
336 class Attr: public jni::Object<Attr> {
337     enum Methods {
338             getOwnerElementID
339     };
340
341 public:
342
343     Attr(jobject this_) throw() : Base(this_) { }
344
345     Node getOwnerElement() const throw() { return Node(call<jobject>(getOwnerElementID)); }
346 };
347
348 /********** NodeList **********/
349 template<>
350 jni::ClassDesc jni::Object<NodeList>::desc = {
351     "org/w3c/dom/NodeList", {
352         { "getLength", "()I" },
353         { "item", "(I)Lorg/w3c/dom/Node;" }
354     }, { }
355 };
356
357 class NodeList: public jni::Object<NodeList> {
358     enum Methods { getLengthID, itemID };
359
360 public:
361     NodeList(jobject this_) throw() : Base(this_) { }
362
363     jint getLength() const throw() { return call<jint>(getLengthID); }
364     Node item(jint i) const throw() { return Node(call<jobject>(itemID, i)); }
365 };
366
367
368 /********** NamedNodeMap **********/
369 template<>
370 jni::ClassDesc jni::Object<NamedNodeMap>::desc = {
371     "org/w3c/dom/NamedNodeMap", {
372         { "getLength", "()I" },
373         { "item", "(I)Lorg/w3c/dom/Node;" }
374     }, { }
375 };
376
377 class NamedNodeMap: public jni::Object<NamedNodeMap> {
378     enum Methods { getLengthID, itemID };
379
380 public:
381     NamedNodeMap(jobject this_) throw() : Base(this_) { }
382
383     jint getLength() const throw() { return call<jint>(getLengthID); }
384     Node item(jint i) const throw() { return Node(call<jobject>(itemID, i)); }
385 };
386
387
388 template<>
389 jni::ClassDesc jni::Object<MutableNodeList>::desc = {
390     "fxslt/memory/MutableNodeList", {
391         { "add", "(Lorg/w3c/dom/Node;)V" },
392         { "<init>", "()V" }
393     }, { }
394 };
395
396 class MutableNodeList: public jni::Object<MutableNodeList> {
397     enum Methods { addID, initID };
398
399 public:
400     MutableNodeList(jobject this_) throw() : Base(this_) { }
401     MutableNodeList() throw()
402         : Base(jni::env().NewObject(class_->class_, class_->methods[initID]))
403     { }
404
405    void add(Node n) throw() { call<jobject>(addID, n.this_); }
406 };
407
408
409 NamedNodeMap Node::getAttributes() const throw()
410 {
411         return NamedNodeMap(call<jobject>(getAttributesID));
412 }
413
414
415
416 template<>
417 jni::ClassDesc jni::Object<CustomBlock>::desc = {
418     "fxslt/memory/TatooEngine$CustomBlock", {
419         { "<init>", "(J)V" }
420     }, {
421         { "value_ptr", "J" }
422     }
423 };
424
425 class CustomBlock: public jni::Object<CustomBlock> {
426     enum Methods { initID };
427     enum Fields { valueID };
428
429     static_assert(sizeof(jlong) <= sizeof(value *), "We use jlong to store pointers.");
430
431     value* get() const throw()
432     { return reinterpret_cast<value *>(jni::env().GetLongField(this_, class_->fields[valueID])); }
433 public:
434     CustomBlock(jobject this_) throw() : Base(this_) {   }
435
436     CustomBlock(value val) throw()
437         : Base(jni::env().NewObject(class_->class_, class_->methods[initID], new value(val)))
438     { caml_register_generational_global_root(get()); }
439
440     value getValue() const throw() { return *get(); }
441 };
442
443 JNIEXPORT void JNICALL
444 Java_fxslt_memory_TatooEngine_unregister  (JNIEnv *env, jclass cls, jlong value_ptr)
445 {
446         value * vptr = reinterpret_cast<value *>(value_ptr);
447         caml_remove_generational_global_root(vptr);
448         delete vptr;
449 }
450
451 static value *init_document;
452 static value *xpath_compile;
453 static value *auto_evaluate;
454
455 JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *)
456 {
457     try {
458         jni::scoped_env se(vm);
459         jni::Integer::initialize();
460         Attr::initialize();
461         Node::initialize();
462         NodeList::initialize();
463         NamedNodeMap::initialize();
464         MutableNodeList::initialize();
465         CustomBlock::initialize();
466
467         char *argv[] = { NULL };
468         caml_startup(argv);
469
470         init_document = caml_named_value("init_document"); assert(init_document);
471         xpath_compile = caml_named_value("xpath_compile"); assert(xpath_compile);
472         auto_evaluate = caml_named_value("auto_evaluate"); assert(auto_evaluate);
473         caml_release_runtime_system();
474     }
475     catch(jint e) {
476         return e;
477     }
478
479     return jni::priv::JNI_VERSION;
480 }
481
482 JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *)
483 {
484     try {
485         jni::scoped_env se(vm);
486         MutableNodeList::finalize();
487         NamedNodeMap::finalize();
488         NodeList::finalize();
489         Node::finalize();
490         Attr::finalize();
491         jni::Integer::finalize();
492     }
493     catch(jint e) {
494         fprintf(stderr, "Critical error unloading shared library.\n");
495     }
496 }
497
498 static jobject extract(const value &val)
499 { return reinterpret_cast<jobject>(val); }
500
501 static value pack(jobject obj)
502 {
503 //    static_assert(sizeof(uintptr_t) <= sizeof(long), "We need long to hold pointers.");
504
505 //    uintptr_t p = reinterpret_cast<uintptr_t>(obj);
506 //    assert(! (p & 1));
507         return reinterpret_cast<value> (obj);
508 }
509
510
511 JNIEXPORT jobject JNICALL
512 Java_fxslt_memory_TatooEngine_init_1document(JNIEnv *env, jclass TatooEngine, jobject node, jint i)
513 {
514         CAMLparam0();
515         CAMLlocal1(val);
516         try {
517                 jni::scoped_env se(env);
518                 node = jni::env().NewGlobalRef(node);
519                 val = caml_callback2(*init_document, pack(node), Val_int(i));
520                 auto t = jni::env().NewGlobalRef(CustomBlock(val).this_);
521                 CAMLreturnT(auto, t);
522         } catch(jint e) {
523                 fprintf(stderr, "Critical error while initializing the document.\n");
524                 CAMLreturnT(jobject, NULL);
525     }
526
527 }
528
529
530 JNIEXPORT jobject JNICALL
531 Java_fxslt_memory_TatooEngine_compile(JNIEnv *env, jclass TatooEngine, jstring xpath)
532 {
533     CAMLparam0();
534     CAMLlocal1(val);
535
536     try {
537         jni::scoped_env se(env);
538
539         val = caml_callback(*xpath_compile, caml_copy_string(jni::String(xpath).c_str()));
540         auto a = CustomBlock(val).this_;
541         a = jni::env().NewGlobalRef(a);
542         CAMLreturnT(auto, a);
543     }
544     catch(jint e) {
545         fprintf(stderr, "Critical error while compiling.\n");
546         CAMLreturnT(jobject, NULL);
547     }
548 }
549
550 JNIEXPORT jobject JNICALL
551 Java_fxslt_memory_TatooEngine_evaluate(JNIEnv *env, jclass TatooEngine,
552                                        jobject automaton, jobject tree, jobject node_list)
553 {
554     CAMLparam0();
555     CAMLlocal4(res, vauto, vtree, vnl);
556
557     try {
558         jni::scoped_env se(env);
559         vauto = CustomBlock(automaton).getValue();
560         vtree = CustomBlock(tree).getValue();
561         vnl = pack(node_list);
562
563         res = caml_callback3(*auto_evaluate, vauto, vtree, vnl);
564         CAMLreturnT(auto, extract(res));
565     } catch(jint e) {
566         fprintf(stderr, "Critical error while evaluating.\n");
567         CAMLreturnT(jobject, NULL);
568     }
569 }
570
571 #define GR (node) (jni::env().NewGlobalRef((node)))
572
573 #if 0
574 #define CHECK_EXCEPTION()  do {                         \
575         if (jni::env().ExceptionCheck() == JNI_TRUE) {  \
576         jni::env().ExceptionDescribe();                 \
577         assert(false);                                  \
578         }                                               \
579         } while (0)
580 #else
581 #define CHECK_EXCEPTION()
582 #endif
583
584 extern "C" {
585 CAMLprim value node_getFirstChild(value node)
586 { CAMLparam1(node);
587         CHECK_EXCEPTION();
588
589         CAMLreturn(pack(Node::getFirstChildO(extract(node))));
590         //CAMLreturn(pack(Node(extract(node)).getFirstChild().this_));
591 }
592
593 CAMLprim value node_getNextSibling(value node)
594 {       CAMLparam1(node);
595         CHECK_EXCEPTION();
596         CAMLreturn(pack(Node::getNextSiblingO(extract(node))));
597         //CAMLreturn(pack(Node(extract(node)).getNextSibling().this_));
598 }
599
600 CAMLprim value node_getNodeType(value node)
601 { CAMLparam1(node);
602         CHECK_EXCEPTION();
603         CAMLreturn(Val_int(Node(extract(node)).getNodeType())); }
604
605 CAMLprim value node_getNodeName(value node)
606 { CAMLparam1(node);
607         CHECK_EXCEPTION();
608         jstring obj = Node(extract(node)).getNodeName().this_;
609         value cstr = caml_copy_string(jni::env().GetStringUTFChars(obj, NULL));
610         CAMLreturn(cstr);
611 }
612
613 CAMLprim value node_getPreorder(value node)
614 { CAMLparam1(node);
615         CHECK_EXCEPTION();
616         CAMLreturn(Val_int(Node(extract(node)).getPreorder()));
617  }
618
619 CAMLprim value node_setPreorder(value node, value i)
620 {
621         CAMLparam1(node);
622         CHECK_EXCEPTION();
623         Node(extract(node)).setPreorder(Int_val(i));
624         CAMLreturn(Val_unit);
625 }
626
627 CAMLprim value print_runtime_class(value o)
628 {
629    CAMLparam1(o);
630    CHECK_EXCEPTION();
631    jclass object = jni::env().FindClass("java/lang/Object");
632    jmethodID getClass_id = jni::env().GetMethodID(object, "getClass", "()Ljava/lang/Class;");
633    jobject oclass = jni::env().CallObjectMethod(extract(o), getClass_id);
634    jclass cls = jni::env().FindClass("java/lang/Class");
635    jmethodID getName_id = jni::env().GetMethodID(cls, "getName", "()Ljava/lang/String;");
636    jstring name = static_cast<jstring>(jni::env().CallObjectMethod(oclass, getName_id));
637    fprintf(stderr, "CLASS OF ATTTRIBUTE IS %s \n", jni::env().GetStringUTFChars(name, NULL));
638    fflush(stderr);
639    CAMLreturn(Val_unit);
640 }
641 CAMLprim value attr_getOwnerElement(value node)
642 {
643         CAMLparam1(node);
644         CHECK_EXCEPTION();
645         auto attr = Attr(extract(node));
646         CAMLreturn(pack(attr.getOwnerElement().this_));
647  }
648
649 CAMLprim value node_getAttributes(value node)
650 { CAMLparam1(node);
651         CHECK_EXCEPTION();
652         CAMLreturn(pack(Node(extract(node)).getAttributes().this_)); }
653
654 CAMLprim value nodelist_getLength(value list)
655 { CAMLparam1(list);
656         CHECK_EXCEPTION();
657         CAMLreturn(Val_int(NodeList(extract(list)).getLength())); }
658
659 CAMLprim value nodelist_item(value list, value idx)
660 {
661     CAMLparam2(list, idx);
662     CHECK_EXCEPTION();
663     CAMLreturn(pack(NodeList(extract(list)).item(Long_val(idx)).this_));
664 }
665
666 CAMLprim value nodelist_new(value list)
667 { CAMLparam1(list);
668  auto l = jni::env().NewGlobalRef(MutableNodeList().this_);
669  CAMLreturn(pack(l));
670 }
671
672 CAMLprim value nodelist_add(value list, value node)
673 {
674     CAMLparam2(list, node);
675     MutableNodeList(extract(list)).add(Node(extract(node)));
676     CAMLreturn(list);
677 }
678
679 CAMLprim value namednodemap_getLength(value list)
680 { CAMLparam1(list); CAMLreturn(Val_int(NamedNodeMap(extract(list)).getLength())); }
681
682 CAMLprim value namednodemap_item(value list, value idx)
683 {
684     CAMLparam2(list, idx);
685     CAMLreturn(pack(NamedNodeMap(extract(list)).item(Long_val(idx)).this_));
686 }
687
688 CAMLprim value getNull(value unit) { CAMLparam1(unit); CAMLreturn((value) NULL); }
689
690 CAMLprim value dereference_object (value obj)
691 {
692         CAMLparam1(obj);
693         jni::env().DeleteGlobalRef(reinterpret_cast<jobject>(obj));
694         CAMLreturn(Val_unit);
695 }
696 }