diff --git a/prnn/utils/predictiveNet.py b/prnn/utils/predictiveNet.py index acbd6cc7..8d9c388f 100644 --- a/prnn/utils/predictiveNet.py +++ b/prnn/utils/predictiveNet.py @@ -837,16 +837,17 @@ def calculateSpatialRepresentation( height = env.height nb_bins_x, nb_bins_y, minmax = env.get_map_bins() - place_fields, xy = nap.compute_2d_tuning_curves_continuous( - rates, position, ep=rates.time_support, nb_bins=(nb_bins_x, nb_bins_y), minmax=minmax - ) - SI = nap.compute_2d_mutual_info( - place_fields, position, position.time_support, bitssec=bitsec - ) - # Remove units that aren't active in enough timepoints - numactiveT = np.sum((h > 0).numpy(), axis=1) - inactive_cells = numactiveT < activeTimeThreshold - SI.iloc[inactive_cells.flatten()] = 0 + place_fields,xy = nap.compute_2d_tuning_curves_continuous(rates,position, + ep=rates.time_support, + nb_bins=(nb_bins_x, nb_bins_y), + minmax=minmax + ) + SI = nap.compute_2d_mutual_info(place_fields, position, position.time_support, + minmax=minmax, bitssec=bitsec) + #Remove units that aren't active in enough timepoints + numactiveT = np.sum((h>0).numpy(),axis=1) + inactive_cells = numactiveT