diff --git a/imagewect/wect.py b/imagewect/wect.py index f88a760..a28dd6d 100644 --- a/imagewect/wect.py +++ b/imagewect/wect.py @@ -72,30 +72,7 @@ def __get_maxfe_weight(s, weights): if d == 1: return max([weights[s[0]], weights[s[1]]]) return max([weights[s[0]], weights[s[1]], weights[s[2]]]) - -def __get_simplex_weight(s, weights, fe): - """ - Computes the weight of a simplex based on the function extension specified. - Args: - s (dionysus.Simplex): The simplex. - weights (np.ndarray): The weights array. - fe (str): The function extension to use for computing weights. - - Returns: - float: The weight of the simplex. - """ - if fe == "MAX": - return __get_maxfe_weight(s, weights) - elif fe == "MIN": - return __get_minfe_weight(s, weights) - elif fe == "AVG": - return __get_avgfe_weight(s, weights) - elif fe == "EC": - return 1 - else: - raise NotImplementedError(f"unknown function extension \"{fe}\"") - def __get_minfe_weight(s, weights): """ @@ -143,6 +120,56 @@ def __get_avgfe_weight(s, weights): return sum([weights[s[0]], weights[s[1]], weights[s[2]]]) / 3 +def __get_productfe_weight(s, weights): + """ + Computes the weight of a simplex using the Product function extension. + (i.e. the simplex weight is the product of the weight of each 0-dimensional component of the simplex) + + Args: + s (dionysus.Simplex): The simplex. + weights (np.ndarray): The weights array. + + Returns: + float: The weight of the simplex under the product extension. + """ + if __contains_zero_weights(s, weights): + return 0 + + d = s.dimension() + if d == 0: + return weights[s[0]] + elif d == 1: + return weights[s[0]] * weights[s[1]] + else: + return weights[s[0]] * weights[s[1]] * weights[s[2]] + + +def __get_simplex_weight(s, weights, fe): + """ + Computes the weight of a simplex based on the function extension specified. + + Args: + s (dionysus.Simplex): The simplex. + weights (np.ndarray): The weights array. + fe (str): The function extension to use for computing weights. + + Returns: + float: The weight of the simplex. + """ + if fe == "MAX": + return __get_maxfe_weight(s, weights) + elif fe == "MIN": + return __get_minfe_weight(s, weights) + elif fe == "AVG": + return __get_avgfe_weight(s, weights) + elif fe == "PRODUCT": + return __get_productfe_weight(s, weights) + elif fe == "EC": + return 1 + else: + raise NotImplementedError(f"unknown function extension \"{fe}\"") + + def vectorize_wect(wect, height_vals): """ Vectorizes the WECT by computing its values at specified height values.