How should the training data of ParallelCriterion be combined?

14 views
Skip to first unread message

clare

unread,
Oct 26, 2024, 10:56:04 AM10/26/24
to User Group for BigDL
def createESMMModel(inputDim: Int, sharedLayerDims: Array[Int], ctrDims: Array[Int], cvrDims: Array[Int]): Module[Float] = {
val input = Input[Float]()

// Shared Representation Layers
var sharedLayer = input
var sharedInputDim = inputDim
for (dim <- sharedLayerDims) {
sharedLayer = ReLU[Float]().inputs(Linear(sharedInputDim, dim).inputs(sharedLayer))
sharedInputDim = dim
}

// CTR Prediction Branch
var ctrLayer = sharedLayer
var ctrInputDim = sharedInputDim
for (dim <- ctrDims) {
ctrLayer = ReLU[Float]().inputs(Linear(ctrInputDim, dim).inputs(ctrLayer))
ctrInputDim = dim
}
val ctrOutput = Sigmoid[Float]().inputs(Linear(ctrInputDim, 1).inputs(ctrLayer))

// CVR Prediction Branch
var cvrLayer = sharedLayer
var cvrInputDim = sharedInputDim
for (dim <- cvrDims) {
cvrLayer = ReLU[Float]().inputs(Linear(cvrInputDim, dim).inputs(cvrLayer))
cvrInputDim = dim
}
val cvrOutput = Sigmoid[Float]().inputs(Linear(cvrInputDim, 1).inputs(cvrLayer))

// Two separate outputs for CTR and CVR
Graph(input, Array(ctrOutput, cvrOutput))
}
val model = createESMMModel(nodeSize, Array(128, 64), Array(64, 32), Array(64, 32))
val trainData = DataSet.rdd(trainDf.rdd.map(row => {
val features = row.getAs[DenseVector]("features").toArray.map(_.toFloat)
val ctr = row.getAs[Double]("ctr").toFloat
val cvr = row.getAs[Double]("cvr").toFloat
Sample[Float](Tensor(features, Array(features.length)), Tensor(Array(ctr, cvr), Array(2)))
})).transform(SampleToMiniBatch[Float](batch.toInt))
val criterion = ParallelCriterion[Float]()
criterion.add(MSECriterion[Float]())
criterion.add(MSECriterion[Float]())
val optimizer = Optimizer(
model = model,
dataset = trainData,
criterion = criterion
).setOptimMethod(new Adam[Float]())
.setEndWhen(Trigger.maxEpoch(maxEpoch.toInt))
val trainedModel = optimizer.optimize()

Driver stacktrace: at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2259) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2208) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2207) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2207) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1079) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1079) at scala.Option.foreach(Option.scala:407) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1079) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2446) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2388) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2377) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:868) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2203) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2298) at org.apache.spark.rdd.RDD.$anonfun$reduce$1(RDD.scala:1120) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:414) at org.apache.spark.rdd.RDD.reduce(RDD.scala:1102) at com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$.optimize(DistriOptimizer.scala:355) at com.intel.analytics.bigdl.dllib.optim.DistriOptimizer.optimize(DistriOptimizer.scala:923) ... 8 more Caused by: com.intel.analytics.bigdl.dllib.utils.UnKnownException: java.lang.ClassCastException: com.intel.analytics.bigdl.dllib.utils.Table cannot be cast to com.intel.analytics.bigdl.dllib.tensor.Tensor at com.intel.analytics.bigdl.dllib.utils.Log4Error$.unKnowExceptionError(Log4Error.scala:60) at com.intel.analytics.bigdl.dllib.utils.ThreadPool.invokeAndWait2(ThreadPool.scala:175) at com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$.$anonfun$optimize$4(DistriOptimizer.scala:261) at org.apache.spark.rdd.ZippedPartitionsRDD2.compute(ZippedPartitionsRDD.scala:89) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373) at org.apache.spark.rdd.RDD.iterator(RDD.scala:337) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:131) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1470) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) Caused by: java.util.concurrent.ExecutionException: java.lang.ClassCastException: com.intel.analytics.bigdl.dllib.utils.Table cannot be cast to com.intel.analytics.bigdl.dllib.tensor.Tensor at java.util.concurrent.FutureTask.report(FutureTask.java:122) at java.util.concurrent.FutureTask.get(FutureTask.java:192) at com.intel.analytics.bigdl.dllib.utils.ThreadPool.invokeAndWait2(ThreadPool.scala:172) ... 12 more Caused by: java.lang.ClassCastException: com.intel.analytics.bigdl.dllib.utils.Table cannot be cast to com.intel.analytics.bigdl.dllib.tensor.Tensor at com.iqiyi.read.big.util.ESMMCriterion.updateOutput(ESMMCriterion.scala:12) at com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion.forward(AbstractCriterion.scala:75) at com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$.$anonfun$optimize$7(DistriOptimizer.scala:272) at scala.runtime.java8.JFunction0$mcI$sp.apply(JFunction0$mcI$sp.java:23) at com.intel.analytics.bigdl.dllib.utils.ThreadPool$$anon$4.call(ThreadPool.scala:161) at java.util.concurrent.FutureTask.run(FutureTask.java:266) ... 3 more


Reply all
Reply to author
Forward
0 new messages