Source code for test_aide.equality

"""
This module contains helper functions that simplify asserting equality for types
where it is not possible to simply assert a == b (e.g. pandas.DataFrame) or
nested data structures containing these types.

Note, if pandas or numpy is not available when the module is imported then the
functionality of the assert_equal_dispatch will change - so as not to try and
check types from the libraries that are not available.

"""

try:

    import pandas as pd

    has_pandas = True

except ModuleNotFoundError:

    has_pandas = False

try:

    import numpy as np

    has_numpy = True

except ModuleNotFoundError:

    has_numpy = False


[docs]def assert_equal_dispatch(expected, actual, msg): """This function is used to call specific assert functions depending on the input types. Often we are dealing with pandas.DataFrame or pandas.Series objects when asserting equality in this project and these types cannot be compared with the standard ==. This function allows these types to be compared appropriately as well as allowing objects that may contain these pandas types (e.g. list) to also be compared. The first assert is that actual and expected are of the same type. If this passes then the following types have specific assert functions; - pd.DataFrame - pd.Series - pd.Index - np.NaN - np.ndarray If the inputs are not of the above types then in the case of the following types; - list - tuple - dict recursive calls will be made on the elements of the object, again according to their tpyes. Finally if on object is passed that is not one of any of the above types then the standard assert for equality is used. Note, if pandas or numpy are not available then the types from those libraries will not be considered i.e. if both are not installed then the function will only use the standard equality assertion, while still recursively calling itself if a list, tuple or dict is passed. Parameters ---------- actual : object The expected object. expected : object The actual object. msg : string A message to be used in the assert, passed onto the specific assert equality function that is called. """ if not type(actual) == type(expected): raise TypeError( f"expected ({type(expected)}) and actual ({type(actual)}) type mismatch" ) if has_pandas and type(expected) is pd.DataFrame: assert_frame_equal_msg(actual, expected, msg) elif has_pandas and type(expected) is pd.Series: assert_series_equal_msg(actual, expected, msg) elif has_pandas and isinstance(expected, pd.Index): assert_index_equal_msg(actual, expected, msg) elif has_numpy and isinstance(expected, float) and np.isnan(expected): assert_np_nan_eqal_msg(actual, expected, msg) elif has_numpy and type(expected) is np.ndarray: assert_array_equal_msg(actual, expected, msg) elif type(expected) in [list, tuple]: assert_list_tuple_equal_msg(actual, expected, msg) elif isinstance(expected, dict): assert_dict_equal_msg(actual, expected, msg) else: assert_equal_msg(actual, expected, msg)
[docs]def assert_equal_msg(actual, expected, msg_tag): """Compares actual and expected objects and simply asserts equality (==). Adds msg_tag, actual and expected values to AssertionException message. Parameters ---------- actual : object The expected object. expected : object The actual object. msg_tag : string A tag for the AssertionException message. Use this to identify mismatching arguments in test output. """ error_msg = f"{msg_tag} -\n Expected: {expected}\n Actual: {actual}" assert actual == expected, error_msg
[docs]def assert_np_nan_eqal_msg(actual, expected, msg): """Function to test that both values are np.NaN. Parameters ---------- actual : object The expected object. Must be an numeric type for np.isnan to run. expected : object The actual object. Must be an numeric type for np.isnan to run. msg : string A tag for the AssertionException message. """ assert np.isnan(actual) and np.isnan( expected ), f"Both values are not equal to np.NaN -\n Expected: {expected}\n Actual: {actual}"
[docs]def assert_list_tuple_equal_msg(actual, expected, msg_tag): """Compares two actual and expected list or tuple objects and asserts equality between the two. Error output will identify location of mismatch in items. Checks actual and expected are the same type, then equal length then loops through pariwise eleemnts and calls assert_equal_dispatch function. Parameters ---------- actual : list or tuple The actual list or tuple to compare. expected : list or tuple The expected list or tuple to compare to actual. msg_tag : string A tag for the AssertionException message. """ if not type(expected) in [list, tuple]: raise TypeError( f"expected should be of type list or tuple, but got {type(expected)}" ) if not type(actual) in [list, tuple]: raise TypeError( f"actual should be of type list or tuple, but got {type(actual)}" ) if not type(actual) == type(expected): raise TypeError( f"expect ({type(expected)}) and actual ({type(actual)}) type mismatch" ) assert len(expected) == len( actual ), f"Unequal lengths -\n Expected: {len(expected)}\n Actual: {len(actual)}" for i, (e, a) in enumerate(zip(expected, actual)): assert_equal_dispatch(e, a, f"{msg_tag} index {i}")
[docs]def assert_dict_equal_msg(actual, expected, msg_tag): """Compares two actual and expected dict objects and asserts equality. Error output will identify (first) location of mismatch values. Checks actual and expected are both dicts, then same number of keys then loops through pariwise values from actual and expected and calls assert_equal_dispatch function on these pairs.. Parameters ---------- actual : dict The actual dict to compare. expected : dict The expected dict to compare to actual. msg_tag : string A tag for the AssertionException message. """ if not isinstance(expected, dict): raise TypeError(f"expected should be of type dict, but got {type(expected)}") if not isinstance(actual, dict): raise TypeError(f"actual should be of type dict, but got {type(actual)}") assert len(expected.keys()) == len( actual.keys() ), f"Unequal number of keys -\n Expected: {len(expected.keys())}\n Actual: {len(actual.keys())}" keys_diff_e_a = set(expected.keys()) - set(actual.keys()) keys_diff_a_e = set(actual.keys()) - set(expected.keys()) assert ( keys_diff_e_a == set() ), f"Keys in expected not in actual: {keys_diff_e_a}\nKeys in actual not in expected: {keys_diff_a_e}" for k in actual.keys(): assert_equal_dispatch(expected[k], actual[k], f"{msg_tag} key {k}")
[docs]def assert_frame_equal_msg( actual, expected, msg_tag, print_actual_and_expected=False, **kwargs ): """Compares actual and expected pandas.DataFrames and asserts equality. Calls pd.testing.assert_frame_equal but presents msg_tag, and optionally actual and expected DataFrames, in addition to any other exception info. Parameters ---------- actual : pandas DataFrame The expected dataframe. expected : pandas DataFrame The actual dataframe. msg_tag : string A tag for the assert error message. **kwargs: Keyword args passed to pd.testing.assert_frame_equal. """ try: pd.testing.assert_frame_equal(expected, actual, **kwargs) except Exception as e: if print_actual_and_expected: error_msg = f"""{msg_tag}\nexpected:\n{expected}\nactual:\n{actual}""" else: error_msg = msg_tag raise AssertionError(error_msg) from e
[docs]def assert_series_equal_msg( actual, expected, msg_tag, print_actual_and_expected=False, **kwargs ): """Compares actual and expected pandas.Series and asserts equality. Calls pd.testing.assert_series_equal but presents msg_tag, and optionally actual and expected Series, in addition to any other exception info. Parameters ---------- actual : pandas Series The actual Series. expected : pandas Series The expected Series. msg_tag : string A tag for the assert error message. print_actual_and_expected : Boolean print the actual and expected dataFrame along with error message tag **kwargs: Keyword args passed to pd.testing.assert_series_equal. """ try: pd.testing.assert_series_equal(expected, actual, **kwargs) except Exception as e: if print_actual_and_expected: error_msg = f"""{msg_tag}\nexpected:\n{expected}\nactual:\n{actual}""" else: error_msg = msg_tag raise AssertionError(error_msg) from e
[docs]def assert_index_equal_msg( actual, expected, msg_tag, print_actual_and_expected=False, **kwargs ): """Compares actual and expected pandas.Index objects and asserts equality. Calls pd.testing.assert_index_equal but presents msg_tag, and optionally actual and expected Series, in addition to any other exception info. Parameters ---------- actual : pd.Index The actual index. expected : pd.Index The expected index. msg_tag : string A tag for the assert error message. print_actual_and_expected : Boolean print the actual and expected valuess along with error message tag **kwargs: Keyword args passed to pd.testing.assert_index_equal. """ try: pd.testing.assert_index_equal(expected, actual, **kwargs) except Exception as e: if print_actual_and_expected: error_msg = f"""{msg_tag}\nexpected:\n{expected}\nactual:\n{actual}""" else: error_msg = msg_tag raise AssertionError(error_msg) from e
[docs]def assert_array_equal_msg( actual, expected, msg_tag, print_actual_and_expected=False, **kwargs ): """Compares actual and expected np.arrays and asserts equality. Calls np.testing.assert_array_equal but presents msg_tag, and optionally actual and expected arrays, in addition to any other exception info. Parameters ---------- actual : numpy array The actual array. expected : numpy array The expected array. msg_tag : string A tag for the assert error message. print_actual_and_expected : Boolean print the actual and expected arrays along with error message tag **kwargs: Keyword args passed to np.testing.assert_array_equal. """ # If actual or expected is a scalar, numpy will check whether each entry in # the other array is equal to the scalar. Therefore need to check type. if not isinstance(expected, np.ndarray): raise TypeError( f"expected should be of type numpy ndarray, but got {type(expected)}" ) if not isinstance(actual, np.ndarray): raise TypeError( f"actual should be of type numpy ndarray, but got {type(actual)}" ) try: np.testing.assert_array_equal(expected, actual, **kwargs) except Exception as e: if print_actual_and_expected: error_msg = f"""{msg_tag}\nexpected:\n{expected}\nactual:\n{actual}""" else: error_msg = msg_tag raise AssertionError(error_msg) from e