From 83f97f18c9212d42541effeb8831bf174c25b6cd Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Thu, 27 Feb 2014 21:03:37 -0800 Subject: [PATCH 01/17] Creating accuracy report class --- .../java/com/nickferraro/bayesian/report/AccuracyReport.java | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/main/java/com/nickferraro/bayesian/report/AccuracyReport.java diff --git a/src/main/java/com/nickferraro/bayesian/report/AccuracyReport.java b/src/main/java/com/nickferraro/bayesian/report/AccuracyReport.java new file mode 100644 index 0000000..c1c5b0d --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/AccuracyReport.java @@ -0,0 +1,5 @@ +package com.nickferraro.bayesian.report; + +public class AccuracyReport { + +} From 77eeec11460d55f13c13f69581c00f6d01bc3965 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Mon, 3 Mar 2014 10:08:34 -0800 Subject: [PATCH 02/17] Created an interface for calculating bayesian system accuracy --- .../report/calc/IAccuracyCalculator.java | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java new file mode 100644 index 0000000..ca4af7b --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java @@ -0,0 +1,46 @@ +package com.nickferraro.bayesian.report.calc; + +import java.util.List; + +import com.nickferraro.bayesian.IDataRow; + +/** + * The IAccuracyCalculator interface is a simple contract for calculating the accuracy of + * a bayesian system in classifying a list of data rows. + * @author Nick Ferraro + * + */ +public interface IAccuracyCalculator { + /** + * Classify the list of data rows and calculate the accuracy of classifying the row category. + * @param dataRows A list of data rows to classify and use in a clean set of calculations. + * @return The accuracy of classifying this data set + */ + public double calculateAccuracy(List> dataRows); + + /** + * Classify the list of data rows and calculate the accuracy of classifying the row category. + * @param dataRows A list of data rows to classify and use in a clean or aggregated set of calculations. + * @param cleanSlate Whether or not to aggregate previous calculations with this new data set or start from a clean slate. Defaults to TRUE. + * @return The accuracy of classifying this data set + */ + public double calculateAccuracy(List> dataRows, boolean cleanSlate); + + /** + * Get the number of correct classifications + * @return The number of correct classifications + */ + public int getCorrectCount(); + + /** + * Get the number of incorrect classifications + * @return The number of incorrect classifications + */ + public int getIncorrectCount(); + + /** + * Get the number of total classifications calculated + * @return The number of total classifications + */ + public int getTotalCount(); +} From f5cc3aa579ab86f2f16f4a09f079a0579fef67d7 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Mon, 3 Mar 2014 10:09:30 -0800 Subject: [PATCH 03/17] Created a thread-safe implementation of IAccuracyCalculator --- .../report/calc/core/AccuracyCalculator.java | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java new file mode 100644 index 0000000..041d76b --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java @@ -0,0 +1,148 @@ +package com.nickferraro.bayesian.report.calc.core; + +import java.security.InvalidParameterException; +import java.util.List; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +import com.nickferraro.bayesian.IBayesianSystem; +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; +import com.nickferraro.bayesian.report.calc.IAccuracyCalculator; + +public class AccuracyCalculator implements IAccuracyCalculator { + protected ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); + protected Lock readLock = readWriteLock.readLock(); + protected Lock writeLock = readWriteLock.writeLock(); + private IBayesianSystem bayesianSystem; + private int total = 0; + private int correct = 0; + + public AccuracyCalculator(IBayesianSystem bayesianSystem) throws InvalidParameterException { + setBayesianSystem(bayesianSystem); + resetCounts(); + } + + public void setBayesianSystem(IBayesianSystem bayesianSystem) throws InvalidParameterException { + if( bayesianSystem == null ) { + throw new InvalidParameterException("AccuracyCalculator cannot set a NULL IBayesianSystem"); + } + + writeLock.lock(); + + try { + this.bayesianSystem = bayesianSystem; + } finally { + writeLock.unlock(); + } + } + + @Override + public double calculateAccuracy(List> dataRows) { + return calculateAccuracy(dataRows, true); + } + + @Override + public double calculateAccuracy(List> dataRows, boolean cleanSlate) { + writeLock.lock(); + + try { + return _calculateAccuracy(dataRows, cleanSlate); + } finally { + writeLock.unlock(); + } + } + + public double getAccuracy() { + readLock.lock(); + + try { + return _getAccuracy(); + } finally { + readLock.unlock(); + } + } + + @Override + public int getCorrectCount() { + readLock.lock(); + + try { + return correct; + } finally { + readLock.unlock(); + } + } + + @Override + public int getIncorrectCount() { + readLock.lock(); + + try { + return total - correct; + } finally { + readLock.unlock(); + } + } + + @Override + public int getTotalCount() { + readLock.unlock(); + + try { + return total; + } finally { + readLock.unlock(); + } + } + + public void resetCounts() { + writeLock.lock(); + + try { + _resetCounts(); + } finally { + writeLock.unlock(); + } + } + + private double _calculateAccuracy(List> dataRows, boolean cleanSlate) { + if( cleanSlate ) { + _resetCounts(); + } + + if( dataRows == null ) { + return 0; + } + + for(IDataRow dataRow : dataRows) { + if( dataRow == null ) { + continue; + } + + IClassification classification = helperGetClassification(bayesianSystem.classifyRow(dataRow)); + if( classification != null ) { + ++total; + + if( classification.getCategory().equals(dataRow.getCategory()) ) { + ++correct; + } + } + } + + return _getAccuracy(); + } + + private double _getAccuracy() { + return (double)correct / (double)total; + } + + private void _resetCounts() { + total = 0; + correct = 0; + } + + private IClassification helperGetClassification(List> l) { + return l == null || l.size() < 1 ? null : l.get(0); + } +} From 1a8d47b08e83f7054cc5a79322c652f2441de732 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Mon, 3 Mar 2014 10:42:41 -0800 Subject: [PATCH 04/17] Updated the interface to have a getAccuracy method --- .../bayesian/report/calc/IAccuracyCalculator.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java index ca4af7b..f0456d0 100644 --- a/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java +++ b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java @@ -27,13 +27,19 @@ public interface IAccuracyCalculator { public double calculateAccuracy(List> dataRows, boolean cleanSlate); /** - * Get the number of correct classifications + * Get the last accuracy calculated. + * @return The last accuracy calculated. + */ + public double getAccuracy(); + + /** + * Get the number of correct classifications from the last calculation. * @return The number of correct classifications */ public int getCorrectCount(); /** - * Get the number of incorrect classifications + * Get the number of incorrect classifications from the last calculation. * @return The number of incorrect classifications */ public int getIncorrectCount(); From 1c8be636d5970fe327989e9a74164cba6529e4f8 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Mon, 3 Mar 2014 10:43:05 -0800 Subject: [PATCH 05/17] Updated the implementation to override getAccuracy and added javadocs --- .../report/calc/core/AccuracyCalculator.java | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java index 041d76b..f61198e 100644 --- a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java @@ -10,6 +10,14 @@ import com.nickferraro.bayesian.IDataRow; import com.nickferraro.bayesian.report.calc.IAccuracyCalculator; +/** + * A thread-safe implementation of IAccuracyCalculator. This class will classify data rows + * and return the accuracy rate calculated. Multiple sets of data rows can be classified and aggregated + * into a single accuracy calculation. This implementation does not use the IDataRow id, nor does + * it verify that duplicate ids are only counted once. + * @author Nick Ferraro + * + */ public class AccuracyCalculator implements IAccuracyCalculator { protected ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); protected Lock readLock = readWriteLock.readLock(); @@ -18,130 +26,190 @@ public class AccuracyCalculator implements IAccuracyCalculator { private int total = 0; private int correct = 0; + /** + * Constructor for AccuracyCalculator. Requires a non-null IBayesianSystem parameter. + * @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL. + * @throws InvalidParameterException Thrown when IBayesianSystem is NULL. + */ public AccuracyCalculator(IBayesianSystem bayesianSystem) throws InvalidParameterException { + // DRY: Set the bayesian system and reset counts setBayesianSystem(bayesianSystem); resetCounts(); } + /** + * Set the IBayesianSystem used for classification. + * @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL. + * @throws InvalidParameterException Thrown when IBayesianSystem is NULL. + */ public void setBayesianSystem(IBayesianSystem bayesianSystem) throws InvalidParameterException { + // Validate the bayesianSystem parameter if( bayesianSystem == null ) { throw new InvalidParameterException("AccuracyCalculator cannot set a NULL IBayesianSystem"); } + // Lock writeLock.lock(); try { this.bayesianSystem = bayesianSystem; } finally { + // Unlock writeLock.unlock(); } } @Override public double calculateAccuracy(List> dataRows) { + // DRY: Call calculateAccuracy with cleanSlate = TRUE return calculateAccuracy(dataRows, true); } @Override public double calculateAccuracy(List> dataRows, boolean cleanSlate) { + // Lock writeLock.lock(); try { + // Call thread-unsafe private method return _calculateAccuracy(dataRows, cleanSlate); } finally { + // Unlock writeLock.unlock(); } } + @Override public double getAccuracy() { + // Lock readLock.lock(); try { + // Call thead-unsafe private method return _getAccuracy(); } finally { + // Unlock readLock.unlock(); } } @Override public int getCorrectCount() { + // Lock readLock.lock(); try { return correct; } finally { + // Unlock readLock.unlock(); } } @Override public int getIncorrectCount() { + // Lock readLock.lock(); try { + // Calculate and return incorrect return total - correct; } finally { + // Unlock readLock.unlock(); } } @Override public int getTotalCount() { + // Lock readLock.unlock(); try { return total; } finally { + // Unlock readLock.unlock(); } } + /** + * Reset the current calculations and counts + */ public void resetCounts() { + // Lock writeLock.lock(); try { + // Call thread-unsafe private method _resetCounts(); } finally { writeLock.unlock(); } } + /** + * Private thread-unsafe method for calculating the accuracy of a bayesian system. + * @param dataRows The data rows to classify and use in accuracy calculations + * @param cleanSlate Whether or not to aggregate calculations or start from a clean slate + * @return The current accuracy of the bayesian system + */ private double _calculateAccuracy(List> dataRows, boolean cleanSlate) { + // Reset our counts if cleanSlate is TRUE if( cleanSlate ) { _resetCounts(); } + // Return 0% accuracy for null dataRows if( dataRows == null ) { return 0; } + // Iterate over data rows for(IDataRow dataRow : dataRows) { + // If a row is null, skip it if( dataRow == null ) { continue; } + // Get the first classification and check if it is correct. NULL classifications are ignored. IClassification classification = helperGetClassification(bayesianSystem.classifyRow(dataRow)); if( classification != null ) { + // Increase the total rows calculated ++total; + // If the classification category matches the data row category, increase the correct count if( classification.getCategory().equals(dataRow.getCategory()) ) { ++correct; } } } + // Return the current accuracy using the thread-unsafe private method return _getAccuracy(); } + /** + * Private thread-unsafe method for calculating the accuracy + * @return The last calculated accuracy + */ private double _getAccuracy() { return (double)correct / (double)total; } + /** + * Private thread-unsafe method for reseting the calculations and counts + */ private void _resetCounts() { total = 0; correct = 0; } + /** + * A helper method for handling generic type inference. Gets the first element of the list. + * @param l The generic list to get the first element from. + * @return The first element of the list or NULL if the list is NULL or empty + */ private IClassification helperGetClassification(List> l) { return l == null || l.size() < 1 ? null : l.get(0); } From 5814c47c5b5fa39e4cba3baea9cebfaa2a8f7698 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Mon, 3 Mar 2014 14:15:55 -0800 Subject: [PATCH 06/17] Adding generic type T to accuracy code to remove wildcard headaches --- .../report/calc/IAccuracyCalculator.java | 7 ++-- .../report/calc/core/AccuracyCalculator.java | 35 ++++++++----------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java index f0456d0..c85da6d 100644 --- a/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java +++ b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java @@ -9,14 +9,15 @@ * a bayesian system in classifying a list of data rows. * @author Nick Ferraro * + * @param The category data type the bayesian system will be using */ -public interface IAccuracyCalculator { +public interface IAccuracyCalculator { /** * Classify the list of data rows and calculate the accuracy of classifying the row category. * @param dataRows A list of data rows to classify and use in a clean set of calculations. * @return The accuracy of classifying this data set */ - public double calculateAccuracy(List> dataRows); + public double calculateAccuracy(List> dataRows); /** * Classify the list of data rows and calculate the accuracy of classifying the row category. @@ -24,7 +25,7 @@ public interface IAccuracyCalculator { * @param cleanSlate Whether or not to aggregate previous calculations with this new data set or start from a clean slate. Defaults to TRUE. * @return The accuracy of classifying this data set */ - public double calculateAccuracy(List> dataRows, boolean cleanSlate); + public double calculateAccuracy(List> dataRows, boolean cleanSlate); /** * Get the last accuracy calculated. diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java index f61198e..e587683 100644 --- a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java @@ -16,13 +16,14 @@ * into a single accuracy calculation. This implementation does not use the IDataRow id, nor does * it verify that duplicate ids are only counted once. * @author Nick Ferraro - * + * + * @param The category data type the bayesian system will be using */ -public class AccuracyCalculator implements IAccuracyCalculator { +public class AccuracyCalculator implements IAccuracyCalculator { protected ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); protected Lock readLock = readWriteLock.readLock(); protected Lock writeLock = readWriteLock.writeLock(); - private IBayesianSystem bayesianSystem; + private IBayesianSystem bayesianSystem; private int total = 0; private int correct = 0; @@ -31,7 +32,7 @@ public class AccuracyCalculator implements IAccuracyCalculator { * @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL. * @throws InvalidParameterException Thrown when IBayesianSystem is NULL. */ - public AccuracyCalculator(IBayesianSystem bayesianSystem) throws InvalidParameterException { + public AccuracyCalculator(IBayesianSystem bayesianSystem) throws InvalidParameterException { // DRY: Set the bayesian system and reset counts setBayesianSystem(bayesianSystem); resetCounts(); @@ -42,7 +43,7 @@ public AccuracyCalculator(IBayesianSystem bayesianSystem) throws InvalidParam * @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL. * @throws InvalidParameterException Thrown when IBayesianSystem is NULL. */ - public void setBayesianSystem(IBayesianSystem bayesianSystem) throws InvalidParameterException { + public void setBayesianSystem(IBayesianSystem bayesianSystem) throws InvalidParameterException { // Validate the bayesianSystem parameter if( bayesianSystem == null ) { throw new InvalidParameterException("AccuracyCalculator cannot set a NULL IBayesianSystem"); @@ -60,13 +61,13 @@ public void setBayesianSystem(IBayesianSystem bayesianSystem) throws InvalidP } @Override - public double calculateAccuracy(List> dataRows) { + public double calculateAccuracy(List> dataRows) { // DRY: Call calculateAccuracy with cleanSlate = TRUE return calculateAccuracy(dataRows, true); } @Override - public double calculateAccuracy(List> dataRows, boolean cleanSlate) { + public double calculateAccuracy(List> dataRows, boolean cleanSlate) { // Lock writeLock.lock(); @@ -154,7 +155,7 @@ public void resetCounts() { * @param cleanSlate Whether or not to aggregate calculations or start from a clean slate * @return The current accuracy of the bayesian system */ - private double _calculateAccuracy(List> dataRows, boolean cleanSlate) { + private double _calculateAccuracy(List> dataRows, boolean cleanSlate) { // Reset our counts if cleanSlate is TRUE if( cleanSlate ) { _resetCounts(); @@ -166,20 +167,21 @@ private double _calculateAccuracy(List> dataRows, boolean cleanSlate } // Iterate over data rows - for(IDataRow dataRow : dataRows) { + for(IDataRow dataRow : dataRows) { // If a row is null, skip it if( dataRow == null ) { continue; } // Get the first classification and check if it is correct. NULL classifications are ignored. - IClassification classification = helperGetClassification(bayesianSystem.classifyRow(dataRow)); - if( classification != null ) { + List> classifications = bayesianSystem.classifyRow(dataRow); + if( classifications != null && classifications.size() > 0 ) { + IClassification classification = classifications.get(0); // Increase the total rows calculated ++total; // If the classification category matches the data row category, increase the correct count - if( classification.getCategory().equals(dataRow.getCategory()) ) { + if( classification != null && classification.getCategory().equals(dataRow.getCategory()) ) { ++correct; } } @@ -204,13 +206,4 @@ private void _resetCounts() { total = 0; correct = 0; } - - /** - * A helper method for handling generic type inference. Gets the first element of the list. - * @param l The generic list to get the first element from. - * @return The first element of the list or NULL if the list is NULL or empty - */ - private IClassification helperGetClassification(List> l) { - return l == null || l.size() < 1 ? null : l.get(0); - } } From fa0cc2a4e2edb0704cc1abe7ddb4f8f2714b2d42 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Mon, 3 Mar 2014 15:04:31 -0800 Subject: [PATCH 07/17] Fixed divide by 0 bug --- .../bayesian/report/calc/core/AccuracyCalculator.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java index e587683..65831a7 100644 --- a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java @@ -124,7 +124,7 @@ public int getIncorrectCount() { @Override public int getTotalCount() { // Lock - readLock.unlock(); + readLock.lock(); try { return total; @@ -196,7 +196,7 @@ private double _calculateAccuracy(List> dataRows, boolean cleanSlate * @return The last calculated accuracy */ private double _getAccuracy() { - return (double)correct / (double)total; + return total == 0 ? 0.0 : ((double)correct / (double)total); } /** From fc98e1cf4caf84c393c14f7e39eae5112d2ca324 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Mon, 3 Mar 2014 15:04:51 -0800 Subject: [PATCH 08/17] Created tests for AccuracyCalculator --- .../calc/core/AccuracyCalculatorTest.java | 194 ++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java diff --git a/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java new file mode 100644 index 0000000..0004cb0 --- /dev/null +++ b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java @@ -0,0 +1,194 @@ +package com.nickferraro.bayesian.report.calc.core; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assume.assumeThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.security.InvalidParameterException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import com.nickferraro.bayesian.IBayesianSystem; +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; + +public class AccuracyCalculatorTest { + private static final String CATEGORY_1 = "cat1"; + private static final String CATEGORY_2 = "cat2"; + private static final String CATEGORY_3 = "cat3"; + + private AccuracyCalculator calculator; + private IBayesianSystem mockSystem; + private IDataRow mockRow1; + private IDataRow mockRow2; + private IDataRow mockRow3; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + mockSystem = mock(IBayesianSystem.class); + calculator = new AccuracyCalculator(mockSystem); + } + + @Test(expected = InvalidParameterException.class) + public void testConstructor_NullSystem() { + new AccuracyCalculator(null); + } + + @Test + public void testInitialState() { + assertEquals(calculator.getAccuracy(), 0.0, 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @Test + public void testCalculateAccuracy_NullList() { + double accuracy = calculator.calculateAccuracy(null); + assertEquals(0.0, accuracy, 0.0001); + assertEquals(0.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy_EmptyList() { + List> emptyList = Collections.emptyList(); + double accuracy = calculator.calculateAccuracy(emptyList); + assertEquals(0.0, accuracy, 0.0001); + assertEquals(0.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy_ListWithNull() { + List> mockDataRows = createMockDataRows(); + mockDataRows.add(1, null); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @Test + public void testResetCounts() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertThat(calculator.getAccuracy(), is(2.0/3.0)); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + + calculator.resetCounts(); + + assertEquals(0.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy_AggregateCounts() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + calculator.calculateAccuracy(mockDataRows); + assumeThat(calculator.getCorrectCount(), is(2)); + assumeThat(calculator.getIncorrectCount(), is(1)); + assumeThat(calculator.getTotalCount(), is(3)); + + double accuracy = calculator.calculateAccuracy(mockDataRows, false); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(4)); + assertThat(calculator.getIncorrectCount(), is(2)); + assertThat(calculator.getTotalCount(), is(6)); + } + + @SuppressWarnings("unchecked") + private List> createMockDataRows() { + List> mockDataRows = new ArrayList>(); + mockRow1 = mock(IDataRow.class); + mockRow2 = mock(IDataRow.class); + mockRow3 = mock(IDataRow.class); + + when(mockRow1.getCategory()).thenReturn(CATEGORY_1); + when(mockRow2.getCategory()).thenReturn(CATEGORY_2); + when(mockRow3.getCategory()).thenReturn(CATEGORY_3); + + mockDataRows.add(mockRow1); + mockDataRows.add(mockRow2); + mockDataRows.add(mockRow3); + + return mockDataRows; + } + + private List> createClassificationList(String category) { + List> classifications = new ArrayList>(); + + @SuppressWarnings("unchecked") + IClassification mockClassification1 = mock(IClassification.class); + classifications.add(mockClassification1); + when(mockClassification1.getCategory()).thenReturn(category); + + return classifications; + } +} From 68ddba785519fb780a555818c8dae011d5528c6d Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Tue, 4 Mar 2014 12:25:43 -0800 Subject: [PATCH 09/17] Fixed a bug related to aggregating calculations --- .../report/calc/core/AccuracyCalculator.java | 2 +- .../calc/core/AccuracyCalculatorTest.java | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java index 65831a7..352ea48 100644 --- a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java @@ -163,7 +163,7 @@ private double _calculateAccuracy(List> dataRows, boolean cleanSlate // Return 0% accuracy for null dataRows if( dataRows == null ) { - return 0; + return _getAccuracy(); } // Iterate over data rows diff --git a/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java index 0004cb0..4eb67b4 100644 --- a/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java +++ b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java @@ -163,6 +163,57 @@ public void testCalculateAccuracy_AggregateCounts() { assertThat(calculator.getTotalCount(), is(6)); } + @Test + public void testCalculateAccuracy_AggregateCounts_Null() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + calculator.calculateAccuracy(mockDataRows); + assumeThat(calculator.getCorrectCount(), is(2)); + assumeThat(calculator.getIncorrectCount(), is(1)); + assumeThat(calculator.getTotalCount(), is(3)); + + double accuracy = calculator.calculateAccuracy(null, false); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @Test + public void testCalculateAccuracy_AggregateCounts_Empty() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + calculator.calculateAccuracy(mockDataRows); + assumeThat(calculator.getCorrectCount(), is(2)); + assumeThat(calculator.getIncorrectCount(), is(1)); + assumeThat(calculator.getTotalCount(), is(3)); + + List> emptyDataRows = Collections.emptyList(); + double accuracy = calculator.calculateAccuracy(emptyDataRows, false); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + @SuppressWarnings("unchecked") private List> createMockDataRows() { List> mockDataRows = new ArrayList>(); From bbf4af449b67c8a34ba3f24943a57b998181fcd9 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Tue, 4 Mar 2014 12:32:11 -0800 Subject: [PATCH 10/17] Removing unused, empty class --- .../java/com/nickferraro/bayesian/report/AccuracyReport.java | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 src/main/java/com/nickferraro/bayesian/report/AccuracyReport.java diff --git a/src/main/java/com/nickferraro/bayesian/report/AccuracyReport.java b/src/main/java/com/nickferraro/bayesian/report/AccuracyReport.java deleted file mode 100644 index c1c5b0d..0000000 --- a/src/main/java/com/nickferraro/bayesian/report/AccuracyReport.java +++ /dev/null @@ -1,5 +0,0 @@ -package com.nickferraro.bayesian.report; - -public class AccuracyReport { - -} From a86b4efcfb631e047f774ff48aeb4b77b0e9f03e Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Tue, 4 Mar 2014 12:32:26 -0800 Subject: [PATCH 11/17] Updating .gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index adef5d4..fb448f0 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ *.jar *.war *.ear -.DS_Store \ No newline at end of file +.DS_Store +/target From 6767d30d8b460f1ef1cca668d62e59b93ad7558e Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Tue, 4 Mar 2014 14:07:50 -0800 Subject: [PATCH 12/17] Updating AccuracyCalculator to support subclasses better --- .../report/calc/core/AccuracyCalculator.java | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java index 352ea48..e65b8f1 100644 --- a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java @@ -20,12 +20,12 @@ * @param The category data type the bayesian system will be using */ public class AccuracyCalculator implements IAccuracyCalculator { - protected ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); - protected Lock readLock = readWriteLock.readLock(); - protected Lock writeLock = readWriteLock.writeLock(); - private IBayesianSystem bayesianSystem; - private int total = 0; - private int correct = 0; + protected final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); + protected final Lock readLock = readWriteLock.readLock(); + protected final Lock writeLock = readWriteLock.writeLock(); + protected IBayesianSystem bayesianSystem; + protected int total = 0; + protected int correct = 0; /** * Constructor for AccuracyCalculator. Requires a non-null IBayesianSystem parameter. @@ -155,7 +155,7 @@ public void resetCounts() { * @param cleanSlate Whether or not to aggregate calculations or start from a clean slate * @return The current accuracy of the bayesian system */ - private double _calculateAccuracy(List> dataRows, boolean cleanSlate) { + protected double _calculateAccuracy(List> dataRows, boolean cleanSlate) { // Reset our counts if cleanSlate is TRUE if( cleanSlate ) { _resetCounts(); @@ -195,14 +195,14 @@ private double _calculateAccuracy(List> dataRows, boolean cleanSlate * Private thread-unsafe method for calculating the accuracy * @return The last calculated accuracy */ - private double _getAccuracy() { + protected double _getAccuracy() { return total == 0 ? 0.0 : ((double)correct / (double)total); } /** * Private thread-unsafe method for reseting the calculations and counts */ - private void _resetCounts() { + protected void _resetCounts() { total = 0; correct = 0; } From 05fd0780e709b5209ac79e6c20adbf9332337cd8 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Tue, 4 Mar 2014 15:58:33 -0800 Subject: [PATCH 13/17] Created the interface for a confusion matrix --- .../bayesian/report/calc/IConfusionMatrix.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 src/main/java/com/nickferraro/bayesian/report/calc/IConfusionMatrix.java diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/IConfusionMatrix.java b/src/main/java/com/nickferraro/bayesian/report/calc/IConfusionMatrix.java new file mode 100644 index 0000000..3cf024f --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/IConfusionMatrix.java @@ -0,0 +1,17 @@ +package com.nickferraro.bayesian.report.calc; + +/** + * An interface for calculating a confusion matrix. + * @author Nick Ferraro + * + * @param The category data type of the bayesian model + */ +public interface IConfusionMatrix extends IAccuracyCalculator { + /** + * Get the count value of a matrix cell. + * @param actualCategory The actual category (row) + * @param classifiedCategory The classified category (column) + * @return The count value of the cell. Will never be less than 0. + */ + public int getCellCount(T actualCategory, T classifiedCategory); +} From da41cb9f9f54abdc37332a229fe321caf2fdd7bc Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Tue, 4 Mar 2014 16:00:32 -0800 Subject: [PATCH 14/17] Adding a thread-safe implementation for a confusion matrix --- .../report/calc/core/ConfusionMatrix.java | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java new file mode 100644 index 0000000..356614d --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java @@ -0,0 +1,134 @@ +package com.nickferraro.bayesian.report.calc.core; + +import java.security.InvalidParameterException; +import java.util.HashMap; +import java.util.List; + +import com.nickferraro.bayesian.IBayesianSystem; +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; +import com.nickferraro.bayesian.report.calc.IConfusionMatrix; + +/** + * A class for calculating the accuracy of a bayesian system and creating a confusion matrix. + * This class is thread-safe. + * @author Nick Ferraro + * + * @param The category data type used by the bayesian system + */ +public class ConfusionMatrix extends AccuracyCalculator implements IConfusionMatrix { + private final HashMap> matrix = new HashMap>(); + + /** + * Constructor of a ConfusionMatrix. Requires a non-null bayesian system. + * @param bayesianSystem The bayesian system to use for calculating a confusion matrix + * @throws InvalidParameterException Thrown when the bayesian system is null + */ + public ConfusionMatrix(IBayesianSystem bayesianSystem) throws InvalidParameterException { + super(bayesianSystem); + } + + @Override + public int getCellCount(T actualCategory, T classifiedCategory) { + // Lock + readLock.lock(); + + try { + // Call the thread-unsafe private method + return _getCellCount(actualCategory, classifiedCategory); + } finally { + // Unlock + readLock.unlock(); + } + } + + /** + * Get the current count for a matrix cell. Not thread-safe. + * @param actualCategory The actual category (row) + * @param classifiedCategory The classified category (column) + * @return The current count of the cell + */ + protected int _getCellCount(T actualCategory, T classifiedCategory) { + // Get the matrix row + HashMap row = matrix.get(actualCategory); + if( row == null ) { + // If the row doesn't exist, the count is 0 + return 0; + } + + // Get the cell count, if it doesn't exist the count is 0 + Integer cellCount = row.get(classifiedCategory); + return cellCount == null ? 0 : cellCount; + } + + @Override + protected double _calculateAccuracy(List> dataRows, boolean cleanSlate) { + // Reset our counts if cleanSlate is TRUE + if( cleanSlate ) { + _resetCounts(); + } + + // Return 0% accuracy for null dataRows + if( dataRows == null ) { + return _getAccuracy(); + } + + // Iterate over data rows + for(IDataRow dataRow : dataRows) { + // If a row is null, skip it + if( dataRow == null ) { + continue; + } + + // Get the first classification and check if it is correct. NULL classifications are ignored. + List> classifications = bayesianSystem.classifyRow(dataRow); + if( classifications != null && classifications.size() > 0 ) { + IClassification classification = classifications.get(0); + // Increase the total rows calculated + ++total; + + // If the classification category matches the data row category, increase the correct count + if( classification != null && classification.getCategory().equals(dataRow.getCategory()) ) { + ++correct; + } + + // Add classification result to matrix + _updateMatrix(dataRow.getCategory(), classification.getCategory()); + } + } + + // Return the current accuracy using the thread-unsafe private method + return _getAccuracy(); + } + + /** + * Increase a cell's count on the matrix + * @param actualCategory The actual category (row) + * @param classifiedCategory The classified category (column) + */ + protected void _updateMatrix(T actualCategory, T classifiedCategory) { + // Get the row for the actual category + HashMap actualRow = matrix.get(actualCategory); + if( actualRow == null ) { + // If the row doesn't exist yet, create it + actualRow = new HashMap(); + matrix.put(actualCategory, actualRow); + } + + // Get the row cell for the classified category + Integer classifiedCell = actualRow.get(classifiedCategory); + if( classifiedCell == null ) { + // If it doesn't exist, create it (set it to 0) + classifiedCell = 0; + } + + // Increase the count for this cell + actualRow.put(classifiedCategory, classifiedCell + 1); + } + + @Override + protected void _resetCounts() { + super._resetCounts(); + matrix.clear(); + } +} From bff1107ee57babdf5cc953dbf3328c9d68b24d5b Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Wed, 5 Mar 2014 11:33:31 -0800 Subject: [PATCH 15/17] Adding fix for NullPointerException --- .../bayesian/report/calc/core/ConfusionMatrix.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java index 356614d..ee717b6 100644 --- a/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java @@ -129,6 +129,9 @@ protected void _updateMatrix(T actualCategory, T classifiedCategory) { @Override protected void _resetCounts() { super._resetCounts(); - matrix.clear(); + // During construction, matrix is NULL + if( matrix != null ) { + matrix.clear(); + } } } From e26788a8f166b216c37e7a12c5b317241d233ab2 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Wed, 5 Mar 2014 11:33:47 -0800 Subject: [PATCH 16/17] Created tests for ConfusionMatrix Updated tests for AccuracyCalculator to support subclassing --- .../calc/core/AccuracyCalculatorTest.java | 20 +++---- .../report/calc/core/ConfusionMatrixTest.java | 53 +++++++++++++++++++ 2 files changed, 63 insertions(+), 10 deletions(-) create mode 100644 src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java diff --git a/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java index 4eb67b4..0d90d8a 100644 --- a/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java +++ b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java @@ -19,15 +19,15 @@ import com.nickferraro.bayesian.IDataRow; public class AccuracyCalculatorTest { - private static final String CATEGORY_1 = "cat1"; - private static final String CATEGORY_2 = "cat2"; - private static final String CATEGORY_3 = "cat3"; + protected static final String CATEGORY_1 = "cat1"; + protected static final String CATEGORY_2 = "cat2"; + protected static final String CATEGORY_3 = "cat3"; - private AccuracyCalculator calculator; - private IBayesianSystem mockSystem; - private IDataRow mockRow1; - private IDataRow mockRow2; - private IDataRow mockRow3; + protected AccuracyCalculator calculator; + protected IBayesianSystem mockSystem; + protected IDataRow mockRow1; + protected IDataRow mockRow2; + protected IDataRow mockRow3; @SuppressWarnings("unchecked") @Before @@ -215,7 +215,7 @@ public void testCalculateAccuracy_AggregateCounts_Empty() { } @SuppressWarnings("unchecked") - private List> createMockDataRows() { + protected List> createMockDataRows() { List> mockDataRows = new ArrayList>(); mockRow1 = mock(IDataRow.class); mockRow2 = mock(IDataRow.class); @@ -232,7 +232,7 @@ private List> createMockDataRows() { return mockDataRows; } - private List> createClassificationList(String category) { + protected List> createClassificationList(String category) { List> classifications = new ArrayList>(); @SuppressWarnings("unchecked") diff --git a/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java b/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java new file mode 100644 index 0000000..5b0b38e --- /dev/null +++ b/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java @@ -0,0 +1,53 @@ +package com.nickferraro.bayesian.report.calc.core; + +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import com.nickferraro.bayesian.IBayesianSystem; +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; + +public class ConfusionMatrixTest extends AccuracyCalculatorTest { + private ConfusionMatrix matrix; + + @Before + public void setup() { + super.setup(); + matrix = new ConfusionMatrix(mockSystem); + calculator = matrix; + } + + @Test + public void testGetCellCount() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertThat(matrix.getCellCount(CATEGORY_1, CATEGORY_1), is(1)); + assertThat(matrix.getCellCount(CATEGORY_1, CATEGORY_2), is(0)); + assertThat(matrix.getCellCount(CATEGORY_1, CATEGORY_3), is(0)); + assertThat(matrix.getCellCount(CATEGORY_2, CATEGORY_1), is(1)); + assertThat(matrix.getCellCount(CATEGORY_2, CATEGORY_2), is(0)); + assertThat(matrix.getCellCount(CATEGORY_2, CATEGORY_3), is(0)); + assertThat(matrix.getCellCount(CATEGORY_3, CATEGORY_1), is(0)); + assertThat(matrix.getCellCount(CATEGORY_3, CATEGORY_2), is(0)); + assertThat(matrix.getCellCount(CATEGORY_3, CATEGORY_3), is(1)); + } +} From bbd5ff67ab1ba77887a39bf6158e6b16ffb4ee04 Mon Sep 17 00:00:00 2001 From: Nick Ferraro Date: Wed, 5 Mar 2014 13:31:17 -0800 Subject: [PATCH 17/17] Removed unused import --- .../bayesian/report/calc/core/ConfusionMatrixTest.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java b/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java index 5b0b38e..583919a 100644 --- a/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java +++ b/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java @@ -3,16 +3,13 @@ import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import java.util.ArrayList; import java.util.List; import org.junit.Before; import org.junit.Test; -import com.nickferraro.bayesian.IBayesianSystem; import com.nickferraro.bayesian.IClassification; import com.nickferraro.bayesian.IDataRow;