diff --git a/.gitignore b/.gitignore index adef5d4..fb448f0 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ *.jar *.war *.ear -.DS_Store \ No newline at end of file +.DS_Store +/target diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java new file mode 100644 index 0000000..c85da6d --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java @@ -0,0 +1,53 @@ +package com.nickferraro.bayesian.report.calc; + +import java.util.List; + +import com.nickferraro.bayesian.IDataRow; + +/** + * The IAccuracyCalculator interface is a simple contract for calculating the accuracy of + * a bayesian system in classifying a list of data rows. + * @author Nick Ferraro + * + * @param The category data type the bayesian system will be using + */ +public interface IAccuracyCalculator { + /** + * Classify the list of data rows and calculate the accuracy of classifying the row category. + * @param dataRows A list of data rows to classify and use in a clean set of calculations. + * @return The accuracy of classifying this data set + */ + public double calculateAccuracy(List> dataRows); + + /** + * Classify the list of data rows and calculate the accuracy of classifying the row category. + * @param dataRows A list of data rows to classify and use in a clean or aggregated set of calculations. + * @param cleanSlate Whether or not to aggregate previous calculations with this new data set or start from a clean slate. Defaults to TRUE. + * @return The accuracy of classifying this data set + */ + public double calculateAccuracy(List> dataRows, boolean cleanSlate); + + /** + * Get the last accuracy calculated. + * @return The last accuracy calculated. + */ + public double getAccuracy(); + + /** + * Get the number of correct classifications from the last calculation. + * @return The number of correct classifications + */ + public int getCorrectCount(); + + /** + * Get the number of incorrect classifications from the last calculation. + * @return The number of incorrect classifications + */ + public int getIncorrectCount(); + + /** + * Get the number of total classifications calculated + * @return The number of total classifications + */ + public int getTotalCount(); +} diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java new file mode 100644 index 0000000..e65b8f1 --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java @@ -0,0 +1,209 @@ +package com.nickferraro.bayesian.report.calc.core; + +import java.security.InvalidParameterException; +import java.util.List; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +import com.nickferraro.bayesian.IBayesianSystem; +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; +import com.nickferraro.bayesian.report.calc.IAccuracyCalculator; + +/** + * A thread-safe implementation of IAccuracyCalculator. This class will classify data rows + * and return the accuracy rate calculated. Multiple sets of data rows can be classified and aggregated + * into a single accuracy calculation. This implementation does not use the IDataRow id, nor does + * it verify that duplicate ids are only counted once. + * @author Nick Ferraro + * + * @param The category data type the bayesian system will be using + */ +public class AccuracyCalculator implements IAccuracyCalculator { + protected final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); + protected final Lock readLock = readWriteLock.readLock(); + protected final Lock writeLock = readWriteLock.writeLock(); + protected IBayesianSystem bayesianSystem; + protected int total = 0; + protected int correct = 0; + + /** + * Constructor for AccuracyCalculator. Requires a non-null IBayesianSystem parameter. + * @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL. + * @throws InvalidParameterException Thrown when IBayesianSystem is NULL. + */ + public AccuracyCalculator(IBayesianSystem bayesianSystem) throws InvalidParameterException { + // DRY: Set the bayesian system and reset counts + setBayesianSystem(bayesianSystem); + resetCounts(); + } + + /** + * Set the IBayesianSystem used for classification. + * @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL. + * @throws InvalidParameterException Thrown when IBayesianSystem is NULL. + */ + public void setBayesianSystem(IBayesianSystem bayesianSystem) throws InvalidParameterException { + // Validate the bayesianSystem parameter + if( bayesianSystem == null ) { + throw new InvalidParameterException("AccuracyCalculator cannot set a NULL IBayesianSystem"); + } + + // Lock + writeLock.lock(); + + try { + this.bayesianSystem = bayesianSystem; + } finally { + // Unlock + writeLock.unlock(); + } + } + + @Override + public double calculateAccuracy(List> dataRows) { + // DRY: Call calculateAccuracy with cleanSlate = TRUE + return calculateAccuracy(dataRows, true); + } + + @Override + public double calculateAccuracy(List> dataRows, boolean cleanSlate) { + // Lock + writeLock.lock(); + + try { + // Call thread-unsafe private method + return _calculateAccuracy(dataRows, cleanSlate); + } finally { + // Unlock + writeLock.unlock(); + } + } + + @Override + public double getAccuracy() { + // Lock + readLock.lock(); + + try { + // Call thead-unsafe private method + return _getAccuracy(); + } finally { + // Unlock + readLock.unlock(); + } + } + + @Override + public int getCorrectCount() { + // Lock + readLock.lock(); + + try { + return correct; + } finally { + // Unlock + readLock.unlock(); + } + } + + @Override + public int getIncorrectCount() { + // Lock + readLock.lock(); + + try { + // Calculate and return incorrect + return total - correct; + } finally { + // Unlock + readLock.unlock(); + } + } + + @Override + public int getTotalCount() { + // Lock + readLock.lock(); + + try { + return total; + } finally { + // Unlock + readLock.unlock(); + } + } + + /** + * Reset the current calculations and counts + */ + public void resetCounts() { + // Lock + writeLock.lock(); + + try { + // Call thread-unsafe private method + _resetCounts(); + } finally { + writeLock.unlock(); + } + } + + /** + * Private thread-unsafe method for calculating the accuracy of a bayesian system. + * @param dataRows The data rows to classify and use in accuracy calculations + * @param cleanSlate Whether or not to aggregate calculations or start from a clean slate + * @return The current accuracy of the bayesian system + */ + protected double _calculateAccuracy(List> dataRows, boolean cleanSlate) { + // Reset our counts if cleanSlate is TRUE + if( cleanSlate ) { + _resetCounts(); + } + + // Return 0% accuracy for null dataRows + if( dataRows == null ) { + return _getAccuracy(); + } + + // Iterate over data rows + for(IDataRow dataRow : dataRows) { + // If a row is null, skip it + if( dataRow == null ) { + continue; + } + + // Get the first classification and check if it is correct. NULL classifications are ignored. + List> classifications = bayesianSystem.classifyRow(dataRow); + if( classifications != null && classifications.size() > 0 ) { + IClassification classification = classifications.get(0); + // Increase the total rows calculated + ++total; + + // If the classification category matches the data row category, increase the correct count + if( classification != null && classification.getCategory().equals(dataRow.getCategory()) ) { + ++correct; + } + } + } + + // Return the current accuracy using the thread-unsafe private method + return _getAccuracy(); + } + + /** + * Private thread-unsafe method for calculating the accuracy + * @return The last calculated accuracy + */ + protected double _getAccuracy() { + return total == 0 ? 0.0 : ((double)correct / (double)total); + } + + /** + * Private thread-unsafe method for reseting the calculations and counts + */ + protected void _resetCounts() { + total = 0; + correct = 0; + } +} diff --git a/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java new file mode 100644 index 0000000..4eb67b4 --- /dev/null +++ b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java @@ -0,0 +1,245 @@ +package com.nickferraro.bayesian.report.calc.core; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assume.assumeThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.security.InvalidParameterException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import com.nickferraro.bayesian.IBayesianSystem; +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; + +public class AccuracyCalculatorTest { + private static final String CATEGORY_1 = "cat1"; + private static final String CATEGORY_2 = "cat2"; + private static final String CATEGORY_3 = "cat3"; + + private AccuracyCalculator calculator; + private IBayesianSystem mockSystem; + private IDataRow mockRow1; + private IDataRow mockRow2; + private IDataRow mockRow3; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + mockSystem = mock(IBayesianSystem.class); + calculator = new AccuracyCalculator(mockSystem); + } + + @Test(expected = InvalidParameterException.class) + public void testConstructor_NullSystem() { + new AccuracyCalculator(null); + } + + @Test + public void testInitialState() { + assertEquals(calculator.getAccuracy(), 0.0, 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @Test + public void testCalculateAccuracy_NullList() { + double accuracy = calculator.calculateAccuracy(null); + assertEquals(0.0, accuracy, 0.0001); + assertEquals(0.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy_EmptyList() { + List> emptyList = Collections.emptyList(); + double accuracy = calculator.calculateAccuracy(emptyList); + assertEquals(0.0, accuracy, 0.0001); + assertEquals(0.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy_ListWithNull() { + List> mockDataRows = createMockDataRows(); + mockDataRows.add(1, null); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @Test + public void testResetCounts() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertThat(calculator.getAccuracy(), is(2.0/3.0)); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + + calculator.resetCounts(); + + assertEquals(0.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy_AggregateCounts() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + calculator.calculateAccuracy(mockDataRows); + assumeThat(calculator.getCorrectCount(), is(2)); + assumeThat(calculator.getIncorrectCount(), is(1)); + assumeThat(calculator.getTotalCount(), is(3)); + + double accuracy = calculator.calculateAccuracy(mockDataRows, false); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(4)); + assertThat(calculator.getIncorrectCount(), is(2)); + assertThat(calculator.getTotalCount(), is(6)); + } + + @Test + public void testCalculateAccuracy_AggregateCounts_Null() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + calculator.calculateAccuracy(mockDataRows); + assumeThat(calculator.getCorrectCount(), is(2)); + assumeThat(calculator.getIncorrectCount(), is(1)); + assumeThat(calculator.getTotalCount(), is(3)); + + double accuracy = calculator.calculateAccuracy(null, false); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @Test + public void testCalculateAccuracy_AggregateCounts_Empty() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + calculator.calculateAccuracy(mockDataRows); + assumeThat(calculator.getCorrectCount(), is(2)); + assumeThat(calculator.getIncorrectCount(), is(1)); + assumeThat(calculator.getTotalCount(), is(3)); + + List> emptyDataRows = Collections.emptyList(); + double accuracy = calculator.calculateAccuracy(emptyDataRows, false); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @SuppressWarnings("unchecked") + private List> createMockDataRows() { + List> mockDataRows = new ArrayList>(); + mockRow1 = mock(IDataRow.class); + mockRow2 = mock(IDataRow.class); + mockRow3 = mock(IDataRow.class); + + when(mockRow1.getCategory()).thenReturn(CATEGORY_1); + when(mockRow2.getCategory()).thenReturn(CATEGORY_2); + when(mockRow3.getCategory()).thenReturn(CATEGORY_3); + + mockDataRows.add(mockRow1); + mockDataRows.add(mockRow2); + mockDataRows.add(mockRow3); + + return mockDataRows; + } + + private List> createClassificationList(String category) { + List> classifications = new ArrayList>(); + + @SuppressWarnings("unchecked") + IClassification mockClassification1 = mock(IClassification.class); + classifications.add(mockClassification1); + when(mockClassification1.getCategory()).thenReturn(category); + + return classifications; + } +}