[thebeast] r722 committed - exhaustive marginalization now passes specs

1 view
Skip to first unread message

codesite...@google.com

unread,
Mar 18, 2010, 6:13:41 PM3/18/10
to thebeas...@googlegroups.com
Revision: 722
Author: sebastian.riedel
Date: Thu Mar 18 15:12:57 2010
Log: exhaustive marginalization now passes specs
http://code.google.com/p/thebeast/source/detail?r=722

Modified:

/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/Beliefs.scala

/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraint.scala

/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/solve/MarginalInference.scala

/branches/thefuture-modules/thebeast-core/src/test/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraintSpecification.scala

=======================================
---
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/Beliefs.scala
Thu Nov 12 23:36:25 2009
+++
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/Beliefs.scala
Thu Mar 18 15:12:57 2010
@@ -74,7 +74,7 @@
result
}

- override def toString = beliefs.toString + (if (expectations.size
>0) "\n" + expectations.toString else "")
+ override def toString = beliefs.mkString("\n") + (if (expectations.size
>0) "\n" + expectations.toString else "")

def increaseExpectation[N](term:NumericTerm[N], value:N, prob:Double) = {
val expectation = expectations.getOrElseUpdate(
=======================================
---
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraint.scala
Thu Mar 18 09:34:31 2010
+++
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraint.scala
Thu Mar 18 15:12:57 2010
@@ -2,7 +2,7 @@

import org.riedelcastro.thebeast.env._
import doubles.{DoubleConstant, DoubleTerm}
-import collection.mutable.{HashSet, Stack, HashMap}
+import collection.mutable.{HashSet, Stack, HashMap, MultiMap}

/**
* A SpanningTreeConstraint is a term that maps graphs to 1 if they are
@@ -31,7 +31,7 @@
def eval(env: Env): Option[Double] = {
//get edges map
val v = Set() ++ env(vertices).getSources(Some(true))
- val e = env(edges).getSources(Some(true)).filter(edge => v(edge._1) &&
v(edge._2))
+ val e = env(this.edges).getSources(Some(true)).filter(edge =>
v(edge._1) && v(edge._2))
val r = env(root)
val heads = new HashMap[V, V]
//check if each vertex has at most one parent
@@ -81,28 +81,20 @@
val n = sorted.size
val vertex2index = Map() ++ (for (i <- 0 until n) yield sorted(i) -> i)
//mapping from vertex to children
-
-
- def projective(from: Int, to: Int): Boolean = {
- if (to - from <= 1) return true
- var current = from + 1
- while (current < to) {
- val currentVertex = sorted(current)
- val headVertex = heads(currentVertex)
- val head = vertex2index(headVertex)
- if (head < from || head > to) return false
- if (head > current) {
- if (!projective(current, head)) return false
- current = head
- } else {
- current += 1
- }
- }
- true
- }
-
-
- if (projective(0, n-1)) Some(1.0) else Some(0.0)
+ val edges = for (i <- 1 until n) yield
(vertex2index(heads(sorted(i))),i)
+
+ def cross(e1:(Int,Int), e2:(Int,Int)) : Boolean = {
+ val e1l = Math.min(e1._1,e1._2)
+ val e1r = Math.max(e1._1,e1._2)
+ val e2l = Math.min(e2._1,e2._2)
+ val e2r = Math.max(e2._1,e2._2)
+ !(e1l >= e2l && e1r <= e2r || e2l >= e1l && e2r <= e1r || e2r <= e1l
|| e1r <= e2l)
+ }
+ //todo this should be doable in O(n)
+ for (e1 <- edges; e2 <- edges; if (e1 != e2)) {
+ if (cross(e1,e2)) return Some(0.0)
+ }
+ Some(1.0)
}


=======================================
---
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/solve/MarginalInference.scala
Thu Nov 12 23:36:25 2009
+++
/branches/thefuture-modules/thebeast-core/src/main/scala/org/riedelcastro/thebeast/solve/MarginalInference.scala
Thu Mar 18 15:12:57 2010
@@ -23,8 +23,8 @@
}

def marginalizeQueries[N](term: DoubleTerm,
- incoming: Beliefs[Any, EnvVar[Any]],
- queries: Collection[NumericTerm[N]]):
Beliefs[Any, EnvVar[Any]] = {
+ incoming: Beliefs[Any, EnvVar[Any]],
+ queries: Collection[NumericTerm[N]]):
Beliefs[Any, EnvVar[Any]] = {
val multiplied = term * Multiplication(term.variables.map(v =>
BeliefTerm(incoming.belief(v), v)).toSeq)
ExhaustiveMarginalInference.inferQueries(multiplied, queries)
}
@@ -46,21 +46,24 @@
}

private def inferQueriesExhaustively[N](term: DoubleTerm,
- queries:
Collection[NumericTerm[N]]): Beliefs[Any, EnvVar[Any]] = {
+ queries:
Collection[NumericTerm[N]]): Beliefs[Any, EnvVar[Any]] = {
val domain = term.variables.toSeq
gatherBeliefs(term, domain, queries)
}

private def gatherBeliefs[N](term: DoubleTerm,
- variables: Collection[EnvVar[Any]],
- queries: Collection[NumericTerm[N]]):
Beliefs[Any, EnvVar[Any]] = {
+ variables: Collection[EnvVar[Any]],
+ queries: Collection[NumericTerm[N]]):
Beliefs[Any, EnvVar[Any]] = {
|**("Exhaustive marginal inference for " + term)

- val beliefs = new MutableBeliefs[Any,EnvVar[Any]]
+ val beliefs = new MutableBeliefs[Any, EnvVar[Any]]

Env.forall(variables) {
env => {
val score = env(term);
+ if (score > 0) {
+ println(env)
+ }
for (variable <- variables) {
beliefs.increaseBelief(variable, env(variable), score)
}
=======================================
---
/branches/thefuture-modules/thebeast-core/src/test/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraintSpecification.scala
Thu Mar 18 09:34:31 2010
+++
/branches/thefuture-modules/thebeast-core/src/test/scala/org/riedelcastro/thebeast/env/combinatorics/SpanningTreeConstraintSpecification.scala
Thu Mar 18 15:12:57 2010
@@ -3,7 +3,8 @@
import org.specs.Specification
import org.riedelcastro.thebeast.DependencyParsingFixtures
import org.specs.runner.{JUnit4}
-import org.riedelcastro.thebeast.env.{Constant, LessThan, FunAppVar,
TheBeastEnv}
+import org.riedelcastro.thebeast.env._
+import org.riedelcastro.thebeast.solve.ExhaustiveMarginalInference

/**
* @author sriedel
@@ -11,6 +12,24 @@

class SpanningTreeConstraintTest extends
JUnit4(SpanningTreeConstraintSpecification)
object SpanningTreeConstraintSpecification extends Specification with
TheBeastEnv {
+ val trees = Seq(
+ List((0, 1), (0, 2), (0, 3)),
+ List((0, 1), (1, 2), (0, 3)),
+ List((0, 1), (3, 2), (0, 3)),
+ List((0, 1), (0, 2), (2, 3)),
+ List((0, 1), (1, 2), (2, 3)),
+ List((0, 1), (1, 3), (3, 2)),
+ List((0, 1), (1, 2), (1, 3)),
+ List((0, 2), (0, 3), (2, 1)),
+ List((0, 2), (2, 1), (2, 3)),
+ List((0, 3), (3, 1), (3, 2)),
+ List((0, 3), (3, 1), (1, 2)),
+ List((0, 3), (3, 2), (2, 1)))
+
+ val counts = Map() ++ {for (i <- 0 until 4; j <- 1 until 4; if (j!=i))
yield
+ (i,j)-> trees.filter(_.contains((i,j))).size.toDouble
+ }
+
"A projective spanning tree constraint" should {
"return 1 if the the graph is a spanning tree" in {
val fixtures = new DependencyParsingFixtures
@@ -55,13 +74,17 @@
"return 0 if the the graph is not projective" in {
val fixtures = new DependencyParsingFixtures
import fixtures._
- val sentence = createSentence(
+ def theManWalks(edges: List[(Int, Int)]) = createSentence(
List("root", "the", "man", "walks"),
- List("root", "DT", "NN", "VB"),
- List((0, 2), (2, 3), (3, 1)))
+ List("root", "DT", "NN", "VB"), edges)
+
+ val nonProjective = Seq(
+ List((0, 2), (2, 3), (3, 1)),
+ List((0, 2), (1, 3), (2, 1)))
val constraint = new SpanningTreeConstraint(link, token, 0,
LessThan(Tokens))
- sentence(constraint) must_== 0.0
- }
+ nonProjective.map(theManWalks(_)).forall(sentence =>
sentence(constraint) must_== 0.0)
+ }
+
"return only edge variables that could be part of a spanning tree if
root and vertices are grounded" in {
val fixtures = new DependencyParsingFixtures
import fixtures._
@@ -71,6 +94,23 @@
val result = constraint.ground(sentence.mask(Set(link))).variables
result must_== expected
}
+ "return exact marginals with exhaustive inference" in {
+ val fixtures = new DependencyParsingFixtures
+ import fixtures._
+ val sentence = createSentence(
+ List("root", "the", "man", "walks"),
+ List("root", "DT", "NN", "VB"),
+ List((0, 3), (3, 2), (2, 1)))
+ val constraint = new SpanningTreeConstraint(link, token, 0,
LessThan(Tokens))
+ val grounded = constraint.ground(sentence.mask(Set(link)))
+ val incoming = new CompleteIgnorance[Any, EnvVar[Any]]
+ val exact = ExhaustiveMarginalInference.marginalize(grounded,
incoming)
+ for (edge <- counts.keySet) {
+ exact.belief(FunAppVar(link,edge)).belief(true) must_==
counts(edge)
+ }
+ }
+
+
}

}

Reply all
Reply to author
Forward
0 new messages