Chad's Blog

MathML inside Javadoc Using MathJax and a Custom Taglet

Posted in Java by chadretz on December 19, 2010

I am writing a baseball statistics application. It is called StatMantis and it’s another in a long line of personal projects I’ll probably never finish. It contains many statistics which have equations that deserve to be in the javadocs. After I researched all the possible ways to accomplish this and I settled on MathML. For example, in my Batting Average On Balls In Play (BABIP) statistic, there is a formula. I wanted one of those badass formula displays like Wikipedia has. So I had javadoc on the class that looked like this:

/**
 * This calculates Batting Average on Balls In Play (BABIP). See the Wikipedia reference
 * <a href="http://en.wikipedia.org/wiki/Batting_average_on_balls_in_play">here</a>. 
 * This is a pitching and batting statistic. This is calculated as:
 * <p>
 * <math>
 *     <mfrac>
 *         <mrow>
 *             <mi>H</mi>
 *             <mo>-</mo>
 *             <mi">HR</mi>
 *         </mrow>
 *         <mrow>
 *             <mi>AB</mi>
 *             <mo>-</mo>
 *             <mi>K</mi>
 *             <mo>-</mo>
 *             <mi>HR</mi>
 *             <mo>+</mo>
 *             <mi>SF</mi>
 *         </mrow>
 *     </mfrac>
 * </math>
 */

I noticed several projects out there that I could utilize to put a MathML equation in my javadoc. JEuclid was my first choice. It could output to an AWT image. It could even given me information to generate an image map to link to the other factors in the equation. But I would have to hack up the standard Sun (…er…Oracle) doclet and do mass hackery for this simple thing. I decided it wasn’t worth the effort since the javadoc tool is difficult to extend. Hopefully one day a project like javadoc-ng will get finished and solve my problems (wink wink, sorry if you’re still waiting on me Harmony guys). So after I couldn’t find anything I liked on the MathML implementations page I almost gave up. I couldn’t find any that were interactive and cross browser. I specifically wanted the href feature of the MathML 3 spec so I could link to my other classes.

Then I found MathJax and it looked like it would solve all my problems. So the first thing I did was toss it in the <footer> element of the javadoc ant task. The distribution is extremely large, and yes you need just about all of it. So I put MathJax in there and edited config/MathJax.js. I changed jax: ["input/TeX","output/HTML-CSS"] to jax: ["input/MathML", "output/HTML-CSS"] and extensions: ["tex2jax.js"] to extensions: ["mml2jax.js"]. This is needed because the input is MathML, not TeX. Once the config was changed, I added the following in my javadoc ant task:

<footer><![CDATA[
    <script type=\"text/javascript\" src=\"{@docroot}/MathJax/MathJax.js\"></script>
]]></footer>

I put this in the footer, because it seems like the header is rendered twice, which isn’t cool. Also, per the javadoc footer documentation I have to escape the quotes. The {@docroot} makes sure it sets the relative links properly. Once I executed this I was very happy to see my formula appear properly.

Now I wanted to link each part of the equation to its representative class. My first approach was to utilize the inline {@link} tag. This outputs a code tag wrapped in an anchor tag which links the piece. I wrote at least a dozen different javascript functions to pull the href out of the anchor and put it on the <mi> tag and move the text out from inside the code tag. Everything I tried ended up in extreme failure because IE sucks balls. Specifically, I can’t edit my math tags via DOM because IE doesn’t understand it. I also couldn’t insert a manipulated MathML string as innerHTML on the parent, because it trimmed off pieces for no reason.

So I decided that I needed another, non-client side approach to having links. Post processing the HTML was my first guess, but that is not a very elegant solution. So I decided to extend Javadoc w/ a custom Taglet. I tried to find the existing Taglet that Sun built for {@link} so I could use the same algorithm to obtain the URL, but it’s not there. They get the benefit of having everything there including the RootDoc. So I wrote my own Taglet and tried it in many scenarios. I quickly realized I would not be able to link to all possible methods/fields because the Taglet interface simply doesn’t provide enough information. Similarly, I can’t validate the values entered in my taglet either. Without further ado, here is the custom Taglet (collapsed by default):

/*
 * Copyright 2010 Chad Retz
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package org.statmantis.tools.javadoc;

import java.util.Map;

import com.sun.javadoc.ClassDoc;
import com.sun.javadoc.Doc;
import com.sun.javadoc.PackageDoc;
import com.sun.javadoc.ProgramElementDoc;
import com.sun.javadoc.Tag;
import com.sun.tools.doclets.Taglet;

/**
 * Taglet that supports the linkhref inline tag. This tag will return just the
 * HREF to a javadoc class file (not any of the methods/fields)
 * 
 * @author Chad Retz
 */
public class LinkHrefTaglet implements Taglet {

    @SuppressWarnings({ "rawtypes", "unchecked" })
    public static void register(Map tagletMap) {
        LinkHrefTaglet tag = new LinkHrefTaglet();
        Taglet t = (Taglet) tagletMap.get(tag.getName());
        if (t != null) {
            tagletMap.remove(tag.getName());
        }
        tagletMap.put(tag.getName(), tag);
    }

    @Override
    public String getName() {
        return "linkhref";
    }

    @Override
    public boolean inConstructor() {
        return true;
    }

    @Override
    public boolean inField() {
        return true;
    }

    @Override
    public boolean inMethod() {
        return true;
    }

    @Override
    public boolean inOverview() {
        return true;
    }

    @Override
    public boolean inPackage() {
        return true;
    }

    @Override
    public boolean inType() {
        return true;
    }

    @Override
    public boolean isInlineTag() {
        return true;
    }

    private PackageDoc getPackageDoc(Tag tag) {
        Doc holder = tag.holder();
        if (holder instanceof ProgramElementDoc) {
            return ((ProgramElementDoc) holder).containingPackage();
        } else if (holder instanceof PackageDoc) {
            return (PackageDoc) holder;
        } else {
            throw new RuntimeException("Unrecognized holder: " + holder);
        }
    }

    private ClassDoc getTopLevelClassDoc(ClassDoc classDoc) {
        if (classDoc.containingClass() == null) {
            return classDoc;
        } else {
            return getTopLevelClassDoc(classDoc);
        }
    }

    private ClassDoc getTopLevelClassDoc(Tag tag) {
        Doc holder = tag.holder();
        if (holder instanceof PackageDoc) {
            return null;
        } else if (holder instanceof ClassDoc) {
            return getTopLevelClassDoc((ClassDoc) holder);
        } else if (holder instanceof ProgramElementDoc) {
            return getTopLevelClassDoc(((ProgramElementDoc) holder)
                    .containingClass());
        } else {
            throw new RuntimeException("Unrecognized holder: " + holder);
        }
    }

    private ClassDoc findClass(String className, ClassDoc[] classImports) {
        for (ClassDoc classDoc : classImports) {
            if (classDoc.name().equals(className)) {
                return classDoc;
            }
        }
        return null;
    }

    private ClassDoc findClass(String className, PackageDoc... packageImports) {
        for (PackageDoc packageDoc : packageImports) {
            for (ClassDoc found : packageDoc.allClasses(true)) {
                if (found.name().equals(className)) {
                    return found;
                }
            }
        }
        return null;
    }
    
    private String error(Tag tag, String error) {
        System.err.println(tag.position() + ": warning - " + error);
        return "javascript: //error";
    }

    @Override
    @SuppressWarnings("deprecation")
    public String toString(Tag tag) {
        PackageDoc packageDoc = getPackageDoc(tag);
        ClassDoc topLevelClassDoc = getTopLevelClassDoc(tag);
        //k, what I'm gonna do is what the main one does...go up to the root
        StringBuilder href = new StringBuilder();
        int dotIndex = packageDoc.name().indexOf('.');
        while (dotIndex != -1) {
            href.append("../");
            dotIndex = packageDoc.name().indexOf('.', dotIndex + 1);
        }
        //package name is empty when it is the root package
        if (!packageDoc.name().isEmpty()) {
            href.append("../");
        }
        //now that we have the root, begin the string parse
        String classInTag = tag.text();
        int poundIndex = classInTag.indexOf('#');
        if (poundIndex != -1) {
            classInTag = classInTag.substring(0, poundIndex);
        }
        //ok, if it's qualified, we just assume it's all good
        if (classInTag.indexOf('.') == -1) {
            ClassDoc classDoc;
            if (topLevelClassDoc == null) {
                //not in a class scope? just try inside this package
                classDoc = findClass(classInTag, packageDoc);
                if (classDoc == null) {
                    //they should qualify it then
                    return error(tag, "Can't locate linkhref class " + classInTag + 
                            ". The name should be qualified.");
                }
            } else {
                //nope? ok, first try my inner classes
                classDoc = findClass(classInTag, topLevelClassDoc.innerClasses(true));
                if (classDoc == null) {
                    //nope? ok, try my single-type-imports
                    classDoc = findClass(classInTag,
                            topLevelClassDoc.importedClasses());
                    if (classDoc == null) {
                        //nope? ok, try my type-import-on-demands
                        classDoc = findClass(classInTag, topLevelClassDoc.importedPackages());
                        if (classDoc == null) {
                            //nope? ok, finally try my own package
                            findClass(classInTag, topLevelClassDoc.containingPackage());
                            if (classDoc == null) {
                                //not even now? well, just assume it's there because
                                //  javadoc doesn't populate fairly
                                classInTag = topLevelClassDoc.containingPackage().name() +
                                        '.' + classInTag;
                            }
                        }
                    }
                }
            }
            if (classDoc != null) {
                classInTag = classDoc.qualifiedName();
            }
        }
        if (classInTag.indexOf('.') == -1) {
            return error(tag, "Unable get linkhref for class " + classInTag +
                    " because it is in the root package");
        }
        // ok, now make the link by replacing the dots w/ slashes
        href.append(classInTag.replace('.', '/'));
        // add .html
        href.append(".html");
        // all good
        return href.toString();
    }

    @Override
    public String toString(Tag[] tags) {
        // not for inline tags...nope
        return null;
    }

}

It works only for class/interface references and doesn’t do any real validation. Regardless, it solves my problem perfectly, and now my mathematical formulas appear in my javadoc complete with links to other classes. Check out the build-javadoc target in the ANT script to see how to include it in the javadoc task. Overall, it works well and I am happy with it. Here is what the aforementioned BABIP javadoc looks like now:

/**
 * This calculates Batting Average on Balls In Play (BABIP). See the Wikipedia reference
 * <a href="http://en.wikipedia.org/wiki/Batting_average_on_balls_in_play">here</a>. 
 * This is a pitching and batting statistic. This is calculated as:
 * <p>
 * <math style="font-size: 200%">
 *     <mfrac>
 *         <mrow>
 *             <mi href="{@linkhref Hits}">H</mi>
 *             <mo>-</mo>
 *             <mi href="{@linkhref HomeRuns}">HR</mi>
 *         </mrow>
 *         <mrow>
 *             <mi href="{@linkhref AtBats}">AB</mi>
 *             <mo>-</mo>
 *             <mi href="{@linkhref Strikeouts}">K</mi>
 *             <mo>-</mo>
 *             <mi href="{@linkhref HomeRuns}">HR</mi>
 *             <mo>+</mo>
 *             <mi href="{@linkhref SacrificeFlies}">SF</mi>
 *         </mrow>
 *     </mfrac>
 * </math>
 */

All of this is APL licensed making it commercially friendly.

STL collections with Java and SWIG

Posted in C++, Java, SWIG by chadretz on November 27, 2009

When using SWIG with Java, I quickly realized that there were wasn’t support for std::set or std::list and only minimal support for std::map and std::vector (i.e. no proper iteration). Other languages (python, ruby, etc) had std_set.i and std_list.i whereas Java does not. I am not a C++ expert and definitely not an expert at writing SWIG typemaps.

My only solution was to write C++ wrappers for these collections. For some I also needed iterator wrappers to support their Java counterparts. They are listed below (collapsed by default).

ListWrapper.h

#pragma once
#include <list>

template<class T>
class ListWrapper
{
public:
	std::list<T>* _list;
	ListWrapper(std::list<T>* original)
	{
		this->_list = original;
	}

	~ListWrapper()
	{
	}

	int size()
	{
		return this->_list->size();
	}

	bool contains(T item)
	{
		for(std::list<T>::iterator iter = this->_list->begin(); 
				iter != this->_list->end(); iter++) {
			if (*iter == item) {
				return true;
			}
		}
		return false;
	}

	bool add(T item)
	{
		this->_list->push_back(item);
		return true;
	}

	void clear()
	{
		this->_list->clear();
	}

	bool remove(T item)
	{
		int size = this->_list->size();
		this->_list->remove(item);
		return size != this->_list->size();
	}
};

template<class T>
class ListIterator
{
private:
	std::list<T>* _list;
	typename std::list<T>::const_iterator _iter;
public:
	ListIterator(std::list<T>* original)
	{
		this->_list = original;
		this->_iter = this->_list->begin();
	}

	bool hasNext()
	{
		return this->_iter != this->_list->end();
	}

	T next()
	{
		T ret = (T) *this->_iter;
		this->_iter++;
		return ret;
	}
};

MapWrapper.h

#pragma once
#include <map>

template<class K, class V>
class MapWrapper
{
private:
	std::map<K, V>* _map;
public:
	MapWrapper(std::map<K, V>* original)
	{
		this->_map = original;
	}

	~MapWrapper()
	{
	}

	int size()
	{
		return this->_map->size();
	}

	bool add(K key, V value)
	{
		bool present = this->_map->find(key) != this->_map->end();
		(*this->_map)[key] = value;
		return present;
	}

	void clear()
	{
		this->_map->clear();
	}

	bool remove(K key)
	{
		std::map<K, V>::const_iterator iter = this->_map->find(key);
		if (iter != this->_map->end()) {
			this->_map->erase(iter);
			return true;
		} else {
			return false;
		}
	}
};

template<class K, class V>
class MapIterator
{
private:
	std::map<K, V>* _map;
	typename std::map<K, V>::iterator _iter;
	std::pair<K, V> _current;
public:
	MapIterator(std::map<K, V>* original)
	{
		this->_map = original;
		this->_iter = this->_map->begin();
	}

	bool hasNext()
	{
		return this->_iter != this->_map->end();
	}

	void next()
	{
		this->_current = *this->_iter;
		this->_iter++;
	}

	K getKey()
	{
		return this->_current.first;
	}

	V getValue()
	{
		return this->_current.second;
	}
};

SetWrapper.h

#pragma once
#include <set>

template<class T>
class SetWrapper
{
private:
	std::set<T>* _set;
public:
	SetWrapper(std::set<T>* original)
	{
		this->_set = original;
	}

	~SetWrapper()
	{
	}

	int size()
	{
		return this->_set->size();
	}

	bool contains(T item)
	{
		std::set<T>::const_iterator iter = this->_set->find(item);
		return iter != this->_set->end();
	}

	bool add(T item)
	{
		return this->_set->insert(item).second;
	}

	void clear()
	{
		this->_set->clear();
	}

	bool remove(T item)
	{
		std::set<T>::const_iterator iter = this->_set->find(item);
		if (iter != this->_set->end()) {
			this->_set->erase(iter);
			return true;
		} else {
			return false;
		}
	}
};

template<class T>
class SetIterator
{
private:
	std::set<T>* _set;
	typename std::set<T>::iterator _iter;
public:
	SetIterator(std::set<T>* original)
	{
		this->_set = original;
		this->_iter = this->_set->begin();
	}

	bool hasNext()
	{
		return this->_iter != this->_set->end();
	}

	T next()
	{
		T ret = (T) *this->_iter;
		this->_iter++;
		return ret;
	}
};

VectorWrapper.h

#pragma once
#include <vector>

template<class T>
class VectorWrapper
{
private:
	std::vector<T>* _vector;
public:
	VectorWrapper(std::vector<T>* original)
	{
		this->_vector = original;
	}

	~VectorWrapper()
	{
	}

	int size()
	{
		return this->_vector->size();
	}

	void add(int index, T item)
	{
		this->_vector->insert(this->_vector->begin() + index, item);
	}

	void clear()
	{
		this->_vector->clear();
	}

	T set(int index, T item)
	{
		T ret = this->get(index);
		(*this->_vector)[index] = item;
		return ret;
	}

	T remove(int index)
	{
		T item = this->_vector->at(index);
		this->_vector->erase(this->_vector->begin() + index);
		return item;
	}

	T get(int index)
	{
		return this->_vector->at(index);
	}
};

Basically all these do is provide SWIG-readable methods for the underlying collections. Again, I am not a C++ expert and I only have minimal test cases to prove to me they work. These headers should be “%import”ed in your SWIG interface.

Now I need to wrap these in normal Java collections. I use std::map, std::set, and std::vector in Java as a Map, Set, and List respectively. In STL, std::list is not indexed and therefore must also be mapped to a Set. Note, this violates the Java contract of a Set where it is assumed that all elements are unique. I suppose I could I have used a Queue or an AbstractSequentialList, but I’ll stick w/ Set for now (iteration is still in order).

This means I needed 3 Java implementations: NativeList (for vectors), NativeMap (for maps), and NativeSet (for sets and lists). They are listed below (collapsed by default)

NativeList.java

package org.cretz.swig.collection;

import java.lang.reflect.Method;
import java.util.AbstractList;
import java.util.List;

/**
 * Wrapper for std::vector from SWIG
 * 
 * @author Chad Retz
 *
 * @param <T>
 */
public class NativeList<T> extends AbstractList<T> implements List<T> {

	private final Object listWrapper;
	private final Method sizeMethod;
	private final Method addMethod;
	private final Method clearMethod;
	private final Method setMethod;
	private final Method removeMethod;
	private final Method getMethod;
	
	/**
	 * Construct native list from std::vector wrappers
	 * 
	 * @param nativeClass The native class
	 * @param nativeList The SWIG vector
	 * @param listWrapperClass The SWIG vector class wrapper
	 */
	public NativeList(Class<T> nativeClass, Object nativeList, Class<?> listWrapperClass) {
		try {
			listWrapper = listWrapperClass.getConstructor(nativeList.getClass()).
					newInstance(nativeList);
			sizeMethod = listWrapperClass.getDeclaredMethod("size");
			addMethod = listWrapperClass.getDeclaredMethod("add", 
					Integer.TYPE, nativeClass);
			clearMethod = listWrapperClass.getDeclaredMethod("clear");
			setMethod = listWrapperClass.getDeclaredMethod("set", 
					Integer.TYPE, nativeClass);
			removeMethod = listWrapperClass.getDeclaredMethod("remove", Integer.TYPE);
			getMethod = listWrapperClass.getDeclaredMethod("get", Integer.TYPE);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}
	
	@Override
	public void add(int index, T item) {
		try {
			addMethod.invoke(listWrapper, index, item);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}
	
	@Override
	public void clear() {
		try {
			clearMethod.invoke(listWrapper);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}
	
	@Override
	@SuppressWarnings("unchecked")
	public T set(int index, T item) {
		try {
			return (T) setMethod.invoke(listWrapper, index, item);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}		
	}

	@Override
	@SuppressWarnings("unchecked")
	public T remove(int index) {
		try {
			return (T) removeMethod.invoke(listWrapper, index);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}
	
	@Override
	public boolean remove(Object item) {
		try {
			return (Boolean) removeMethod.invoke(listWrapper, item);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	@SuppressWarnings("unchecked")
	public T get(int index) {
		try {
			return (T) getMethod.invoke(listWrapper, index);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public int size() {
		try {
			return (Integer) sizeMethod.invoke(listWrapper);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}
}

NativeMap.java

package org.cretz.swig.collection;

import java.lang.reflect.Method;
import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Wrapper for std::map from SWIG. This is basically
 * the same as std::set and impl w/ {@link org.cretz.swig.collection.NativeSet}
 * 
 * @author Chad Retz
 *
 * @param <K>
 * @param <V>
 */
public class NativeMap<K, V> extends AbstractMap<K, V> implements Map<K, V> {

	private final Class<?> mapIteratorClass;
	private final Object nativeMap;
	private final Object mapWrapper;
	private final Method sizeMethod;
	private final Method addMethod;
	private final Method clearMethod;
	private final Method removeMethod;
	private final NativeMapSet nativeMapSet; 
	private final Method hasNextMethod;
	private final Method nextMethod;
	private final Method keyMethod;
	private final Method valueMethod;
	
	/**
	 * Construct native map from std::map wrappers
	 * 
	 * @param nativeKeyClass
	 * @param nativeValueClass
	 * @param mapIteratorClass
	 * @param nativeMap
	 * @param mapWrapperClass
	 */
	public NativeMap(Class<K> nativeKeyClass, Class<V> nativeValueClass, 
			Class<?> mapIteratorClass, Object nativeMap,
			Class<?> mapWrapperClass) {
		this.mapIteratorClass = mapIteratorClass;
		this.nativeMap = nativeMap;
		try {
			mapWrapper = mapWrapperClass.getConstructor(nativeMap.getClass()).
					newInstance(nativeMap);
			sizeMethod = mapWrapperClass.getDeclaredMethod("size");
			addMethod = mapWrapperClass.getDeclaredMethod("add", 
					nativeKeyClass, nativeValueClass);
			clearMethod = mapWrapperClass.getDeclaredMethod("clear");
			removeMethod = mapWrapperClass.getDeclaredMethod("remove", 
					nativeKeyClass);
			hasNextMethod = mapIteratorClass.getDeclaredMethod("hasNext");
			nextMethod = mapIteratorClass.getDeclaredMethod("next");
			keyMethod = mapIteratorClass.getDeclaredMethod("getKey");
			valueMethod = mapIteratorClass.getDeclaredMethod("getValue");
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
		nativeMapSet = new NativeMapSet();
	}
	
	@Override
	public Set<Entry<K, V>> entrySet() {
		return nativeMapSet;
	}
	
	@Override
	public V put(K key, V value) {
		V old = get(key);
		entrySet().add(new NativeEntry(key, value));
		return old;
	}

	@Override
	@SuppressWarnings("unchecked")
	public V remove(Object key) {
		V old = get(key);
		if (old != null) {
			entrySet().remove(new NativeEntry((K) key, old));
		}
		return old;
	}

	@Override
	public int size() {
		return entrySet().size();
	}
	
	protected class NativeEntry implements Entry<K, V> {
		private final K key;
		private V value;
		
		protected NativeEntry(K key, V value) {
			this.key = key;
			this.value = value;
		}
		
		@Override
		public K getKey() {
			return key;
		}
		
		@Override
		public V getValue() {
			return value;
		}
		
		@Override
		public V setValue(V value) {
			V old = this.value;
			this.value = value;
			put(key, value);
			return old;
		}
	}
	
	protected class NativeMapSet extends AbstractSet<Entry<K, V>> 
			implements Set<Entry<K, V>> {
		
		@Override
		public boolean add(Entry<K, V> item) {
			try {
				return (Boolean) addMethod.invoke(mapWrapper, 
						item.getKey(), item.getValue());
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}

		@Override
		public void clear() {
			try {
				clearMethod.invoke(mapWrapper);
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}

		@Override
		public Iterator<Entry<K, V>> iterator() {
			return new NativeMapSetIterator();
		}

		@Override
		public boolean remove(Object item) {
			try {
				if (item instanceof Entry<?, ?>) {
					return (Boolean) removeMethod.invoke(mapWrapper, 
							((Entry<?, ?>)item).getKey());
				}
				return false;
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}

		@Override
		public boolean removeAll(Collection<?> collection) {
			boolean modified = false;
			for (Object item : collection) {
				modified |= remove(item);
			}
			return modified;
		}
		
		@Override
		public boolean retainAll(Collection<?> collection) {
			//best way?
			List<Entry<K, V>> toRemove = new ArrayList<Entry<K, V>>(this.size());
			for (Entry<K, V> item : this) {
				if (!collection.contains(item)) {
					toRemove.add(item);
				}
			}
			return removeAll(toRemove);
		}

		@Override
		public int size() {
			try {
				return (Integer) sizeMethod.invoke(mapWrapper);
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}
	}
	
	protected class NativeMapSetIterator implements Iterator<Entry<K, V>> {

		private final Object setIterator;
		
		private NativeMapSetIterator() {
			try {
				setIterator = mapIteratorClass.getConstructor(
						nativeMap.getClass()).newInstance(nativeMap);
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}
		
		@Override
		public boolean hasNext() {
			try {
				return (Boolean) hasNextMethod.invoke(setIterator);
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}

		@Override
		@SuppressWarnings("unchecked")
		public NativeEntry next() {
			try {
				nextMethod.invoke(setIterator);
				return new NativeEntry((K) keyMethod.invoke(setIterator), 
						(V) valueMethod.invoke(setIterator));
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}

		/**
		 * {@inheritDoc}
		 * <p>
		 * Unsupported
		 */
		@Override
		public void remove() {
			throw new UnsupportedOperationException();
		}
		
	}
}

NativeSet.java

package org.cretz.swig.collection;

import java.lang.reflect.Method;
import java.util.AbstractSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/**
 * Wrapper for std::set and std::list from SWIG. This does
 * not support removal during iteration.
 * 
 * @author Chad Retz
 * 
 * @param <T>
 */
public class NativeSet<T> extends AbstractSet<T> implements Set<T> {

	private final Class<T> nativeClass;
	private final Class<?> setIteratorClass;
	private final Object nativeSet;
	private final Object setWrapper;
	private final Method sizeMethod;
	private final Method containsMethod;
	private final Method addMethod;
	private final Method clearMethod;
	private final Method removeMethod;
	private final Method hasNextMethod;
	private final Method nextMethod;
	
	/**
	 * Instantiate the native set
	 * 
	 * @param nativeClass The native class
	 * @param setIteratorClass The class for the SetIterator
	 * @param nativeSet The native set object
	 * @param setWrapperClass The class for the Set
	 */
	public NativeSet(Class<T> nativeClass, 
			Class<?> setIteratorClass, Object nativeSet,
			Class<?> setWrapperClass) {
		this.nativeClass = nativeClass;
		this.setIteratorClass = setIteratorClass;
		this.nativeSet = nativeSet;
		try {
			setWrapper = setWrapperClass.getConstructor(nativeSet.getClass()).
					newInstance(nativeSet);
			sizeMethod = setWrapperClass.getDeclaredMethod("size");
			containsMethod = setWrapperClass.getDeclaredMethod("contains", 
					nativeClass);
			addMethod = setWrapperClass.getDeclaredMethod("add", 
					nativeClass);
			clearMethod = setWrapperClass.getDeclaredMethod("clear");
			removeMethod = setWrapperClass.getDeclaredMethod("remove", 
					nativeClass);
			hasNextMethod = setIteratorClass.getDeclaredMethod("hasNext");
			nextMethod = setIteratorClass.getDeclaredMethod("next");
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}
	
	@Override
	public boolean add(T item) {
		try {
			return (Boolean) addMethod.invoke(setWrapper, item);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public void clear() {
		try {
			clearMethod.invoke(setWrapper);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public boolean contains(Object item) {
		try {
			return nativeClass.isAssignableFrom(item.getClass()) &&
					(Boolean) containsMethod.invoke(setWrapper, item);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public Iterator<T> iterator() {
		return new NativeSetIterator();
	}

	@Override
	public boolean remove(Object item) {
		try {
			return (Boolean) removeMethod.invoke(setWrapper, item);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public boolean removeAll(Collection<?> collection) {
		boolean modified = false;
		for (Object item : collection) {
			modified |= remove(item);
		}
		return modified;
	}
	
	@Override
	public boolean retainAll(Collection<?> collection) {
		//best way?
		List<T> toRemove = new ArrayList<T>(this.size());
		for (T item : this) {
			if (!collection.contains(item)) {
				toRemove.add(item);
			}
		}
		return removeAll(toRemove);
	}

	@Override
	public int size() {
		try {
			return (Integer) sizeMethod.invoke(setWrapper);
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	protected class NativeSetIterator implements Iterator<T> {

		private final Object setIterator;
		
		private NativeSetIterator() {
			try {
				setIterator = setIteratorClass.getConstructor(
						nativeSet.getClass()).newInstance(nativeSet);
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}
		
		@Override
		public boolean hasNext() {
			try {
				return (Boolean) hasNextMethod.invoke(setIterator);
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}

		@Override
		@SuppressWarnings("unchecked")
		public T next() {
			try {
				return (T) nextMethod.invoke(setIterator);
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}

		/**
		 * {@inheritDoc}
		 * <p>
		 * Unsupported
		 */
		@Override
		public void remove() {
			throw new UnsupportedOperationException();
		}
		
	}
}

Note, only in NativeList is remove supported inside the iterator. All other features should be supported. Let’s review how to use each of these in conjunction with their STL counterparts.

std::list – As mentioned earlier, this must use the NativeSet due to a lack of being indexed. Suppose we have the following C++ reference:

std::list<MyObject*>

Therefore you would need a %template set for both the ListWrapper class and ListIterator class (both in ListWrapper.h above) like so:

%template (MyObjectListWrapper) ListWrapper<MyObject*>;
%template (MyObjectListIterator) ListIterator<MyObject*>;

We’ll assume MyObject was generated by SWIG as the class name MyObject and your std::list was generated by SWIG as the class name SWIGTYPE_p_std__listT_MyObject_p_t. We’ll also assume your original std::list can be obtained via mymodule.getMyObjectList(). Including the above, you should now have two more classes: MyObjectListWrapper and MyObjectListIterator. Now you can either instantiate the NativeSet directly like so:

new NativeSet<MyObject>(MyObject.class, MyObjectListIterator.class, 
        mymodule.getMyObjectList(), MyObjectListWrapper.class);

or you can extend NativeSet and call the super constructor with these values.

std::map – Suppose we have the following C++ reference (This assumes you have “stl.i” as an %import in your SWIG interface):

std::map<std::string, MyObject*>

Therefore you would need a %template set for both the MapWrapper class and MapIterator class (both in MapWrapper.h above) like so:

%template (MyObjectMapWrapper) MapWrapper<MyObject*>;
%template (MyObjectMapIterator) MapIterator<MyObject*>;

We’ll assume MyObject was generated by SWIG as the class name MyObject and your std::map was generated by SWIG as the class name SWIGTYPE_p_std__mapT_std__string_MyObject_p_t. We’ll also assume your original std::map can be obtained via mymodule.getMyObjectMap(). Including the above, you should now have two more classes: MyObjectMapWrapper and MyObjectMapIterator. Now you can either instantiate the NativeMap directly like so:

new NativeMap<String, MyObject>(String.class, MyObject.class,
        MyObjectMapIterator.class, mymodule.getMyObjectMap(),
        MyObjectMapWrapper.class);

or you can extend NativeMap and call the super constructor with these values.

std::set – This is very similar to std::list. Suppose we have the following C++ reference:

std::set<MyObject*>

Therefore you would need a %template set for both the SetWrapper class and SetIterator class (both in SetWrapper.h above) like so:

%template (MyObjectSetWrapper) SetWrapper<MyObject*>;
%template (MyObjectSetIterator) SetIterator<MyObject*>;

We’ll assume MyObject was generated by SWIG as the class name MyObject and your std::set was generated by SWIG as the class name SWIGTYPE_p_std__setT_MyObject_p_t. We’ll also assume your original std::set can be obtained via mymodule.getMyObjectSet(). Including the above, you should now have two more classes: MyObjectSetWrapper and MyObjectSetIterator. Now you can either instantiate the NativeSet directly like so:

new NativeSet<MyObject>(MyObject.class, MyObjectSetIterator.class, 
        mymodule.getMyObjectSet(), MyObjectSetWrapper.class);

or you can extend NativeSet and call the super constructor with these values.

std::vector – Since a vector is indexed, an iterator is not needed here. Suppose we have the following C++ reference:

std::vector<MyObject*>

Therefore you would need a %template set for the VectorWrapper class (in VectorWrapper.h above) like so:

%template (MyObjectVectorWrapper) VectorWrapper<MyObject*>;

We’ll assume MyObject was generated by SWIG as the class name MyObject and your std::vector was generated by SWIG as the class name SWIGTYPE_p_std__vectorT_MyObject_p_t. We’ll also assume your original std::vector can be obtained via mymodule.getMyObjectVector(). Including the above, you should now have another class: MyObjectVectorWrapper. Now you can either instantiate the NativeList directly like so:

new NativeList<MyObject>(MyObject.class, mymodule.getMyObjectVector(),
        MyObjectVectorWrapper.class);

or you can extend NativeList and call the super constructor with these values.

Now you have Java collections representing your STL collections directly. There are several things to note:

  1. Pointers – These collections use the underlying pointers. Therefore, altering a collection here will alter the source collection. Also, since my examples use MyObject as a pointer, they will also be altered by users of this collection. If you don’t want this, instantiate your favorite version of the collection and pass this collection as the parameter; this does a copy
  2. Garbage Collection – Using mutable collections like this can have an issue with SWIG memory ownership. Make sure you read up on SWIG memory management. If you add an object created in Java to this collection, SWIG automatically assumes it owns the collection of this object. Since there will be no references to the object inside of Java, the object may get garbage collected. I usually call Collections.unmodifiable* on this collection because it’s rare I need it changed. Otherwise, you probably instantiated the native collection object in Java too and it should be OK then.
  3. Performance – These collections extend AbstractList, AbstractMap, and AbstractSet. With AbstractList, doing things like remove and contains with the actual object iterate over the entire object. The other two abstracts have similar mechanisms for other pieces. For all collections, equals() iterates through all. Please reference the base classes to understand what they do. If I wasn’t lazy right now, I’d have focused on bridging all STL methods and implement better RandomAccess.
  4. Errors – I only tested a few pieces for right now. I haven’t tested with null objects, or doing things like calling iterator.next() if iterator.hasNext() is false. I purposefully didn’t implement iterator.remove(). Good luck 🙂
  5. Thread safety – I have no clue! Be safe and use Collections.synchronized*.
  6. Transformation – Lots of times, you want a cleaner object on the other side of your collections. You can use commons collections‘ CollectionUtils.transformed* (or the more specific ListUtils, MapUtils, and SetUtils). If you’re cool like me, you’d use larvalabs’ collections w/ generics or Google Collections.

Remember, none of this is tested that well and I’m not the strongest C++ dev around so use at your own risk. Once I complete and open source my library that uses this, I will link to the full code and test cases. I hope this code helps someone. License: WTFPL.

Tagged with: , , ,

Registering native methods when calling Java from SWIG

Posted in C++, Java, SWIG by chadretz on November 25, 2009

I am working on a project where my DLL is loaded as a dynamic plugin in a third party application. I wanted to use Java to accomplish my task so I chose to use SWIG with the Invocation API. Luckily, I can guarantee a single execution platform: Windoze. I am no C++ expert, but here’s how I accomplished this…

I will put all my JARs on the same path as the DLL, so I grabbed the current DLL directory using the HMODULE in the DllMain like so:

wchar_t* filename = new wchar_t[300];
GetModuleFileName((HMODULE) hModule, filename, 300);
_dllDirectory = filename;
_dllDirectory = _dllDirectory.substr(0, _dllDirectory.find_last_of('\\') + 1);

I also need all JARs in this directory, so I created a method:

string get_jars(wstring dir)
{
    WIN32_FIND_DATA findFileData;
    HANDLE hFind;
    hFind = FindFirstFile((dir + str_to_wstr("*.jar")).c_str(), &findFileData);
    if (hFind == INVALID_HANDLE_VALUE) {
        return "";
    } else {
        string str_dir = wstr_to_str(dir);
        string ret = str_dir + wstr_to_str(findFileData.cFileName);
        while(FindNextFile(hFind, &findFileData)) {
            //since we're on windows, this is the separator
            ret += ';';
            ret += str_dir + wstr_to_str(findFileData.cFileName);
        }
        FindClose(hFind);
        return ret;
    }
}

Now that I can get all JARs, I instantiate the JVM in a constructor of my “management” class. I have to first load the JVM DLL because it won’t be loaded by default. I use the JAVA_HOME environment variable to locate the DLL. The code is similar to the following:

MyNS::MyClass(wstring dllDirectory)
{
    this->_libInst = NULL;
    this->_jvm = NULL;
    this->_delegate = NULL;
    //create vm
    //get dll location
    string env = getenv("JAVA_HOME");
    if (file_exists((env + "\\bin\\client\\jvm.dll").c_str())) {
        env += "\\bin\\client\\jvm.dll";
    } else if (file_exists((env + "\\jre\\bin\\client\\jvm.dll").c_str())) {
        //prolly a JDK
        env += "\\jre\\bin\\client\\jvm.dll";
    } else {
        MyNS::MyClass->printf("JVM Dll not found; Is JAVA_HOME set?");
        return;
    }
    //load it
    if ( (this->_libInst = LoadLibrary(str_to_wstr(env).c_str())) == NULL) {
        MyNS::MyClass->printf("Can't load JVM DLL");
        return;
    }
    //grab vm creation method
    CreateJavaVM_t* createFn = (CreateJavaVM_t *)GetProcAddress(this->_libInst, "JNI_CreateJavaVM");
    if (createFn == NULL) {
        MyNS::MyClass->printf("Can't locate JNI_CreateJavaVM");
        return;
    }
    //build options
    JavaVMInitArgs initArgs;
    JavaVMOption* options = new JavaVMOption[1];
    string classpath = "-Djava.class.path=";
    //we want all jars in our directory
    string jars = get_jars(dllDirectory);
    if (jars == "") {
        return;
    }
    classpath += jars;
    options[0].optionString = (char *) classpath.c_str();
    //assuming 1.6 here
    initArgs.version = JNI_VERSION_1_6;
    initArgs.nOptions = 1;
    initArgs.options = options;
    initArgs.ignoreUnrecognized = false;
    //create vm
    if (createFn(&this->_jvm, (void **)&this->_env, &initArgs) != 0) {
        delete options;
        MyNS::MyClass->printf("Can't create VM");
        return;
    }
    delete options;
    //grab delegate class
    this->_delegate = this->_env->FindClass("my/java/package/MyDelegationClass");
    if (this->_delegate == NULL) {
        this->_jvm->DestroyJavaVM();
        this->_jvm = NULL;
        MyLog::MyLog->printf("Can't find delegate class; Is bridge JAR present?");
        return;
    }
}

Also, we must make sure we free resources in the destructor:

MyNS::~MyClass()
{
    if (this->_jvm != NULL) {
        this->_jvm->DestroyJavaVM();
    }
    if (this->_libInst != NULL) {
        FreeLibrary(this->_libInst);
    }
}

After compiling this into a DLL along w/ a test method to call back to Java, I went ahead and ran SWIG on the includes I was writing the plugin for. Unfortunately, none of the Java native methods were mapped. It turns out that, when instantiating the VM from native code, none of the JNI methods are “registered”. RegisterNatives to the rescue. The C++ API I was using generated several thousand Java methods in SWIG making it much to trivial to hand write the RegisterNatives code. So, I made an ANT task to do it for me. Here’s the code (collapsed by default):

package org.cretz.swig.ant;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.List;

import org.apache.tools.ant.BuildException;
import org.apache.tools.ant.Project;
import org.apache.tools.ant.Task;

/**
 * Ant task to add "RegisterNative" code to SWIG gen'd files
 * 
 * @author Chad Retz
 */
public class RegisterNativesTask extends Task {
    
    /**
     * Get a type signature from a class
     * 
     * @param cls
     * @return
     */
    private static String getJNITypeSignature(Class<?> cls) {
        String ret = "";
        if (cls.isArray()) {
            ret += '[';
        }
        if ("boolean".equals(cls.getName())) {
            ret += 'Z';
        } else if ("byte".equals(cls.getName())) {
            ret += 'B';
        } else if ("char".equals(cls.getName())) {
            ret += 'C';
        } else if ("short".equals(cls.getName())) {
            ret += 'S';
        } else if ("int".equals(cls.getName())) {
            ret += 'I';
        } else if ("long".equals(cls.getName())) {
            ret += 'J';
        } else if ("float".equals(cls.getName())) {
            ret += 'F';
        } else if ("double".equals(cls.getName())) {
            ret += 'D';
        } else if ("void".equals(cls.getName())) {
            ret += 'V';
        } else {
            ret += 'L' + cls.getName().replace('.', '/') + ';';
        }
        return ret;
    }
    
    /**
     * Get a JNI signature for the given method
     * 
     * @param method
     * @return
     */
    private static String getJNISignature(Method method) {
        StringBuilder builder = new StringBuilder();
        builder.append('(');
        for (Class<?> parameter : method.getParameterTypes()) {
            builder.append(getJNITypeSignature(parameter));
        }
        builder.append(')');
        builder.append(getJNITypeSignature(method.getReturnType()));
        return builder.toString();
    }
    
    /**
     * Load an entire fine into a string
     * 
     * @param file
     * @return
     * @throws IOException
     */
    private static String loadFileToString(File file) throws IOException {
        BufferedReader reader = new BufferedReader(
                new FileReader(file));
        StringBuilder builder = new StringBuilder();
        try {
            String line = reader.readLine();
            while (line != null) {
                builder.append(line + "\r\n");
                line = reader.readLine();
            }
            return builder.toString();
        } finally {
            reader.close();
        }
    }
    
    /**
     * Write an entire string to a file
     * 
     * @param file
     * @param string
     * @throws IOException
     */
    private static void writeStringToFile(File file, String string) throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(file));
        try {
            writer.write(string);
        } finally {
            writer.close();
        }
    }

    private String source;
    private String module;
    private String _package;
    
    @Override
    public void execute() throws BuildException {
        if (source == null) {
            throw new BuildException("source is required");
        }
        if (module == null) {
            throw new BuildException("module is required");
        }
        if (_package == null) {
            throw new BuildException("_package is required");
        }
        //get cpp and header file
        File cppFile = new File(source);
        File headerFile = new File(source.replace(".cpp", ".h"));
        if (!cppFile.exists()) {
            throw new BuildException("Can't find source file");
        }
        if (!headerFile.exists()) {
            throw new BuildException("Can't find header file");
        }
        String cppContents;
        String headerContents;
        try {
            cppContents = loadFileToString(cppFile);
            headerContents = loadFileToString(headerFile);
        } catch (IOException e) {
            throw new BuildException("Unable to load cpp or header file", e);
        }
        //find class
        Class<?> cls;
        try {
            cls = Class.forName(_package + "." + module + "JNI");
        } catch (ClassNotFoundException e) {
            throw new BuildException("Can't find JNI class", e);
        }
        //load up the methods
        List<JNIMethod> methods = new ArrayList<JNIMethod>();
        String prefix = "Java_" + cls.getName().replace('.', '_');
        for (Method method : cls.getDeclaredMethods()) {
            if (Modifier.isNative(method.getModifiers())) {
                //get the part before the underscore
                String methodName = method.getName();
                String jniMethod;
                jniMethod = prefix + '_' + methodName.replace("_", "_1");
                if (!cppContents.contains("JNICALL " + jniMethod + "(")) {
                    log("Can't find JNI method, skipping: " + jniMethod, 
                            Project.MSG_WARN);
                } else {
                    methods.add(new JNIMethod(methodName, getJNISignature(method), 
                            jniMethod));
                }
            }
        }
        //write pieces to header and cpp
        headerContents = "#pragma once\r\n#include <jni.h>\r\n\r\n" +
                headerContents + "\r\n\r\nclass SwigUtils {\r\npublic:\r\n\t" +
                "static int registerNatives(JNIEnv* env);\r\n};";
        StringBuilder cppMethods = new StringBuilder();
        cppMethods.append("int SwigUtils::registerNatives(JNIEnv* env)\r\n{\r\n");
        cppMethods.append("\tJNINativeMethod methods[" + methods.size() + "];\r\n");
        cppMethods.append("\tjclass cls = env->FindClass(\"");
        cppMethods.append(cls.getName().replace('.', '/') + "\");\r\n");
        for (int i = 0; i < methods.size(); i++) {
            JNIMethod method = methods.get(i);
            cppMethods.append("\r\n\tmethods[" + i + "].name = \"" +
                    method.name + "\";\r\n");
            cppMethods.append("\tmethods[" + i + "].signature = \"" +
                    method.signature + "\";\r\n");
            cppMethods.append("\tmethods[" + i + "].fnPtr = (void*)&" +
                    method.cppSignature + ";\r\n");
        }
        cppMethods.append("\r\n\treturn (int) env->RegisterNatives(cls, methods, " +
                methods.size() + ");\r\n}");
        try {
            writeStringToFile(headerFile, headerContents);
            writeStringToFile(cppFile, cppContents + "\r\n\r\n" + 
                    cppMethods.toString());
        } catch (IOException e) {
            throw new BuildException("Unable to write cpp or header file", e);
        }
    }
    
    /**
     * The source cpp file to alter (also assumes the
     * header file is there)
     *  
     * @return
     */
    public String getSource() {
        return source;
    }
    
    public void setSource(String source) {
        this.source = source;
    }
    
    /**
     * The name of the SWIG module
     * 
     * @return
     */
    public String getModule() {
        return module;
    }

    public void setModule(String module) {
        this.module = module;
    }

    /**
     * The package the source is in
     * 
     * @return
     */
    public String getPackage() {
        return _package;
    }

    public void setPackage(String _package) {
        this._package = _package;
    }

    /**
     * Simple POJO for holding method information
     * 
     * @author Chad Retz
     */
    private static class JNIMethod {
        private final String name;
        private final String signature;
        private final String cppSignature;
        
        private JNIMethod(String name, String signature, String cppSignature) {
            this.name = name;
            this.signature = signature;
            this.cppSignature = cppSignature;
        }
    }
}

NOTE: I wouldn’t use the JNI signature generator above in other situation; it doesn’t handle arrays.

It accepts three arguments. The “source” argument is the path to the .cpp file (the -o argument passed to SWIG). The “module” argument is the SWIG module name in your .i file. The “package” argument is the package you told SWIG about (the -package argument for SWIG).

It reflectively obtains all native methods in the ‘module + JNI’.java file. This means the generated SWIG code must be on the classpath when running this task. It creates a SwigUtils class in the header file with one method: registerNatives which accepts a JNIEnv:

class SwigUtils {
public:
    static int registerNatives(JNIEnv* env);
};

In the cpp file, it implements this method at the bottom with all the other SWIG code. The method returns the value returned by RegisterNatives. Now all you have to do is call this once you are done creating your JVM:

int regRes = SwigUtils::registerNatives(this->_env);
if (regRes != 0) {
    MyLog::MyLog->printf("Couldn't register natives");
}

I hope this code helps someone. License: WTFPL.

Tagged with: , , ,