Modified:
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraint.scala
=======================================
---
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraint.scala
Fri Mar 19 00:35:19 2010
+++
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraint.scala
Sat Mar 20 12:45:27 2010
@@ -149,7 +149,7 @@
val beliefs = new MutableBeliefs[Any, EnvVar[Any]]
for (i <- 0 until sorted.size; j <- 1 until sorted.size; if (i !=
j)) {
val atom = FunAppVar(pred, (sorted(i), sorted(j)))
- println("%d %d: %f".format(i,j,insideOutside.total(i, j)))
+ println("%d %d: %f".format(i, j, insideOutside.total(i, j)))
val trueBelief = insideOutside.total(i, j) * pi
beliefs.increaseBelief(atom, true, trueBelief)
beliefs.increaseBelief(atom, false, b - trueBelief)
@@ -164,137 +164,104 @@
object SpanType extends Enumeration {
type SpanType = Value
val RightParent, LeftParent, NoParents = Value
+
+ val parents = Seq(LeftParent, RightParent)
+
+ def opposite(value: Value) = value match {
+ case LeftParent => RightParent
+ case RightParent => LeftParent
+ case _ => NoParents
+ }
+
+ def toType(left: Boolean, right: Boolean): Value = {
+ if (left && !right) LeftParent
+ else if (!left && right) RightParent
+ else if (!left && !right) NoParents
+ else null
+ }
+
}
import SpanType._
- case class Signature(from: Int, to: Int, spanType: SpanType, simple:
Boolean) {
- override def toString: String = "(%d,%d,%s,%s)".format(from, to,
spanType match {
- case RightParent => "R"
- case LeftParent => "L"
- case NoParents => "N"
- }, simple)
- }
+
+ case class Signature(from: Int, to: Int, left: Boolean, right:
Boolean, simple: Boolean) {
+ override def toString: String = "(%d,%d,%s,%s,%s)".format(from, to,
left, right, simple)
+ }
+ sealed trait Operation {
+ def eval:Signature
+ }
+ case class Seed(index: Int) extends Operation {
+ val eval = Signature(index,index+1,false,false,true)
+ }
+ case class CloseRight(sig: Signature) extends Operation{
+ val eval = Signature(sig.from, sig.to, true, false, true)
+ }
+ case class CloseLeft(sig: Signature) extends Operation {
+ val eval = Signature(sig.from, sig.to, false, true, true)
+ }
+ case class Join(l: Signature, r: Signature) extends Operation {
+ val eval = Signature(l.from,r.to,l.left,r.right,false)
+ val defined = l.right != r.left && l.simple
+ }
+
+
class InsideOutsideResult {
val inside = new HashMap[Signature, Double]
val outside = new HashMap[Signature, Double]
val total = new HashMap[(Int, Int), Double]
-
var Z = 0.0
- def in(from: Int, to: Int, spanType: SpanType, simple: Boolean):
Double =
- inside.getOrElse(Signature(from, to, spanType, simple), 0.0)
-
- def in(from: Int, to: Int, spanType: SpanType): Double =
- in(from, to, spanType, true) + in(from, to, spanType, false)
-
- def out(from: Int, to: Int, spanType: SpanType) =
- outside.getOrElse(Signature(from, to, spanType, false), 0.0)
-
- def incrIn(from: Int, to: Int, spanType: SpanType, simple: Boolean,
value: Double) = {
- val sig = Signature(from, to, spanType, simple)
- inside(sig) = value + inside.getOrElse(sig, 0.0)
- }
-
- def incrOut(from: Int, to: Int, spanType: SpanType, value: Double) =
{
- val sig = Signature(from, to, spanType, false)
- outside(sig) = value + outside.getOrElse(sig, 0.0)
- }
+ def incrOut(sig:Signature, value:Double) =
+ outside(sig) = outside.getOrElse(sig, 0.0) + value
+ def incrIn(sig:Signature, value:Double) =
+ inside(sig) = inside.getOrElse(sig, 0.0) + value
+
+ def in(sig:Signature) = inside.getOrElse(sig,0.0)
+ def out(sig:Signature) = outside.getOrElse(sig,0.0)
}
def calculate(sorted: Array[V], weights: scala.collection.Map[(Int,
Int), Double]): InsideOutsideResult = {
+
+ val bools = Array(false,true)
val result = new InsideOutsideResult
import result._
- //initialize
- val n = sorted.size
- for (left <- 0 until n - 1) {
- incrIn(left, left + 1, NoParents, true, 1.0)
- incrIn(left, left + 1, RightParent, true, weights(left, left + 1))
- incrIn(left, left + 1, LeftParent, true, weights(left + 1, left))
- }
- //calculate inside scores
- for (width <- 2 until n) {
- for (l <- 0 until n - width) {
- val r = l + width
- val lr = weights(l, r)
- val rl = weights(r, l)
- //possible signatures
- for (m <- l + 1 until r) {
- //no heads on both ends
- incrIn(l, r, NoParents, false, in(l, m, RightParent, true) *
in(m, r, NoParents))
- incrIn(l, r, NoParents, false, in(l, m, NoParents, true) *
in(m, r, LeftParent))
-
- //right end with parent
- incrIn(l, r, RightParent, false, in(l, m, RightParent, true) *
in(m, r, RightParent))
- incrIn(l, r, RightParent, true, in(l, m, RightParent, true) *
in(m, r, NoParents) * lr)
- incrIn(l, r, RightParent, true, in(l, m, NoParents, true) *
in(m, r, LeftParent) * lr)
-
- //left end with parent
- incrIn(l, r, LeftParent, false, in(l, m, LeftParent, true) *
in(m, r, LeftParent))
- incrIn(l, r, LeftParent, true, in(l, m, NoParents, true) *
in(m, r, LeftParent) * rl)
- incrIn(l, r, LeftParent, true, in(l, m, RightParent, true) *
in(m, r, NoParents) * rl)
- }
- }
- }
- //calculate outside scores
- incrOut(0, n - 1, RightParent, 1.0)
- incrOut(0, n - 1, NoParents, 1.0)
- incrOut(0, n - 1, LeftParent, 1.0)
-
- for (width <- (1 until n-1).reverse) {
- println("Width: " + width)
- for (l <- 0 until (n - width)) {
- val r = l + width
- for (i <- 0 until l) {
- val ir = weights(i, r)
- val ri = weights(r, i)
- //no heads on both ends
- println("%d (%d %d)".format(i,l,r))
- incrOut(l, r, NoParents, in(i, l, RightParent, true) * out(i,
r, NoParents))
-
- //right end with head
- incrOut(l, r, RightParent, in(i, l, RightParent, true) *
out(i, r, RightParent))
-
- //left end with head
- incrOut(l, r, LeftParent, in(i, l, NoParents, true) * out(i,
r, NoParents))
- incrOut(l, r, LeftParent, in(i, l, LeftParent, true) * out(i,
r, LeftParent))
-
- println("%d %d N = %f".format(l,r,out(l,r,NoParents)))
- println("%d %d L = %f".format(l,r,out(l,r,LeftParent)))
- println("%d %d R = %f".format(l,r,out(l,r,RightParent)))
- }
- for (i <- r + 1 until n) {
- val il = weights(i, l)
- val li = weights(l, i)
- println("(%d %d) %d".format(l,r,i))
-
- incrOut(l, r, NoParents, in(r, i, LeftParent, true) * out(l,
i, NoParents))
-
- incrOut(l, r, LeftParent, in(r, i, LeftParent, true) * out(l,
i, LeftParent))
-
- incrOut(l, r, RightParent, in(r, i, NoParents, true) * out(l,
i, NoParents))
- incrOut(l, r, RightParent, in(r, i, RightParent, true) *
out(l, i, RightParent))
-
- println("%d %d N = %f".format(l,r,out(l,r,NoParents)))
- println("%d %d L = %f".format(l,r,out(l,r,LeftParent)))
- println("%d %d R = %f".format(l,r,out(l,r,RightParent)))
- }
+ def add(op:Operation) = op match {
+ case Seed(_) => incrIn(op.eval, 1.0)
+ case Join(l,r) => incrIn(op.eval, in(l) * in(r))
+ case CloseLeft(sig @ Signature(i,j,_,_,_)) => incrIn(op.eval,
in(sig) * weights(i,j))
+ case CloseRight(sig @ Signature(i,j,_,_,_)) => incrIn(op.eval,
in(sig) * weights(j,i))
+ }
+ val n = sorted.size
+
+ //initialize inside probs
+ for (i <- 0 until n-1){
+ add(Seed(i))
+ add(CloseLeft(Signature(i,i+1,false,false,true)))
+ add(CloseRight(Signature(i,i+1,false,false,true)))
+ }
+ for (length <- 2 until n){
+ for (i <- 0 until n - length){
+ val j = i + length
+ for (k <- i + 1 until j){
+ println("%d %d %d".format(i,k,j))
+ for (b_L <- bools; b <- bools; b_R <- bools; s <- bools){
+ val sig_L = Signature(i,k,b_L,b,true)
+ val sig_R = Signature(k,j,!b, b_R, s)
+ val join = Join(sig_L, sig_R)
+ if (join.defined) add(join)
+ }
+ }
+ add(CloseLeft(Signature(i,j,false,false,false)))
+ add(CloseRight(Signature(i,j,false,false,false)))
}
}
- //partition function
- Z = in(0, n - 1, RightParent)
-
+ println(inside.mkString("\n"))
+ Z = in(Signature(0,n-1,false,true,false)) +
in(Signature(0,n-1,false,true,true))
println("Z: " + Z)
- for (i <- 0 until n; j <- i + 1 until n) {
- println("IN %d->%d R = %f".format(i, j, in(i, j,
RightParent,true)))
- println("OUT %d->%d R = %f".format(i, j, out(i, j, RightParent)))
- println("IN %d->%d L = %f".format(i, j, in(i, j,
LeftParent,true)))
- println("OUT %d->%d L = %f".format(i, j, out(i, j, LeftParent)))
- total(i -> j) = out(i, j, RightParent) * in(i, j, RightParent,
true)
- total(j -> i) = out(i, j, LeftParent) * in(i, j, LeftParent, true)
- }
+
result
}
@@ -302,6 +269,7 @@
}
}
+
/*