From 217f18bc8a28369370fbe28992a2b88b04837803 Mon Sep 17 00:00:00 2001 From: cbora Date: Sun, 7 May 2017 21:01:57 -0400 Subject: [PATCH] data source tests --- src/main/java/org/template/DataSource.java | 82 +++---- .../java/org/template/DataSourceTest.java | 204 ++++++++++++++++++ 2 files changed, 250 insertions(+), 36 deletions(-) create mode 100644 src/test/java/org/template/DataSourceTest.java diff --git a/src/main/java/org/template/DataSource.java b/src/main/java/org/template/DataSource.java index 40e3619..87129df 100755 --- a/src/main/java/org/template/DataSource.java +++ b/src/main/java/org/template/DataSource.java @@ -1,6 +1,6 @@ package org.template; -import grizzled.slf4j.Logger; +import org.slf4j.Logger; import org.apache.predictionio.controller.EmptyParams; import org.apache.predictionio.controller.java.PJavaDataSource; import org.apache.predictionio.core.EventWindow; @@ -27,24 +27,26 @@ public class DataSource extends PJavaDataSource> implements SelfCleaningDataSource { - private final Logger logger = new Logger(LoggerFactory.getLogger(SelfCleaningDataSource.class)); + private final Logger logger = LoggerFactory.getLogger(SelfCleaningDataSource.class); private transient PEvents pEventsDb = Storage.getPEvents(); private transient LEvents lEventsDb = Storage.getLEvents(false); public final DataSourceParams dsp; // Data source param object + /** + * Data Source reads data from an input source and transforms it into a desired format + * @param dsp + */ public DataSource(DataSourceParams dsp) { this.dsp = dsp; + // Draw info - /* - drawInfo("Init DataSource", Seq( - ("===================", "==================="), - ("App name", dsp.getAppName()), - ("Event window", dsp.getEventWindow()), - ("Event names", dsp.getEventNames()) - )) - */ - + List> t = new ArrayList>(); + t.add(new Tuple2("===================", "===================")); + t.add(new Tuple2("App name", dsp.getAppName())); + t.add(new Tuple2("Event window", dsp.getEventWindow())); + t.add(new Tuple2("Event names", dsp.getEventNames())); + Conversions.drawInfo("Init DataSource", t, this.logger); } /* Getter @@ -61,13 +63,37 @@ public EventWindow getEventWindow() { return dsp.getEventWindow(); } + /** + * Separate events by event name + * @return actionRdds + * */ + public List>> separateEvents(JavaRDD eventsRDD) { + ArrayList eventNames = dsp.getEventNames(); // get event names + + // Now separate events by event names + List>> actionRDDs = + eventNames.stream() + .map(eventName -> { + JavaRDD> actionRDD = + eventsRDD.filter(event -> !event.entityId().isEmpty() + && !event.targetEntityId().get().isEmpty() + && eventName.equals(event.event())) + .map(event -> new Tuple2( + event.entityId(), + event.targetEntityId().get())); + return new Tuple2<>(eventName, JavaPairRDD.fromJavaRDD(actionRDD)); + }) + .filter( pair -> !pair._2().isEmpty()) + .collect(Collectors.toList()); + + return actionRDDs; + + } + /* Getter * @retrun new TrainingData object */ public TrainingData readTraining(SparkContext sc) { - - ArrayList eventNames = dsp.getEventNames(); // get event names - // find events associated with the particular app name and eventNames JavaRDD eventsRDD = PJavaEventStore.find( dsp.getAppName(), // app name @@ -82,26 +108,10 @@ public TrainingData readTraining(SparkContext sc) { sc // spark context ).repartition(sc.defaultParallelism()); - // Now separate events by event name - List>> actionRDDs = - eventNames.stream() - .map(eventName -> { - JavaRDD> actionRDD = - eventsRDD.filter(event -> !event.entityId().isEmpty() - && !event.targetEntityId().get().isEmpty() - && eventName.equals(event.event())) - .map(event -> new Tuple2( - event.entityId(), - event.targetEntityId().get())); - return new Tuple2<>(eventName, JavaPairRDD.fromJavaRDD(actionRDD)); - }) - .filter( pair -> !pair._2().isEmpty()) - .collect(Collectors.toList()); - - // String eventNamesLogger = actionRDDs.stream() - // .map(i -> i._1()).collect(Collectors.joining(", ")); - - // logger.debug(String.format("Received actions for events %s", eventNamesLogger)); + List>> actionRDDs = separateEvents(eventsRDD); + String eventNamesLogger = actionRDDs.stream().map(i -> i._1()).collect(Collectors.joining(", ")); + + logger.debug(String.format("Received actions for events %s", eventNamesLogger)); JavaRDD> fieldsRDD = PJavaEventStore.aggregateProperties( dsp.getAppName(), // app name @@ -122,8 +132,8 @@ private boolean isSetEvent(Event e) { return e.event().equals("$set") || e.event().equals("$unset"); } - public Logger logger() { - return this.logger; + public grizzled.slf4j.Logger logger() { + return (grizzled.slf4j.Logger)this.logger; } @Override diff --git a/src/test/java/org/template/DataSourceTest.java b/src/test/java/org/template/DataSourceTest.java new file mode 100644 index 0000000..0619f34 --- /dev/null +++ b/src/test/java/org/template/DataSourceTest.java @@ -0,0 +1,204 @@ +package org.template; + + + +import static org.junit.Assert.*; + + +import org.apache.predictionio.core.EventWindow; +import org.apache.predictionio.data.storage.Event; +import org.apache.predictionio.data.store.java.OptionHelper; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaPairRDD; +import org.joda.time.DateTime; +import org.apache.spark.api.java.JavaRDD; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import scala.Tuple2; +import scala.collection.JavaConversions; +import scala.collection.Seq; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + + +/** + * Created by cbora on 3/24/17. + */ +public class DataSourceTest { + + private DataSource source; + private DataSourceParams params; + private SparkContext sc; + private JavaRDD events; + private String appName; + + /** + * Helper function to set environment variables + * @param key + * @param value + */ + public void setEnv(String key, String value) { + try { + Map env = System.getenv(); + Class cl = env.getClass(); + java.lang.reflect.Field field = cl.getDeclaredField("m"); + field.setAccessible(true); + Map writableEnv = (Map) field.get(env); + writableEnv.put(key, value); + } catch (Exception e) { + throw new IllegalStateException("Failed to set environment variable", e); + } + } + + /** + * Function to set environment variables + */ + public void setEnvironmentVariables() { + + // Function to set different environment variables + setEnv("SPARK_HOME", "$PIO_HOME/vendors/spark-1.5.1-bin-hadoop2.6"); + setEnv("POSTGRES_JDBC_DRIVER", "$PIO_HOME/lib/postgresql-9.4-1204.jdbc41.jar"); + setEnv("MYSQL_JDBC_DRIVER", "$PIO_HOME/lib/mysql-connector-java-5.1.37.jar"); + setEnv("PIO_FS_BASEDIR", "$HOME/.pio_store"); + setEnv("PIO_FS_ENGINESDIR", "$PIO_FS_BASEDIR/engines"); + setEnv("PIO_FS_TMPDIR", "$PIO_FS_BASEDIR/tmp"); + setEnv("PIO_STORAGE_REPOSITORIES_METADATA_NAME", "pio_meta"); + //setEnv("PIO_STORAGE_REPOSITORIES_METADATA_SOURCE", "elasticsearch"); + setEnv("PIO_STORAGE_REPOSITORIES_METADATA_SOURCE", "PGSQL"); + setEnv("PIO_STORAGE_REPOSITORIES_EVENTDATA_NAME", "pio_event"); + setEnv("PIO_STORAGE_REPOSITORIES_EVENTDATA_SOURCE", "PGSQL"); + setEnv("PIO_STORAGE_REPOSITORIES_MODELDATA_NAME", "pio_model"); + setEnv("PIO_STORAGE_REPOSITORIES_MODELDATA_SOURCE", "PGSQL"); + setEnv("PIO_STORAGE_SOURCES_PGSQL_TYPE", "jdbc"); + setEnv("PIO_STORAGE_SOURCES_PGSQL_URL", "jdbc:postgresql://localhost/pio"); + setEnv("PIO_STORAGE_SOURCES_PGSQL_USERNAME", "pio"); + setEnv("PIO_STORAGE_SOURCES_PGSQL_PASSWORD", "pio"); + setEnv("PIO_STORAGE_SOURCES_ELASTICSEARCH_TYPE", "elasticsearch"); + setEnv("PIO_STORAGE_SOURCES_ELASTICSEARCH_CLUSTERNAME", "elasticsearch"); + setEnv("PIO_STORAGE_SOURCES_ELASTICSEARCH_HOSTS", "localhost"); + setEnv("PIO_STORAGE_SOURCES_ELASTICSEARCH_PORTS", "9300"); + setEnv("PIO_STORAGE_SOURCES_ELASTICSEARCH_HOME", "$PIO_HOME/vendors/elasticsearch-1.4.4"); + setEnv("PIO_STORAGE_SOURCES_HBASE_TYPE", "hbase"); + setEnv("PIO_STORAGE_SOURCES_HBASE_HOME", "$PIO_HOME/vendors/hbase-1.0.0"); + } + + /** + * Helper function to create a new Event object + * @param name + * @param targetEntityId + * @param eventTime + * @return + */ + private Event makeEvent(String name, String targetEntityId, DateTime eventTime){ + Event e = new Event( + OptionHelper.none(), + name, + "a", + "a", + OptionHelper.none(), + OptionHelper.some("b"), + null, + eventTime, + null, + OptionHelper.none(), + DateTime.now() + ); + return e; + } + + + @Before + public void setUp() { + setEnvironmentVariables(); + + // spark context object + appName = "myApp"; + SparkConf conf = new SparkConf().setAppName(appName).setMaster("local"); + sc = new SparkContext(conf); + + // create some events + Event e1 = makeEvent("event1", "a", new DateTime(2)); + Event e2 = makeEvent("event2", "a", new DateTime(10)); + Event e3 = makeEvent("event3", "b", new DateTime(30)); + Event e4 = makeEvent("event4", "c", new DateTime(40)); + Event e5 = makeEvent("event5", "d", new DateTime(50)); + + List eventsList = Arrays.asList(e1, e2, e3, e4, e5); + Seq seqEvents = JavaConversions.asScalaBuffer(eventsList).toSeq(); + ClassTag tag = ClassTag$.MODULE$.apply(Event.class); + events = sc.parallelize(seqEvents, sc.defaultParallelism(), tag).toJavaRDD(); + } + + @Test + public void matchAllEvents() throws Exception { + // Construct DataSourceParams object + ArrayList eventNames = new ArrayList(); + eventNames.add("event1"); + eventNames.add("event2"); + eventNames.add("event3"); + eventNames.add("event4"); + eventNames.add("event5"); + EventWindow eventWindow = null; + params = new DataSourceParams(appName, eventNames, eventWindow); + source = new DataSource(params); + List>> results = source.separateEvents(events); + assertTrue(results.size() == eventNames.size()); + } + + @Test + public void matchNoEvent() throws Exception { + ArrayList eventNames = new ArrayList(); + eventNames.add("event100"); + eventNames.add("event101"); + eventNames.add("event102"); + eventNames.add("event103"); + EventWindow eventWindow = null; + params = new DataSourceParams(appName, eventNames, eventWindow); + source = new DataSource(params); + List>> results = source.separateEvents(events); + assertTrue(results.size() == 0); + } + + @Test + public void matchWithEmptyEventNames() throws Exception { + ArrayList eventNames = new ArrayList(); + EventWindow eventWindow = null; + params = new DataSourceParams(appName, eventNames, eventWindow); + source = new DataSource(params); + List>> results = source.separateEvents(events); + assertTrue(results.size() == 0); + //System.out.println() + } + + @Test + public void getAppName() throws Exception { + ArrayList eventNames = new ArrayList(); + EventWindow eventWindow = null; + params = new DataSourceParams(appName, eventNames, eventWindow); + source = new DataSource(params); + String appName = source.getAppName(); + assertTrue(appName.equals(params.getAppName())); + } + + @Test + public void getEventWindow() throws Exception { + ArrayList eventNames = new ArrayList(); + EventWindow eventWindow = null; + params = new DataSourceParams(appName, eventNames, eventWindow); + source = new DataSource(params); + List>> results = source.separateEvents(events); + assertNull(source.getEventWindow()); + } + + @After + public void tearDown() { + sc.stop(); + } +} \ No newline at end of file