diff --git a/.gitignore b/.gitignore index adef5d4..fb448f0 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ *.jar *.war *.ear -.DS_Store \ No newline at end of file +.DS_Store +/target diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java new file mode 100644 index 0000000..c85da6d --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/IAccuracyCalculator.java @@ -0,0 +1,53 @@ +package com.nickferraro.bayesian.report.calc; + +import java.util.List; + +import com.nickferraro.bayesian.IDataRow; + +/** + * The IAccuracyCalculator interface is a simple contract for calculating the accuracy of + * a bayesian system in classifying a list of data rows. + * @author Nick Ferraro + * + * @param The category data type the bayesian system will be using + */ +public interface IAccuracyCalculator { + /** + * Classify the list of data rows and calculate the accuracy of classifying the row category. + * @param dataRows A list of data rows to classify and use in a clean set of calculations. + * @return The accuracy of classifying this data set + */ + public double calculateAccuracy(List> dataRows); + + /** + * Classify the list of data rows and calculate the accuracy of classifying the row category. + * @param dataRows A list of data rows to classify and use in a clean or aggregated set of calculations. + * @param cleanSlate Whether or not to aggregate previous calculations with this new data set or start from a clean slate. Defaults to TRUE. + * @return The accuracy of classifying this data set + */ + public double calculateAccuracy(List> dataRows, boolean cleanSlate); + + /** + * Get the last accuracy calculated. + * @return The last accuracy calculated. + */ + public double getAccuracy(); + + /** + * Get the number of correct classifications from the last calculation. + * @return The number of correct classifications + */ + public int getCorrectCount(); + + /** + * Get the number of incorrect classifications from the last calculation. + * @return The number of incorrect classifications + */ + public int getIncorrectCount(); + + /** + * Get the number of total classifications calculated + * @return The number of total classifications + */ + public int getTotalCount(); +} diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/IConfusionMatrix.java b/src/main/java/com/nickferraro/bayesian/report/calc/IConfusionMatrix.java new file mode 100644 index 0000000..3cf024f --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/IConfusionMatrix.java @@ -0,0 +1,17 @@ +package com.nickferraro.bayesian.report.calc; + +/** + * An interface for calculating a confusion matrix. + * @author Nick Ferraro + * + * @param The category data type of the bayesian model + */ +public interface IConfusionMatrix extends IAccuracyCalculator { + /** + * Get the count value of a matrix cell. + * @param actualCategory The actual category (row) + * @param classifiedCategory The classified category (column) + * @return The count value of the cell. Will never be less than 0. + */ + public int getCellCount(T actualCategory, T classifiedCategory); +} diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java new file mode 100644 index 0000000..e65b8f1 --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculator.java @@ -0,0 +1,209 @@ +package com.nickferraro.bayesian.report.calc.core; + +import java.security.InvalidParameterException; +import java.util.List; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +import com.nickferraro.bayesian.IBayesianSystem; +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; +import com.nickferraro.bayesian.report.calc.IAccuracyCalculator; + +/** + * A thread-safe implementation of IAccuracyCalculator. This class will classify data rows + * and return the accuracy rate calculated. Multiple sets of data rows can be classified and aggregated + * into a single accuracy calculation. This implementation does not use the IDataRow id, nor does + * it verify that duplicate ids are only counted once. + * @author Nick Ferraro + * + * @param The category data type the bayesian system will be using + */ +public class AccuracyCalculator implements IAccuracyCalculator { + protected final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); + protected final Lock readLock = readWriteLock.readLock(); + protected final Lock writeLock = readWriteLock.writeLock(); + protected IBayesianSystem bayesianSystem; + protected int total = 0; + protected int correct = 0; + + /** + * Constructor for AccuracyCalculator. Requires a non-null IBayesianSystem parameter. + * @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL. + * @throws InvalidParameterException Thrown when IBayesianSystem is NULL. + */ + public AccuracyCalculator(IBayesianSystem bayesianSystem) throws InvalidParameterException { + // DRY: Set the bayesian system and reset counts + setBayesianSystem(bayesianSystem); + resetCounts(); + } + + /** + * Set the IBayesianSystem used for classification. + * @param bayesianSystem The IBayesianSystem to use for classification. Must not be NULL. + * @throws InvalidParameterException Thrown when IBayesianSystem is NULL. + */ + public void setBayesianSystem(IBayesianSystem bayesianSystem) throws InvalidParameterException { + // Validate the bayesianSystem parameter + if( bayesianSystem == null ) { + throw new InvalidParameterException("AccuracyCalculator cannot set a NULL IBayesianSystem"); + } + + // Lock + writeLock.lock(); + + try { + this.bayesianSystem = bayesianSystem; + } finally { + // Unlock + writeLock.unlock(); + } + } + + @Override + public double calculateAccuracy(List> dataRows) { + // DRY: Call calculateAccuracy with cleanSlate = TRUE + return calculateAccuracy(dataRows, true); + } + + @Override + public double calculateAccuracy(List> dataRows, boolean cleanSlate) { + // Lock + writeLock.lock(); + + try { + // Call thread-unsafe private method + return _calculateAccuracy(dataRows, cleanSlate); + } finally { + // Unlock + writeLock.unlock(); + } + } + + @Override + public double getAccuracy() { + // Lock + readLock.lock(); + + try { + // Call thead-unsafe private method + return _getAccuracy(); + } finally { + // Unlock + readLock.unlock(); + } + } + + @Override + public int getCorrectCount() { + // Lock + readLock.lock(); + + try { + return correct; + } finally { + // Unlock + readLock.unlock(); + } + } + + @Override + public int getIncorrectCount() { + // Lock + readLock.lock(); + + try { + // Calculate and return incorrect + return total - correct; + } finally { + // Unlock + readLock.unlock(); + } + } + + @Override + public int getTotalCount() { + // Lock + readLock.lock(); + + try { + return total; + } finally { + // Unlock + readLock.unlock(); + } + } + + /** + * Reset the current calculations and counts + */ + public void resetCounts() { + // Lock + writeLock.lock(); + + try { + // Call thread-unsafe private method + _resetCounts(); + } finally { + writeLock.unlock(); + } + } + + /** + * Private thread-unsafe method for calculating the accuracy of a bayesian system. + * @param dataRows The data rows to classify and use in accuracy calculations + * @param cleanSlate Whether or not to aggregate calculations or start from a clean slate + * @return The current accuracy of the bayesian system + */ + protected double _calculateAccuracy(List> dataRows, boolean cleanSlate) { + // Reset our counts if cleanSlate is TRUE + if( cleanSlate ) { + _resetCounts(); + } + + // Return 0% accuracy for null dataRows + if( dataRows == null ) { + return _getAccuracy(); + } + + // Iterate over data rows + for(IDataRow dataRow : dataRows) { + // If a row is null, skip it + if( dataRow == null ) { + continue; + } + + // Get the first classification and check if it is correct. NULL classifications are ignored. + List> classifications = bayesianSystem.classifyRow(dataRow); + if( classifications != null && classifications.size() > 0 ) { + IClassification classification = classifications.get(0); + // Increase the total rows calculated + ++total; + + // If the classification category matches the data row category, increase the correct count + if( classification != null && classification.getCategory().equals(dataRow.getCategory()) ) { + ++correct; + } + } + } + + // Return the current accuracy using the thread-unsafe private method + return _getAccuracy(); + } + + /** + * Private thread-unsafe method for calculating the accuracy + * @return The last calculated accuracy + */ + protected double _getAccuracy() { + return total == 0 ? 0.0 : ((double)correct / (double)total); + } + + /** + * Private thread-unsafe method for reseting the calculations and counts + */ + protected void _resetCounts() { + total = 0; + correct = 0; + } +} diff --git a/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java b/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java new file mode 100644 index 0000000..ee717b6 --- /dev/null +++ b/src/main/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrix.java @@ -0,0 +1,137 @@ +package com.nickferraro.bayesian.report.calc.core; + +import java.security.InvalidParameterException; +import java.util.HashMap; +import java.util.List; + +import com.nickferraro.bayesian.IBayesianSystem; +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; +import com.nickferraro.bayesian.report.calc.IConfusionMatrix; + +/** + * A class for calculating the accuracy of a bayesian system and creating a confusion matrix. + * This class is thread-safe. + * @author Nick Ferraro + * + * @param The category data type used by the bayesian system + */ +public class ConfusionMatrix extends AccuracyCalculator implements IConfusionMatrix { + private final HashMap> matrix = new HashMap>(); + + /** + * Constructor of a ConfusionMatrix. Requires a non-null bayesian system. + * @param bayesianSystem The bayesian system to use for calculating a confusion matrix + * @throws InvalidParameterException Thrown when the bayesian system is null + */ + public ConfusionMatrix(IBayesianSystem bayesianSystem) throws InvalidParameterException { + super(bayesianSystem); + } + + @Override + public int getCellCount(T actualCategory, T classifiedCategory) { + // Lock + readLock.lock(); + + try { + // Call the thread-unsafe private method + return _getCellCount(actualCategory, classifiedCategory); + } finally { + // Unlock + readLock.unlock(); + } + } + + /** + * Get the current count for a matrix cell. Not thread-safe. + * @param actualCategory The actual category (row) + * @param classifiedCategory The classified category (column) + * @return The current count of the cell + */ + protected int _getCellCount(T actualCategory, T classifiedCategory) { + // Get the matrix row + HashMap row = matrix.get(actualCategory); + if( row == null ) { + // If the row doesn't exist, the count is 0 + return 0; + } + + // Get the cell count, if it doesn't exist the count is 0 + Integer cellCount = row.get(classifiedCategory); + return cellCount == null ? 0 : cellCount; + } + + @Override + protected double _calculateAccuracy(List> dataRows, boolean cleanSlate) { + // Reset our counts if cleanSlate is TRUE + if( cleanSlate ) { + _resetCounts(); + } + + // Return 0% accuracy for null dataRows + if( dataRows == null ) { + return _getAccuracy(); + } + + // Iterate over data rows + for(IDataRow dataRow : dataRows) { + // If a row is null, skip it + if( dataRow == null ) { + continue; + } + + // Get the first classification and check if it is correct. NULL classifications are ignored. + List> classifications = bayesianSystem.classifyRow(dataRow); + if( classifications != null && classifications.size() > 0 ) { + IClassification classification = classifications.get(0); + // Increase the total rows calculated + ++total; + + // If the classification category matches the data row category, increase the correct count + if( classification != null && classification.getCategory().equals(dataRow.getCategory()) ) { + ++correct; + } + + // Add classification result to matrix + _updateMatrix(dataRow.getCategory(), classification.getCategory()); + } + } + + // Return the current accuracy using the thread-unsafe private method + return _getAccuracy(); + } + + /** + * Increase a cell's count on the matrix + * @param actualCategory The actual category (row) + * @param classifiedCategory The classified category (column) + */ + protected void _updateMatrix(T actualCategory, T classifiedCategory) { + // Get the row for the actual category + HashMap actualRow = matrix.get(actualCategory); + if( actualRow == null ) { + // If the row doesn't exist yet, create it + actualRow = new HashMap(); + matrix.put(actualCategory, actualRow); + } + + // Get the row cell for the classified category + Integer classifiedCell = actualRow.get(classifiedCategory); + if( classifiedCell == null ) { + // If it doesn't exist, create it (set it to 0) + classifiedCell = 0; + } + + // Increase the count for this cell + actualRow.put(classifiedCategory, classifiedCell + 1); + } + + @Override + protected void _resetCounts() { + super._resetCounts(); + // During construction, matrix is NULL + if( matrix != null ) { + matrix.clear(); + } + } +} diff --git a/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java new file mode 100644 index 0000000..0d90d8a --- /dev/null +++ b/src/test/java/com/nickferraro/bayesian/report/calc/core/AccuracyCalculatorTest.java @@ -0,0 +1,245 @@ +package com.nickferraro.bayesian.report.calc.core; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assume.assumeThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.security.InvalidParameterException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import com.nickferraro.bayesian.IBayesianSystem; +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; + +public class AccuracyCalculatorTest { + protected static final String CATEGORY_1 = "cat1"; + protected static final String CATEGORY_2 = "cat2"; + protected static final String CATEGORY_3 = "cat3"; + + protected AccuracyCalculator calculator; + protected IBayesianSystem mockSystem; + protected IDataRow mockRow1; + protected IDataRow mockRow2; + protected IDataRow mockRow3; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + mockSystem = mock(IBayesianSystem.class); + calculator = new AccuracyCalculator(mockSystem); + } + + @Test(expected = InvalidParameterException.class) + public void testConstructor_NullSystem() { + new AccuracyCalculator(null); + } + + @Test + public void testInitialState() { + assertEquals(calculator.getAccuracy(), 0.0, 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @Test + public void testCalculateAccuracy_NullList() { + double accuracy = calculator.calculateAccuracy(null); + assertEquals(0.0, accuracy, 0.0001); + assertEquals(0.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy_EmptyList() { + List> emptyList = Collections.emptyList(); + double accuracy = calculator.calculateAccuracy(emptyList); + assertEquals(0.0, accuracy, 0.0001); + assertEquals(0.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy_ListWithNull() { + List> mockDataRows = createMockDataRows(); + mockDataRows.add(1, null); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @Test + public void testResetCounts() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertThat(calculator.getAccuracy(), is(2.0/3.0)); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + + calculator.resetCounts(); + + assertEquals(0.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(0)); + assertThat(calculator.getIncorrectCount(), is(0)); + assertThat(calculator.getTotalCount(), is(0)); + } + + @Test + public void testCalculateAccuracy_AggregateCounts() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + calculator.calculateAccuracy(mockDataRows); + assumeThat(calculator.getCorrectCount(), is(2)); + assumeThat(calculator.getIncorrectCount(), is(1)); + assumeThat(calculator.getTotalCount(), is(3)); + + double accuracy = calculator.calculateAccuracy(mockDataRows, false); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(4)); + assertThat(calculator.getIncorrectCount(), is(2)); + assertThat(calculator.getTotalCount(), is(6)); + } + + @Test + public void testCalculateAccuracy_AggregateCounts_Null() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + calculator.calculateAccuracy(mockDataRows); + assumeThat(calculator.getCorrectCount(), is(2)); + assumeThat(calculator.getIncorrectCount(), is(1)); + assumeThat(calculator.getTotalCount(), is(3)); + + double accuracy = calculator.calculateAccuracy(null, false); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @Test + public void testCalculateAccuracy_AggregateCounts_Empty() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + calculator.calculateAccuracy(mockDataRows); + assumeThat(calculator.getCorrectCount(), is(2)); + assumeThat(calculator.getIncorrectCount(), is(1)); + assumeThat(calculator.getTotalCount(), is(3)); + + List> emptyDataRows = Collections.emptyList(); + double accuracy = calculator.calculateAccuracy(emptyDataRows, false); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertEquals(2.0/3.0, calculator.getAccuracy(), 0.0001); + assertThat(calculator.getCorrectCount(), is(2)); + assertThat(calculator.getIncorrectCount(), is(1)); + assertThat(calculator.getTotalCount(), is(3)); + } + + @SuppressWarnings("unchecked") + protected List> createMockDataRows() { + List> mockDataRows = new ArrayList>(); + mockRow1 = mock(IDataRow.class); + mockRow2 = mock(IDataRow.class); + mockRow3 = mock(IDataRow.class); + + when(mockRow1.getCategory()).thenReturn(CATEGORY_1); + when(mockRow2.getCategory()).thenReturn(CATEGORY_2); + when(mockRow3.getCategory()).thenReturn(CATEGORY_3); + + mockDataRows.add(mockRow1); + mockDataRows.add(mockRow2); + mockDataRows.add(mockRow3); + + return mockDataRows; + } + + protected List> createClassificationList(String category) { + List> classifications = new ArrayList>(); + + @SuppressWarnings("unchecked") + IClassification mockClassification1 = mock(IClassification.class); + classifications.add(mockClassification1); + when(mockClassification1.getCategory()).thenReturn(category); + + return classifications; + } +} diff --git a/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java b/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java new file mode 100644 index 0000000..583919a --- /dev/null +++ b/src/test/java/com/nickferraro/bayesian/report/calc/core/ConfusionMatrixTest.java @@ -0,0 +1,50 @@ +package com.nickferraro.bayesian.report.calc.core; + +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.when; + +import java.util.List; + +import org.junit.Before; +import org.junit.Test; + +import com.nickferraro.bayesian.IClassification; +import com.nickferraro.bayesian.IDataRow; + +public class ConfusionMatrixTest extends AccuracyCalculatorTest { + private ConfusionMatrix matrix; + + @Before + public void setup() { + super.setup(); + matrix = new ConfusionMatrix(mockSystem); + calculator = matrix; + } + + @Test + public void testGetCellCount() { + List> mockDataRows = createMockDataRows(); + + List> classificationList1 = createClassificationList(CATEGORY_1); + List> classificationList2 = createClassificationList(CATEGORY_1); + List> classificationList3 = createClassificationList(CATEGORY_3); + + when(mockSystem.classifyRow(mockRow1)).thenReturn(classificationList1); + when(mockSystem.classifyRow(mockRow2)).thenReturn(classificationList2); + when(mockSystem.classifyRow(mockRow3)).thenReturn(classificationList3); + + double accuracy = calculator.calculateAccuracy(mockDataRows); + assertEquals(2.0/3.0, accuracy, 0.0001); + assertThat(matrix.getCellCount(CATEGORY_1, CATEGORY_1), is(1)); + assertThat(matrix.getCellCount(CATEGORY_1, CATEGORY_2), is(0)); + assertThat(matrix.getCellCount(CATEGORY_1, CATEGORY_3), is(0)); + assertThat(matrix.getCellCount(CATEGORY_2, CATEGORY_1), is(1)); + assertThat(matrix.getCellCount(CATEGORY_2, CATEGORY_2), is(0)); + assertThat(matrix.getCellCount(CATEGORY_2, CATEGORY_3), is(0)); + assertThat(matrix.getCellCount(CATEGORY_3, CATEGORY_1), is(0)); + assertThat(matrix.getCellCount(CATEGORY_3, CATEGORY_2), is(0)); + assertThat(matrix.getCellCount(CATEGORY_3, CATEGORY_3), is(1)); + } +}