diff --git a/datacompy/spark.py b/datacompy/spark.py index 9df8852f..7769d170 100644 --- a/datacompy/spark.py +++ b/datacompy/spark.py @@ -249,6 +249,20 @@ def _validate_dataframe( self._df1 = dataframe.toDF(*[str(c).lower() for c in dataframe.columns]) if index == "df2": self._df2 = dataframe.toDF(*[str(c).lower() for c in dataframe.columns]) + else: + # Don't allow case sensitive columns + lower_cols = [c.lower() for c in dataframe.columns] + if len(set(lower_cols)) < len(lower_cols): + dupes = { + c for c in dataframe.columns if lower_cols.count(c.lower()) > 1 + } + raise ValueError( + f"{index} has columns that differ only by case: {dupes}. " + "Spark strongly discourages use of case sensitive column names. " + "Rename columns to be unique regardless of case. " + "See: https://spark.apache.org/docs/latest/api/python/tutorial/" + "pandas_on_spark/best_practices.html#do-not-use-duplicated-column-names" + ) # Check if join_columns are present in the dataframe dataframe = getattr(self, index) # refresh diff --git a/tests/test_spark.py b/tests/test_spark.py index b0ae5bad..fa1973fc 100644 --- a/tests/test_spark.py +++ b/tests/test_spark.py @@ -2433,3 +2433,29 @@ def test_columns_with_mismatches_multiple_join_columns(spark_session): assert "id1" not in result assert "id2" not in result assert sorted(result) == ["value1", "value2"] + + +def test_forbid_case_sensitvive_columns(spark_session): + """Test error case for case sensitive columns in dataframes.""" + df1 = spark_session.createDataFrame( + [{"a": 1, "b": 2, "B": 1}, {"a": 3, "b": 1, "B": 0}] + ) + df2 = spark_session.createDataFrame( + [{"a": 1, "b": 2, "B": 2}, {"a": 2, "b": 0, "B": 0}] + ) + + with pytest.raises( + ValueError, + match=r"df1 has columns that differ only by case: \{(?:'b', 'B'|'B', 'b')\}. " + "Spark strongly discourages use of case sensitive column names. " + "Rename columns to be unique regardless of case. " + "See: https://spark.apache.org/docs/latest/api/python/tutorial/" + "pandas_on_spark/best_practices.html#do-not-use-duplicated-column-names", + ): + SparkSQLCompare( + spark_session, + df1, + df2, + join_columns=["a"], + cast_column_names_lower=False, + )