def createESMMModel(inputDim: Int, sharedLayerDims: Array[Int], ctrDims: Array[Int], cvrDims: Array[Int]): Module[Float] = {
val input = Input[Float]()
// Shared Representation Layers
var sharedLayer = input
var sharedInputDim = inputDim
for (dim <- sharedLayerDims) {
sharedLayer = ReLU[Float]().inputs(Linear(sharedInputDim, dim).inputs(sharedLayer))
sharedInputDim = dim
}
// CTR Prediction Branch
var ctrLayer = sharedLayer
var ctrInputDim = sharedInputDim
for (dim <- ctrDims) {
ctrLayer = ReLU[Float]().inputs(Linear(ctrInputDim, dim).inputs(ctrLayer))
ctrInputDim = dim
}
val ctrOutput = Sigmoid[Float]().inputs(Linear(ctrInputDim, 1).inputs(ctrLayer))
// CVR Prediction Branch
var cvrLayer = sharedLayer
var cvrInputDim = sharedInputDim
for (dim <- cvrDims) {
cvrLayer = ReLU[Float]().inputs(Linear(cvrInputDim, dim).inputs(cvrLayer))
cvrInputDim = dim
}
val cvrOutput = Sigmoid[Float]().inputs(Linear(cvrInputDim, 1).inputs(cvrLayer))
// Two separate outputs for CTR and CVR
Graph(input, Array(ctrOutput, cvrOutput))
}
val model = createESMMModel(nodeSize, Array(128, 64), Array(64, 32), Array(64, 32))
val trainData = DataSet.rdd(trainDf.rdd.map(row => {
val features = row.getAs[DenseVector]("features").toArray.map(_.toFloat)
val ctr = row.getAs[Double]("ctr").toFloat
val cvr = row.getAs[Double]("cvr").toFloat
Sample[Float](Tensor(features, Array(features.length)), Tensor(Array(ctr, cvr), Array(2)))
})).transform(SampleToMiniBatch[Float](batch.toInt))
val criterion = ParallelCriterion[Float]()
criterion.add(MSECriterion[Float]())
criterion.add(MSECriterion[Float]())
val optimizer = Optimizer(
model = model,
dataset = trainData,
criterion = criterion
).setOptimMethod(new Adam[Float]())
.setEndWhen(Trigger.maxEpoch(maxEpoch.toInt))
val trainedModel = optimizer.optimize()
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2259)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2208)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2207)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2207)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1079)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1079)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1079)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2446)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2388)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2377)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:868)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2203)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2298)
at org.apache.spark.rdd.RDD.$anonfun$reduce$1(RDD.scala:1120)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:414)
at org.apache.spark.rdd.RDD.reduce(RDD.scala:1102)
at com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$.optimize(DistriOptimizer.scala:355)
at com.intel.analytics.bigdl.dllib.optim.DistriOptimizer.optimize(DistriOptimizer.scala:923)
... 8 more
Caused by: com.intel.analytics.bigdl.dllib.utils.UnKnownException: java.lang.ClassCastException: com.intel.analytics.bigdl.dllib.utils.Table cannot be cast to com.intel.analytics.bigdl.dllib.tensor.Tensor
at com.intel.analytics.bigdl.dllib.utils.Log4Error$.unKnowExceptionError(Log4Error.scala:60)
at com.intel.analytics.bigdl.dllib.utils.ThreadPool.invokeAndWait2(ThreadPool.scala:175)
at com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$.$anonfun$optimize$4(DistriOptimizer.scala:261)
at org.apache.spark.rdd.ZippedPartitionsRDD2.compute(ZippedPartitionsRDD.scala:89)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:131)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1470)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.util.concurrent.ExecutionException: java.lang.ClassCastException: com.intel.analytics.bigdl.dllib.utils.Table cannot be cast to com.intel.analytics.bigdl.dllib.tensor.Tensor
at java.util.concurrent.FutureTask.report(FutureTask.java:122)
at java.util.concurrent.FutureTask.get(FutureTask.java:192)
at com.intel.analytics.bigdl.dllib.utils.ThreadPool.invokeAndWait2(ThreadPool.scala:172)
... 12 more
Caused by: java.lang.ClassCastException: com.intel.analytics.bigdl.dllib.utils.Table cannot be cast to com.intel.analytics.bigdl.dllib.tensor.Tensor
at com.iqiyi.read.big.util.ESMMCriterion.updateOutput(ESMMCriterion.scala:12)
at com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion.forward(AbstractCriterion.scala:75)
at com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$.$anonfun$optimize$7(DistriOptimizer.scala:272)
at scala.runtime.java8.JFunction0$mcI$sp.apply(JFunction0$mcI$sp.java:23)
at com.intel.analytics.bigdl.dllib.utils.ThreadPool$$anon$4.call(ThreadPool.scala:161)
at java.util.concurrent.FutureTask.run(FutureTask.java:266)
... 3 more