Skip to content

test: add unit tests for EmbeddingModel, GAN, metrics, and label processors#956

Open
sacredvoid wants to merge 1 commit intosunlabuiuc:masterfrom
sacredvoid:add-unit-tests-embedding-model-gan
Open

test: add unit tests for EmbeddingModel, GAN, metrics, and label processors#956
sacredvoid wants to merge 1 commit intosunlabuiuc:masterfrom
sacredvoid:add-unit-tests-embedding-model-gan

Conversation

@sacredvoid
Copy link
Copy Markdown

@sacredvoid sacredvoid commented Apr 8, 2026

Summary

Add 6 new test files with 73 test cases covering modules that previously had no unit tests: EmbeddingModel, GAN, binary_metrics_fn, multiclass_metrics_fn, regression_metrics_fn, BinaryLabelProcessor, and MultiClassLabelProcessor.

All tests use synthetic data, complete in under 5 seconds, and follow the established patterns from test_transformer.py / test_retain.py.

test_embedding_model.py (14 tests)

Test What it verifies
test_initialization Correct embedding_dim, layers created for each feature
test_embedding_layers_are_correct_type nn.Embedding for sequences, nn.Linear for tensors/multi_hot
test_forward_output_shapes Embedded output has correct batch and embedding dimensions
test_forward_with_output_mask output_mask=True returns masks dict alongside embeddings
test_gradients_flow Gradients propagate through embedding layers
Tensor/MultiHot/NestedSequence variants Same checks across all 4 processor types
test_mixed_forward Mixed sequence + tensor inputs work together
test_custom_embedding_dim Non-default embedding_dim produces correct output sizes

test_gan.py (14 tests)

Test What it verifies
test_initialization Model attributes and submodules created correctly
test_discriminator_output_shape Discriminator output is (batch, 1) for 32/64/128 inputs
test_discriminator_output_range Scores in [0, 1] (sigmoid output)
test_generate_fake_shape Generator produces correct spatial dimensions
test_generate_fake_pixel_range Pixel values in [0, 1]
test_sampling_shape Latent noise has shape (n, hidden_dim, 1, 1)
test_discriminator_backward / test_generator_backward Gradients flow through both networks
test_end_to_end / test_multichannel_end_to_end Generate then discriminate pipeline works

test_binary_metrics.py (10 tests)

Tests binary_metrics_fn with all 9 supported metrics (pr_auc, roc_auc, accuracy, balanced_accuracy, f1, precision, recall, cohen_kappa, jaccard), plus default behavior, custom thresholds, value range validation, perfect predictions, and error handling. One test is skipped documenting a bug in ece_confidence_binary (calibration.py:150) where 1D arrays are indexed as 2D.

test_multiclass_metrics.py (11 tests)

Tests multiclass_metrics_fn with ROC-AUC variants (macro/weighted, ovo/ovr), F1 variants (micro/macro/weighted), Jaccard variants, calibration metrics (ECE, brier_top1, cwECEt), cohen_kappa, hits@n, and mean_rank.

test_regression_metrics.py (9 tests)

Tests regression_metrics_fn with MSE, MAE, KL divergence, perfect reconstruction (zero error), shape mismatch error handling, and 2D array flattening.

test_label_processor.py (15 tests)

Tests BinaryLabelProcessor (fitting with int/bool/string labels, non-binary raises error, process output, size, schema) and MultiClassLabelProcessor (fitting, correct index mapping, size, schema).

Pattern

Follows the established test patterns:

  • Synthetic data via create_sample_dataset() or small numpy arrays
  • torch.manual_seed(42) / np.random.seed(42) for determinism
  • Standard unittest.TestCase structure
  • Google-style docstrings on every method

How to test

python -m pytest tests/core/test_embedding_model.py tests/core/test_gan.py tests/core/test_binary_metrics.py tests/core/test_multiclass_metrics.py tests/core/test_regression_metrics.py tests/core/test_label_processor.py -v

Ref #425

…essors

Add 6 new test files with 73 test cases covering previously untested
modules. All tests use synthetic data, complete in under 5 seconds,
and follow the established test patterns.

Models:
- test_embedding_model.py: EmbeddingModel with 5 processor types
  (sequence, tensor, multi_hot, nested_sequence, mixed)
- test_gan.py: GAN discriminator/generator for 32x32, 64x64, 128x128

Metrics:
- test_binary_metrics.py: binary_metrics_fn (9 metrics + calibration)
- test_multiclass_metrics.py: multiclass_metrics_fn (ROC-AUC, F1,
  Jaccard, ECE, hits@n, mean_rank)
- test_regression_metrics.py: regression_metrics_fn (MSE, MAE, KL)

Processors:
- test_label_processor.py: BinaryLabelProcessor, MultiClassLabelProcessor

Note: test_binary_metrics includes a skipped test documenting a bug in
ece_confidence_binary (calibration.py:150) where 1D arrays are indexed
as 2D.

Ref sunlabuiuc#425
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant