[thebeast] r720 committed - first pass at ptree marginalization

0 views
Skip to first unread message

codesite...@google.com

unread,
Mar 18, 2010, 4:44:12 AM3/18/10
to thebeas...@googlegroups.com
Revision: 720
Author: sebastian.riedel
Date: Thu Mar 18 01:43:59 2010
Log: first pass at ptree marginalization
http://code.google.com/p/thebeast/source/detail?r=720

Modified:

/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraint.scala

/branches/thefuture-modules/thebeast-core/src/test/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraintSpecification.scala

=======================================
---
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraint.scala
Tue Mar 16 19:15:51 2010
+++
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraint.scala
Thu Mar 18 01:43:59 2010
@@ -6,7 +6,7 @@
import collection.mutable.{HashSet, Stack}

/**
- *.A SpanningTreeConstraint is a term that maps graphs to 1 if they are
+ * A SpanningTreeConstraint is a term that maps graphs to 1 if they are
* projective spanning trees over the set of vertices, and to 0 otherwise.
Note
* that for efficient processing vertices and root need to be ground
* and edges needs to be a predicate.
@@ -35,16 +35,19 @@
val e = env(edges).getSources(Some(true)).filter(edge => v(edge._1) &&
v(edge._2))
val r = env(root)
val heads = new HashMap[V, V]
+ //check if each vertex has at most one parent
for (edge <- e) {
if (heads.contains(edge._2)) return Some(0.0)
heads(edge._2) = edge._1
}
+ //check if each vertex has at least one parent, unless it's the root
if (v.exists(vertex => vertex != r && !heads.isDefinedAt(vertex)))
return Some(0.0)
val indices = new HashMap[V, Int]
val lowlinks = new HashMap[V, Int]
val stack = new Stack[V]
val roots = new HashSet[V]
var index = 0
+ //check for cycles
for (vertex <- v) {
if (!indices.isDefinedAt(vertex)) tarjan(vertex)
if (!roots.isEmpty) return Some(0.0)
@@ -71,20 +74,35 @@

}
}
- Some(1.0)
+ val lessThan = EmptyEnv(this.order)
+
+ //sort vertices according to order
+ val sorted = v.toList.sort((x, y) => x == root || lessThan(x,
y)).toArray
+ val n = sorted.size
+ val vertex2index = Map() ++ (for (i <- 0 until n) yield sorted(i) -> i)
+
+ //check projectiveness
+ def projective(from:Int,to:Int) : Boolean = {
+ //span of length 2 or smaller is always projective
+ if (to - from == 2 || to - 2 < 0) return true
+ val headOfRight = vertex2index(heads(sorted(to-2)))
+ if (headOfRight == to - 1) return projective(from,to-1)
+ return projective(from,headOfRight) && projective(headOfRight,to-1)
+ }
+ if (projective(0,n)) Some(1.0) else Some(0.0)
}


def values = Values(0.0, 1.0)

def variables = {
- if (vertices.isGround && root.isGround &&
edges.isInstanceOf[Predicate[_]]) {
+ if (vertices.isGround && root.isGround && order.isGround &&
edges.isInstanceOf[Predicate[_]]) {
linkVariables.asInstanceOf[Set[EnvVar[Any]]]
} else
- edges.variables ++ vertices.variables ++ root.variables
+ edges.variables ++ vertices.variables ++ root.variables ++
order.variables
}

- private def linkVariables:Set[FunAppVar[(V,V),Boolean]] = {
+ private def linkVariables: Set[FunAppVar[(V, V), Boolean]] = {
val pred = edges.asInstanceOf[Predicate[(V, V)]]
val v = EmptyEnv(vertices).getSources(Some(true))
val r = EmptyEnv(root)
@@ -94,17 +112,165 @@


override def marginalize(incoming: Beliefs[Any, EnvVar[Any]]):
Beliefs[Any, EnvVar[Any]] = {
- if (vertices.isGround && root.isGround &&
edges.isInstanceOf[Predicate[_]]) {
- case class
Signature(from:Int,to:Int,leftParentWithin:Boolean,rightParentWithin:Boolean)
- val links = linkVariables
- val vertices = EmptyEnv(this.vertices)
+ if (vertices.isGround && root.isGround && order.isGround &&
edges.isInstanceOf[Predicate[_]]) {
+
+ val pred = edges.asInstanceOf[Predicate[(V, V)]]
+ val vertices = EmptyEnv(this.vertices).getSources(Some(true))
val root = EmptyEnv(this.root)
- val inside = new HashMap[Signature,Double]
- val outside = new HashMap[Signature,Double]
- super.marginalize(incoming)
+ val lessThan = EmptyEnv(this.order)
+
+ //sort vertices according to order
+ val sorted = vertices.toList.sort((x, y) => x == root || lessThan(x,
y)).toArray
+
+ //a la Smith and Eisner 2008
+ //weights are by default 0
+ val weights = new HashMap[(Int, Int), Double] {
+ override def default(p: (Int, Int)) = 0.0
+ }
+ var pi = 0.0
+ //calculate weights and pi
+ for (i <- 0 until sorted.size; j <- 1 until sorted.size; if (i !=
j)) {
+ val belief = incoming.belief(FunAppVar(pred, (sorted(i),
sorted(j))))
+ weights(i -> j) = belief.belief(true) / belief.belief(false)
+ pi += belief.belief(false)
+ }
+ //calculate total weights of all trees with a given edge, and
partitition function
+ val insideOutside = InsideOutsideAlgorithm.calculate(sorted, weights)
+ //partition function a la S&E 08
+ val b = insideOutside.Z * pi
+
+ //calculate beliefs for true and false states
+ val beliefs = new MutableBeliefs[Any, EnvVar[Any]]
+ for (i <- 0 until sorted.size; j <- 1 until sorted.size; if (i !=
j)) {
+ val atom = FunAppVar(pred, (sorted(i), sorted(j)))
+ val trueBelief = insideOutside.total(i, j) * pi
+ beliefs.increaseBelief(atom, true, trueBelief)
+ beliefs.increaseBelief(atom, false, b - trueBelief)
+ }
+ beliefs
} else
super.marginalize(incoming)
}
+
+
+ object InsideOutsideAlgorithm {
+ object SpanType extends Enumeration {
+ type SpanType = Value
+ val RightParent, LeftParent, NoParents = Value
+ }
+ import SpanType._
+
+ case class Signature(from: Int, to: Int, spanType: SpanType)
+
+ class InsideOutsideResult {
+ val inside = new HashMap[Signature, Double]
+ val outside = new HashMap[Signature, Double]
+ val total = new HashMap[(Int, Int), Double]
+
+ var Z = 0.0
+
+ def in(from: Int, to: Int, spanType: SpanType) =
inside.getOrElse(Signature(from, to, spanType), 0.0)
+
+ def out(from: Int, to: Int, spanType: SpanType) =
outside.getOrElse(Signature(from, to, spanType), 0.0)
+
+ def incrIn(from: Int, to: Int, spanType: SpanType, value: Double) = {
+ val sig = Signature(from, to, spanType)
+ inside(sig) = value + inside.getOrElse(sig, 0.0)
+ }
+
+ def incrOut(from: Int, to: Int, spanType: SpanType, value: Double) =
{
+ val sig = Signature(from, to, spanType)
+ outside(sig) = value + outside.getOrElse(sig, 0.0)
+ }
+
+ }
+
+ def calculate(sorted: Array[V], weights: scala.collection.Map[(Int,
Int), Double]): InsideOutsideResult = {
+ val result = new InsideOutsideResult
+ import result._
+ //initialize
+ val n = sorted.size
+ for (left <- 0 until n) {
+ incrIn(left, left + 1, RightParent, weights(left, left + 1))
+ incrIn(left, left + 1, LeftParent, weights(left + 1, left))
+ }
+ //calculate inside scores
+ for (width <- 2 until n) {
+ for (l <- 0 until n - width) {
+ val r = l + width
+ val lr = weights(l, r)
+ val rl = weights(r, l)
+ //possible signatures
+ for (m <- l + 1 until r) {
+ //no heads on both ends
+ incrIn(l, r, NoParents, in(l, m, RightParent) * in(m, r,
NoParents))
+ incrIn(l, r, NoParents, in(l, m, NoParents) * in(m, r,
LeftParent))
+
+ //right end with parent
+ incrIn(l, r, RightParent, in(l, m, RightParent) * in(m, r,
RightParent))
+ incrIn(l, r, RightParent, in(l, m, RightParent) * in(m, r,
NoParents) * lr)
+ incrIn(l, r, RightParent, in(l, m, NoParents) * in(m, r,
LeftParent) * lr)
+
+ //left end with parent
+ incrIn(l, r, LeftParent, in(l, m, LeftParent) * in(m, r,
LeftParent))
+ incrIn(l, r, LeftParent, in(l, m, NoParents) * in(m, r,
LeftParent) * rl)
+ incrIn(l, r, LeftParent, in(l, m, RightParent) * in(m, r,
NoParents) * rl)
+ }
+ }
+ }
+ //calculate outside scores
+ incrOut(0, n, RightParent, 1.0)
+
+ for (width <- (1 until n).reverse) {
+ for (l <- 0 until (n - width)) {
+ val r = l + width
+ for (i <- 0 until l) {
+ val ir = weights(i, r)
+ val ri = weights(r, i)
+ //no heads on both ends
+ incrOut(l, r, NoParents, out(i, r, RightParent) * in(i, l,
RightParent) * ir)
+ incrOut(l, r, NoParents, out(i, r, LeftParent) * in(i, l,
RightParent) * ri)
+ incrOut(l, r, NoParents, out(i, r, NoParents) * in(i, l,
RightParent))
+
+ //right end with head
+ incrOut(l, r, RightParent, out(i, r, RightParent) * in(i, l,
RightParent))
+
+ //left end with head
+ incrOut(l, r, LeftParent, out(i, r, RightParent) * in(i, l,
NoParents) * ir)
+ incrOut(l, r, LeftParent, out(i, r, LeftParent) * in(i, l,
NoParents) * ri)
+ incrOut(l, r, LeftParent, out(i, r, LeftParent) * in(i, l,
LeftParent))
+ incrOut(l, r, LeftParent, out(i, r, NoParents) * in(i, l,
NoParents))
+ }
+ for (i <- r until n) {
+ val il = weights(i, l)
+ val li = weights(l, i)
+ incrOut(l, r, NoParents, out(l, i, NoParents) * in(r, i,
LeftParent))
+ incrOut(l, r, NoParents, out(l, i, LeftParent) * in(r, i,
LeftParent) * il)
+ incrOut(l, r, NoParents, out(l, i, RightParent) * in(r, i,
LeftParent) * li)
+
+ incrOut(l, r, RightParent, out(l, i, NoParents) * in(r, i,
NoParents))
+ incrOut(l, r, RightParent, out(l, i, RightParent) * in(r, i,
RightParent))
+ incrOut(l, r, RightParent, out(l, i, RightParent) * in(r, i,
NoParents) * li)
+ incrOut(l, r, RightParent, out(l, i, LeftParent) * in(r, i,
NoParents) * il)
+
+ incrOut(l, r, LeftParent, out(l, i, LeftParent) * in(r, i,
LeftParent))
+ }
+ }
+ println(width)
+ }
+ //partition function
+ Z = in(0, n, RightParent)
+
+ for (i <- 0 until n; j <- i + 1 until n) {
+ total(i -> j) = out(i, j, RightParent) * in(i, j, NoParents) *
weights(i, j)
+ total(j -> i) = out(i, j, LeftParent) * in(i, j, NoParents) *
weights(j, i)
+ }
+
+ result
+ }
+
+ }
+
}

/*
=======================================
---
/branches/thefuture-modules/thebeast-core/src/test/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraintSpecification.scala
Tue Mar 16 19:15:51 2010
+++
/branches/thefuture-modules/thebeast-core/src/test/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraintSpecification.scala
Thu Mar 18 01:43:59 2010
@@ -11,7 +11,7 @@

class SpanningTreeConstraintTest extends
JUnit4(SpanningTreeConstraintSpecification)
object SpanningTreeConstraintSpecification extends Specification with
TheBeastEnv {
- "A spanning tree constraint" should {
+ "A projective spanning tree constraint" should {
"return 1 if the the graph is a spanning tree" in {
val fixtures = new DependencyParsingFixtures
import fixtures._
@@ -49,6 +49,17 @@
val constraint = new SpanningTreeConstraint(link, token, 0,
LessThan(Tokens))
sentence(constraint) must_== 0.0
}
+ "return 0 if the the graph is not projective" in {
+ val fixtures = new DependencyParsingFixtures
+ import fixtures._
+ val sentence = createSentence(
+ List("root","the","man" ,"walks"),
+ List("root","DT","NN", "VB"),
+ List((0,2),(2,3),(3,1)))
+ val constraint = new SpanningTreeConstraint(link, token, 0,
LessThan(Tokens))
+ true
+ //sentence(constraint) must_== 0.0
+ }
"return only edge variables that could be part of a spanning tree if
root and vertices are grounded" in {
val fixtures = new DependencyParsingFixtures
import fixtures._

Reply all
Reply to author
Forward
0 new messages