[cspoker] r1390 committed - Added concept drift (beta version)

3 views
Skip to first unread message

codesite...@google.com

unread,
Aug 17, 2010, 8:16:37 PM8/17/10
to cspoker...@googlegroups.com
Revision: 1390
Author: laurent.verbruggen
Date: Tue Aug 17 17:15:33 2010
Log: Added concept drift (beta version)
http://code.google.com/p/cspoker/source/detail?r=1390

Modified:
/trunk/ai/experiments/src/main/java/org/cspoker/ai/bots/BotRunner.java

/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ARFFFile.java

/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ARFFPlayer.java

/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ARFFPropositionalizer.java

/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ActionTrackingVisitor.java

/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/WekaLearningModel.java

/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/WekaOptions.java

=======================================
--- /trunk/ai/experiments/src/main/java/org/cspoker/ai/bots/BotRunner.java
Tue Aug 17 11:28:01 2010
+++ /trunk/ai/experiments/src/main/java/org/cspoker/ai/bots/BotRunner.java
Tue Aug 17 17:15:33 2010
@@ -107,7 +107,7 @@
public static void create(RemoteCSPokerServer cspokerServer) {
kullbackLeibler = new KullbackLeiblerListener(reportInterval);
// new BotRunner(cspokerServer, "bucketSampler0.01VsRulebots");
- new BotRunner(cspokerServer, "NaiveBotsAccuracy2", getBots());
+ new BotRunner(cspokerServer, "test", getBots());
}

public static BotFactory[] getBots() {
@@ -123,8 +123,8 @@
configNoPersist.setUseOnlineLearning(false);

WekaOptions configPersist = new WekaOptions();
- configPersist.setContinuousLearning(false);
- configPersist.setModelCreationTreshold(3000);
+ configPersist.setContinuousLearning(true);
+ configPersist.setModelCreationTreshold(800);
configPersist.setContinueAfterCreation(false);

Sampler s = new BucketSampler(0.01);
@@ -133,9 +133,10 @@
// Sampler s = new RandomSampler(9);

return new BotFactory[] {
- new CallBotFactory("CallBot"), // 62% precision, 71% accuracy
- new CardBotFactory("CardBot"), // 60% precision, 71% accuracy
- new HandBotFactory("HandBot"), // 41% accuracy
+ new AlternatingBotFactory("AlternatingBot"), // 62% precision, 71%
accuracy
+// new CallBotFactory("CallBot"), // 62% precision, 71% accuracy
+// new CardBotFactory("CardBot"), // 60% precision, 71% accuracy
+// new HandBotFactory("HandBot"), // 41% accuracy
// new FixedSampleMCTSBotFactory("MCTSBot",
//
WekaRegressionModelFactory.createForZip("org/cspoker/ai/opponentmodels/weka/models/model1.zip",
// configNoPersist),
@@ -169,7 +170,7 @@
new MCTSShowdownRollOutNode.Factory(),
new SampleWeightedBackPropStrategy.Factory(),
s,
- 500
+ 250
),
// new SearchBotFactory( // 42,5% accuracy
//
WekaRegressionModelFactory.createForZip("org/cspoker/ai/opponentmodels/weka/models/model1.zip",
configNoPersist),
=======================================
---
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ARFFFile.java
Mon Aug 16 10:38:42 2010
+++
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ARFFFile.java
Tue Aug 17 17:15:33 2010
@@ -1,6 +1,7 @@
package org.cspoker.ai.opponentmodels.weka;

import java.io.*;
+import java.util.ArrayList;

import org.cspoker.ai.opponentmodels.weka.instances.InstancesBuilder;

@@ -23,9 +24,13 @@
private Writer file;
private long count = 0;
private WekaOptions config;
+
+ private Instances instances;
+ private ArrayList<Prediction> predictions;
+ private M5P cl = null;

public ARFFFile(String path, Object player, String name, String
attributes,
- WekaOptions config) throws IOException {
+ WekaOptions config) throws Exception {
this.path = path;
this.player = player;
this.name = name;
@@ -35,6 +40,13 @@
file = new BufferedWriter(new FileWriter(path + player + name, false));
file.write(attributes);
file.flush();
+
+ DataSource source = new DataSource(path + player + name);
+ instances = source.getDataSet();
+ // make it clean
+ instances.delete();
+
+ predictions = new ArrayList<Prediction>();
}

// private double countDataLines() {
@@ -79,11 +91,78 @@
count++;
file.write(instance.toString() + nl);
file.flush();
+ instances.add(instance);
+// if (count != instances.numInstances())
+// System.err.println("PROBLEM");
+ adjustWindow();
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
-
+
+ public void addPrediction(Prediction p) {
+ predictions.add(p);
+ }
+
+ public double getWindowSize() {
+ return instances.numInstances();
+ }
+
+ public double getAccuracy() {
+ if (predictions.isEmpty()) return 0.0;
+ double truePositive = 0.0;
+ double trueNegative = 0.0;
+ double falsePositive = 0.0;
+ double falseNegative = 0.0;
+ for (int i = 0; i < predictions.size(); i++) {
+ Prediction p = predictions.get(i);
+ if (p != null) {
+ truePositive += p.getTruePositive();
+ trueNegative += p.getTrueNegative();
+ falsePositive += p.getFalsePositive();
+ falseNegative += p.getFalseNegative();
+ }
+ }
+ return (trueNegative + truePositive) /
+ (trueNegative + truePositive + falseNegative + falsePositive);
+ }
+
+ private double prevAcc = 0.0;
+
+ private boolean decreasingAcc(double accuracy) {
+ double diffAcc = accuracy - prevAcc;
+ prevAcc = accuracy;
+ return (diffAcc < -0.01);
+ }
+
+ private void adjustWindow() {
+ if (cl == null) return;
+ double windowSize = instances.numInstances();
+ double coverage = windowSize / cl.measureNumRules();
+ double accuracy = getAccuracy();
+ double l;
+ if ((coverage < config.getCdLowCoverage()) ||
+ (accuracy < config.getCdAccuracy() && decreasingAcc(accuracy)))
+ l = Math.round(0.2 * windowSize);
+ else if (coverage > 2 * config.getCdHighCoverage() &&
+ accuracy > config.getCdAccuracy())
+ l = 2;
+ else if (coverage > config.getCdHighCoverage() &&
+ accuracy > config.getCdAccuracy())
+ l = 1;
+ else
+ l = 0;
+
+ for (int i = 0; i < l; i++) {
+ instances.delete(0);
+ if (!predictions.isEmpty())
+ predictions.remove(0);
+ }
+
+// windowSize = windowSize - l;
+// System.out.println(name + ", " + windowSize + ", l: " + l + ", acc: "
+ accuracy + ", coverage: " + coverage);
+ }
+
public boolean isModelReady() {
return count > config.getMinimalLearnExamples();
}
@@ -98,9 +177,13 @@

public Classifier createModel(String fileName, String attribute, String[]
rmAttributes) throws Exception {
// System.out.println("Creating model for " + player + name);
- DataSource source = new DataSource(path + player + name);
-// System.out.println(source + " > " + path + player + name);
- Instances data = source.getDataSet();
+ Instances data;
+ if (config.solveConceptDrift())
+ data = instances;
+ else {
+ DataSource source = new DataSource(path + player + name);
+ data = source.getDataSet();
+ }
if (rmAttributes.length > 0) {
String[] optionsDel = new String[2];
optionsDel[0] = "-R";
@@ -118,13 +201,17 @@
data.setClass(data.attribute(attribute));

// train M5P
- M5P cl = new M5P();
+ cl = new M5P();
cl.setBuildRegressionTree(true);
cl.setUnpruned(false);
cl.setUseUnsmoothed(false);
// further options...
cl.buildClassifier(data);

+// System.out.println("Number of instances: " + data.numInstances());
+// System.out.println("Number of measures: " + cl.measureNumRules());
+// System.out.println(cl);
+
// save model + header
if (config.modelPersistency())
SerializationHelper.write(path + "../" + player + fileName
+ ".model", cl);
=======================================
---
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ARFFPlayer.java
Tue Aug 17 11:28:01 2010
+++
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ARFFPlayer.java
Tue Aug 17 17:15:33 2010
@@ -3,6 +3,7 @@
import java.io.*;

import org.apache.log4j.Logger;
+import org.cspoker.common.elements.player.PlayerId;

import weka.core.Instance;

@@ -24,16 +25,19 @@

private WekaRegressionModel model = null;
private WekaOptions config = new WekaOptions();
+ private ActionTrackingVisitor actions = null;

private long writeCounter = 0;

- public ARFFPlayer(Object player, WekaRegressionModel baseModel,
WekaOptions config) {
+ public ARFFPlayer(Object player, WekaRegressionModel baseModel,
WekaOptions config,
+ ActionTrackingVisitor actions) {
if (!config.useOnlineLearning())
throw new IllegalStateException("ARFFPlayer can only be used with
online learning!");

this.player = player;
this.config = config;
this.model = baseModel;
+ this.actions = actions;

try {
String path = (getClass().getProtectionDomain().getCodeSource()
@@ -51,6 +55,8 @@
ARFFPropositionalizer.getShowdownInstance().toString(), config);
} catch (IOException io) {
throw new RuntimeException(io);
+ } catch (Exception e) {
+ throw new RuntimeException("Unable to create set of instances");
}
}

@@ -69,6 +75,11 @@
"(learning examples: " + file.getNrExamples() +
" < " + config.getMinimalLearnExamples() + " required)";
}
+
+ public double getAccuracy() {
+ return actions.getAccuracy((PlayerId) player);
+ }
+
public void learnNewModel() {
// if (!(preCheckBetFile.isModelReady() &&
postCheckBetFile.isModelReady() && preFoldCallRaiseFile.isModelReady()
// && postFoldCallRaiseFile.isModelReady() &&
showdownFile.isModelReady())) {
@@ -110,6 +121,10 @@
public boolean modelCreated() {
return modelCreated;
}
+
+ public void addPreCheckBetPrediction(Prediction p) {
+ preCheckBetFile.addPrediction(p);
+ }

public void writePreCheckBet(Instance instance) {
if (writeAllowed()) {
@@ -132,6 +147,10 @@
e.printStackTrace();
}
}
+
+ public void addPostCheckBetPrediction(Prediction p) {
+ postCheckBetFile.addPrediction(p);
+ }

public void writePostCheckBet(Instance instance) {
if (writeAllowed()) {
@@ -154,6 +173,10 @@
e.printStackTrace();
}
}
+
+ public void addPreFoldCallRaisePrediction(Prediction p) {
+ preFoldCallRaiseFile.addPrediction(p);
+ }

public void writePreFoldCallRaise(Instance instance) {
if (writeAllowed()) {
@@ -180,6 +203,10 @@
e.printStackTrace();
}
}
+
+ public void addPostFoldCallRaisePrediction(Prediction p) {
+ postFoldCallRaiseFile.addPrediction(p);
+ }

public void writePostFoldCallRaise(Instance instance) {
if (writeAllowed()) {
@@ -205,7 +232,10 @@
} catch (Exception e) {
e.printStackTrace();
}
-
+ }
+
+ public void addShowdownPrediction(Prediction p) {
+ showdownFile.addPrediction(p);
}

public void writeShowdown(Instance instance) {
@@ -242,6 +272,12 @@

private void incrementWriteCounter() {
writeCounter++;
+ System.out.println(
+ preCheckBetFile.getAccuracy() + "\t" + preCheckBetFile.getWindowSize()
+ "\t" +
+ postCheckBetFile.getAccuracy() + "\t" +
postCheckBetFile.getWindowSize() + "\t" +
+ preFoldCallRaiseFile.getAccuracy() + "\t" +
preFoldCallRaiseFile.getWindowSize() + "\t" +
+ postFoldCallRaiseFile.getAccuracy() + "\t" +
postFoldCallRaiseFile.getWindowSize() + "\t" +
+ showdownFile.getAccuracy() + "\t" + showdownFile.getWindowSize());
// System.out.println("=" + writeCounter + "=");
}
}
=======================================
---
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ARFFPropositionalizer.java
Mon Aug 16 10:38:42 2010
+++
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ARFFPropositionalizer.java
Tue Aug 17 17:15:33 2010
@@ -111,7 +111,15 @@
}
}

-
+ protected void logFoldProb(Object actorId, Prediction p) {
+ if (actorId.equals(bot)) return;
+ if (getRound().equals("preflop")) {
+ getARFF(actorId).addPreFoldCallRaisePrediction(p);
+ } else {
+ getARFF(actorId).addPostFoldCallRaisePrediction(p);
+ }
+ }
+
@Override
protected void logCall(Object actorId) {
if (actorId.equals(bot)) return;
@@ -125,6 +133,15 @@
actorId, new Object[] { 0, 1, 0, "call" }));
}
}
+
+ protected void logCallProb(Object actorId, Prediction p) {
+ if (actorId.equals(bot)) return;
+ if (getRound().equals("preflop")) {
+ getARFF(actorId).addPreFoldCallRaisePrediction(p);
+ } else {
+ getARFF(actorId).addPostFoldCallRaisePrediction(p);
+ }
+ }

@Override
protected void logRaise(Object actorId, double raiseAmount) {
@@ -139,6 +156,15 @@
actorId, new Object[] { 0, 0, 1, "raise" }));
}
}
+
+ protected void logRaiseProb(Object actorId, Prediction p) {
+ if (actorId.equals(bot)) return;
+ if (getRound().equals("preflop")) {
+ getARFF(actorId).addPreFoldCallRaisePrediction(p);
+ } else {
+ getARFF(actorId).addPostFoldCallRaisePrediction(p);
+ }
+ }

@Override
protected void logCheck(Object actorId) {
@@ -153,6 +179,15 @@
new Object[] { 0, "check" }));
}
}
+
+ protected void logCheckProb(Object actorId, Prediction p) {
+ if (actorId.equals(bot)) return;
+ if (getRound().equals("preflop")) {
+ getARFF(actorId).addPreCheckBetPrediction(p);
+ } else {
+ getARFF(actorId).addPostCheckBetPrediction(p);
+ }
+ }

@Override
protected void logBet(Object actorId, double raiseAmount) {
@@ -167,6 +202,15 @@
new Object[] { 1, "bet" }));
}
}
+
+ protected void logBetProb(Object actorId, Prediction p) {
+ if (actorId.equals(bot)) return;
+ if (getRound().equals("preflop")) {
+ getARFF(actorId).addPreCheckBetPrediction(p);
+ } else {
+ getARFF(actorId).addPostCheckBetPrediction(p);
+ }
+ }

@Override
protected void logShowdown(Object actorId, double[] partitionDistr) {
@@ -181,4 +225,9 @@
getARFF(actorId).writeShowdown(
showdownInstance.getClassifiedInstance(this, actorId, targets));
}
-}
+
+ protected void logShowdownProb(Object actorId, Prediction p) {
+ if (actorId.equals(bot)) return;
+ getARFF(actorId).addShowdownPrediction(p);
+ }
+}
=======================================
---
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ActionTrackingVisitor.java
Tue Aug 17 11:28:01 2010
+++
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/ActionTrackingVisitor.java
Tue Aug 17 17:15:33 2010
@@ -15,6 +15,9 @@
import org.cspoker.common.util.Util;

import org.cspoker.ai.bots.bot.gametree.action.BetAction;
+import org.cspoker.ai.bots.bot.gametree.action.CallAction;
+import org.cspoker.ai.bots.bot.gametree.action.CheckAction;
+import org.cspoker.ai.bots.bot.gametree.action.FoldAction;
import org.cspoker.ai.bots.bot.gametree.action.RaiseAction;
import org.cspoker.ai.bots.bot.gametree.action.SearchBotAction;
import org.cspoker.ai.bots.bot.gametree.mcts.nodes.INode;
@@ -220,7 +223,7 @@
data.trueNegative += p.getTrueNegative();
data.falsePositive += p.getFalsePositive();
data.falseNegative += p.getFalseNegative();
- printAccuracy();
+// printAccuracy();
}

public double getAccuracy(PlayerId id) {
@@ -238,6 +241,7 @@
if (node != null
&& !callState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(callState);
assimilatePrediction(callState.getNextToAct(), p);
+ getPropz().logCallProb(callState.getNextToAct(), p);
logger.trace(getPlayerName(callState) + " " + p);
} else {
logger.trace(getPlayerName(callState) + " CallState");
@@ -251,6 +255,7 @@
if (node != null
&& !raiseState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(raiseState,raiseState.getLargestBet());
assimilatePrediction(raiseState.getNextToAct(), p);
+ getPropz().logRaiseProb(raiseState.getNextToAct(), p);
logger.trace(getPlayerName(raiseState) +
" Raise " + Util.parseDollars(raiseState.getLargestBet()) +
" - with <" + p + ">");
@@ -266,6 +271,7 @@
if (node != null
&& !foldState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(foldState);
assimilatePrediction(foldState.getNextToAct(), p);
+ getPropz().logFoldProb(foldState.getNextToAct(), p);
logger.trace(getPlayerName(foldState) + " " + p);
} else {
logger.trace(getPlayerName(foldState) + " FoldState");
@@ -279,6 +285,7 @@
if (node != null
&& !checkState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(checkState);
assimilatePrediction(checkState.getNextToAct(), p);
+ getPropz().logCheckProb(checkState.getNextToAct(), p);
logger.trace(getPlayerName(checkState) + " " + p);
} else {
logger.trace(getPlayerName(checkState) + " CheckState");
@@ -292,6 +299,7 @@
if (node != null
&& !betState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(betState,
betState.getEvent().getAmount());
assimilatePrediction(betState.getNextToAct(), p);
+ getPropz().logBetProb(betState.getNextToAct(), p);
logger.trace(getPlayerName(betState) +
" Bet " + Util.parseDollars(betState.getEvent().getAmount()) +
" - with <" + p + ">");
@@ -307,6 +315,18 @@
if (node != null
&& !allInState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(allInState,
allInState.getEvent().getMovedAmount());
assimilatePrediction(allInState.getNextToAct(), p);
+
+ if (p.getAction() instanceof CallAction)
+ getPropz().logCallProb(allInState.getNextToAct(), p);
+ if (p.getAction() instanceof FoldAction)
+ getPropz().logFoldProb(allInState.getNextToAct(), p);
+ if (p.getAction() instanceof RaiseAction)
+ getPropz().logRaiseProb(allInState.getNextToAct(), p);
+ if (p.getAction() instanceof CheckAction)
+ getPropz().logCheckProb(allInState.getNextToAct(), p);
+ if (p.getAction() instanceof BetAction)
+ getPropz().logBetProb(allInState.getNextToAct(), p);
+
logger.trace(getPlayerName(allInState) +
" All-in " + Util.parseDollars(allInState.getEvent().getMovedAmount())
+
" - with <" + p + ">");
=======================================
---
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/WekaLearningModel.java
Tue Aug 17 11:28:01 2010
+++
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/WekaLearningModel.java
Tue Aug 17 17:15:33 2010
@@ -107,7 +107,8 @@
model = new WekaRegressionModel(defaultModel);
if (config.useOnlineLearning() && !actor.equals(bot)) {
opponentModels.put(actor, model);
- actionTrackingVisitor.getPropz().addPlayer(actor, new
ARFFPlayer(actor, model, config));
+ actionTrackingVisitor.getPropz().addPlayer(actor,
+ new ARFFPlayer(actor, model, config, actionTrackingVisitor));
}
}
return model;
=======================================
---
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/WekaOptions.java
Wed Aug 11 08:56:53 2010
+++
/trunk/ai/opponentmodels/weka/src/main/java/org/cspoker/ai/opponentmodels/weka/WekaOptions.java
Tue Aug 17 17:15:33 2010
@@ -13,6 +13,9 @@

/** continuousLearning must be true for using solveConceptDrift */
private boolean solveConceptDrift = true;
+ private double cdHighCoverage = 50;
+ private double cdLowCoverage = 1;
+ private double cdAccuracy = 0.5;
/** if solveConceptDrift is false, a new model must be learned at
intervals
* based on the number of reported actions */
private long learningInterval = 1;
@@ -98,4 +101,28 @@
public void setModelPersistency(boolean modelPersistency) {
this.modelPersistency = modelPersistency;
}
-}
+
+ public double getCdHighCoverage() {
+ return cdHighCoverage;
+ }
+
+ public void setCdHighCoverage(double cdHighCoverage) {
+ this.cdHighCoverage = cdHighCoverage;
+ }
+
+ public double getCdLowCoverage() {
+ return cdLowCoverage;
+ }
+
+ public void setCdLowCoverage(double cdLowCoverage) {
+ this.cdLowCoverage = cdLowCoverage;
+ }
+
+ public double getCdAccuracy() {
+ return cdAccuracy;
+ }
+
+ public void setCdAccuracy(double cdAccuracy) {
+ this.cdAccuracy = cdAccuracy;
+ }
+}

Reply all
Reply to author
Forward
0 new messages