From d8928d38593288a535964e9613486ed5f01fb8df Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Wed, 18 Sep 2024 11:10:47 +0200 Subject: [PATCH 01/19] Update LibraryBaseChecker implementation --- README.rst | 2 +- .../matplotlib/matplotlib_parameter.py | 10 ++--- pylint_ml/checkers/numpy/numpy_dot.py | 7 +-- .../checkers/numpy/numpy_nan_comparison.py | 8 +++- pylint_ml/checkers/numpy/numpy_parameter.py | 9 +++- .../checkers/pandas/pandas_dataframe_bool.py | 9 ++-- pylint_ml/util/config.py | 10 +++++ pylint_ml/util/library_handler.py | 45 ++++++++++++------- .../test_numpy/test_numpy_nan_comparison.py | 2 + .../test_numpy/test_numpy_parameter.py | 11 ++--- 10 files changed, 75 insertions(+), 38 deletions(-) create mode 100644 pylint_ml/util/config.py diff --git a/README.rst b/README.rst index a6bce47..8a4bdb3 100644 --- a/README.rst +++ b/README.rst @@ -3,4 +3,4 @@ pylint-ml About ----- -``pylint-ml`` is a pylint plugin for enhancing code analysis for machine learning and data science +``pylint-ml`` is a pylint plugin for enhancing code analysis for machine learning and data science projects diff --git a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py index d19fbfe..6fc66d6 100644 --- a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py +++ b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py @@ -8,10 +8,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.library_handler import LibraryHandler +from pylint_ml.util.config import LIB_MATPLOTLIB +from pylint_ml.util.library_handler import LibraryBaseChecker -class MatplotlibParameterChecker(LibraryHandler): +class MatplotlibParameterChecker(LibraryBaseChecker): name = "matplotlib-parameter" msgs = { "W8111": ( @@ -47,9 +48,8 @@ class MatplotlibParameterChecker(LibraryHandler): @only_required_for_messages("matplotlib-parameter") def visit_call(self, node: nodes.Call) -> None: - # TODO Update - # if not self.is_library_imported('matplotlib') and self.is_library_version_valid(lib_version=): - # return + if not self.is_library_imported_and_version_valid(lib_name=LIB_MATPLOTLIB, required_version=None): + return method_name = self._get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: diff --git a/pylint_ml/checkers/numpy/numpy_dot.py b/pylint_ml/checkers/numpy/numpy_dot.py index 955656f..6620a0c 100644 --- a/pylint_ml/checkers/numpy/numpy_dot.py +++ b/pylint_ml/checkers/numpy/numpy_dot.py @@ -10,10 +10,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.library_handler import LibraryHandler +from pylint_ml.util.config import LIB_NUMPY +from pylint_ml.util.library_handler import LibraryBaseChecker -class NumpyDotChecker(LibraryHandler): +class NumpyDotChecker(LibraryBaseChecker): name = "numpy-dot-checker" msgs = { "W8122": ( @@ -29,7 +30,7 @@ def visit_import(self, node: nodes.Import): @only_required_for_messages("numpy-dot-usage") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported("numpy"): + if not self.is_library_imported_and_version_valid(lib_name=LIB_NUMPY, required_version=None): return # Check if the function being called is np.dot diff --git a/pylint_ml/checkers/numpy/numpy_nan_comparison.py b/pylint_ml/checkers/numpy/numpy_nan_comparison.py index 4a1a7ad..3b6bf71 100644 --- a/pylint_ml/checkers/numpy/numpy_nan_comparison.py +++ b/pylint_ml/checkers/numpy/numpy_nan_comparison.py @@ -7,15 +7,17 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_NUMPY +from pylint_ml.util.library_handler import LibraryBaseChecker + COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "==")) NUMPY_NAN = frozenset(("nan", "NaN", "NAN")) -class NumpyNaNComparisonChecker(BaseChecker): +class NumpyNaNComparisonChecker(LibraryBaseChecker): name = "numpy-nan-compare" msgs = { "W8001": ( @@ -32,6 +34,8 @@ def __is_np_nan_call(cls, node: nodes.Attribute) -> bool: @only_required_for_messages("numpy-nan-compare") def visit_compare(self, node: nodes.Compare) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_NUMPY, required_version=None): + return if isinstance(node.left, nodes.Attribute) and self.__is_np_nan_call(node.left): self.add_message("numpy-nan-compare", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index e045e83..ac98852 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -5,12 +5,14 @@ """Check for proper usage of numpy functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_NUMPY +from pylint_ml.util.library_handler import LibraryBaseChecker -class NumPyParameterChecker(BaseChecker): + +class NumPyParameterChecker(LibraryBaseChecker): name = "numpy-parameter" msgs = { "W8111": ( @@ -71,6 +73,9 @@ class NumPyParameterChecker(BaseChecker): @only_required_for_messages("numpy-parameter") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_NUMPY, required_version=None): + return + method_name = self._get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py index a519ca8..7183d7b 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py @@ -7,14 +7,14 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -# Todo add version deprecated +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_handler import LibraryBaseChecker -class PandasDataFrameBoolChecker(BaseChecker): +class PandasDataFrameBoolChecker(LibraryBaseChecker): name = "pandas-dataframe-bool" msgs = { "W8104": ( @@ -26,6 +26,9 @@ class PandasDataFrameBoolChecker(BaseChecker): @only_required_for_messages("pandas-dataframe-bool") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version="2.1.0"): + return + if isinstance(node.func, nodes.Attribute): method_name = getattr(node.func, "attrname", None) diff --git a/pylint_ml/util/config.py b/pylint_ml/util/config.py new file mode 100644 index 0000000..59caa7b --- /dev/null +++ b/pylint_ml/util/config.py @@ -0,0 +1,10 @@ + + +# Library names +LIB_PANDAS = "pandas" +LIB_NUMPY = "numpy" +LIB_TENSORFLOW = "tensor" +LIB_SCIPY = "scipy" +LIB_SKLEARN = "sklearn" +LIB_PYTORCH = "torch" +LIB_MATPLOTLIB = "matplotlib" diff --git a/pylint_ml/util/library_handler.py b/pylint_ml/util/library_handler.py index 2d54203..a5d1498 100644 --- a/pylint_ml/util/library_handler.py +++ b/pylint_ml/util/library_handler.py @@ -1,7 +1,9 @@ +from importlib.metadata import PackageNotFoundError, version + from pylint.checkers import BaseChecker -class LibraryHandler(BaseChecker): +class LibraryBaseChecker(BaseChecker): def __init__(self, linter): super().__init__(linter) @@ -11,24 +13,33 @@ def visit_import(self, node): for name, alias in node.names: self.imports[alias or name] = name - def visit_importfrom( - self, - node, - ): - # TODO Update method to handle either: - # 1. Check of specific method-name imported? - # 2. Store all method names importfrom libname? - + def visit_importfrom(self, node): module = node.modname for name, alias in node.names: full_name = f"{module}.{name}" self.imports[alias or name] = full_name - def is_library_imported(self, library_name): - return any(mod.startswith(library_name) for mod in self.imports.values()) - - # def is_library_version_valid(self, lib_version): - # # TODO update solution - # if lib_version is None: - # pass - # return + def is_library_imported_and_version_valid(self, lib_name, required_version): + """ + Checks if the library is imported and whether the installed version is valid (greater than or equal to the + required version). + + param lib_name: Name of the library (as a string). + param required_version: The required minimum version (as a string). + return: True if the library is imported and the version is valid, otherwise False. + """ + # Check if the library is imported + if not any(mod.startswith(lib_name) for mod in self.imports.values()): + return False + + # Check if the library version is valid + try: + installed_version = version(lib_name) + except PackageNotFoundError: + return False + + # Compare versions (this assumes versioning follows standard conventions like '1.2.3') + if required_version is not None and installed_version < required_version: + return False + + return True diff --git a/tests/checkers/test_numpy/test_numpy_nan_comparison.py b/tests/checkers/test_numpy/test_numpy_nan_comparison.py index 6191d2e..5f9f2a7 100644 --- a/tests/checkers/test_numpy/test_numpy_nan_comparison.py +++ b/tests/checkers/test_numpy/test_numpy_nan_comparison.py @@ -11,6 +11,8 @@ class TestNumpyNaNComparison(pylint.testutils.CheckerTestCase): def test_singleton_nan_compare(self): singleton_node, chained_node, great_than_node = astroid.extract_node( """ + import numpy as np + a_nan = np.array([0, 1, np.nan]) np.nan == a_nan #@ 1 == 1 == np.nan #@ diff --git a/tests/checkers/test_numpy/test_numpy_parameter.py b/tests/checkers/test_numpy/test_numpy_parameter.py index 40b8d52..d600bac 100644 --- a/tests/checkers/test_numpy/test_numpy_parameter.py +++ b/tests/checkers/test_numpy/test_numpy_parameter.py @@ -9,25 +9,26 @@ class TestNumPyParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumPyParameterChecker def test_array_missing_object(self): - node = astroid.extract_node( + import_node, call_node = astroid.extract_node( """ - import numpy as np + import numpy as np #@ arr = np.array() #@ """ ) - array_call = node.value + call_node = call_node.value with self.assertAddsMessages( pylint.testutils.MessageTest( msg_id="numpy-parameter", confidence=HIGH, - node=array_call, + node=call_node, args=("object", "array"), ), ignore_position=True, ): - self.checker.visit_call(array_call) + self.checker.visit_import(import_node) + self.checker.visit_call(call_node) def test_zeros_without_shape(self): node = astroid.extract_node( From 54895ff6219e3db4955682d47e1e558d79e2b42a Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Thu, 19 Sep 2024 15:22:31 +0200 Subject: [PATCH 02/19] Update tests --- .../matplotlib/matplotlib_parameter.py | 2 +- pylint_ml/checkers/numpy/numpy_dot.py | 2 +- .../checkers/numpy/numpy_nan_comparison.py | 2 +- pylint_ml/checkers/numpy/numpy_parameter.py | 2 +- .../checkers/pandas/pandas_dataframe_bool.py | 2 +- .../pandas_dataframe_column_selection.py | 10 ++- .../pandas/pandas_dataframe_empty_column.py | 8 ++- .../pandas/pandas_dataframe_iterrows.py | 8 ++- .../pandas/pandas_dataframe_naming.py | 8 ++- .../pandas/pandas_dataframe_values.py | 8 ++- pylint_ml/checkers/pandas/pandas_inplace.py | 8 ++- pylint_ml/checkers/pandas/pandas_parameter.py | 9 ++- .../checkers/pandas/pandas_series_bool.py | 7 +- .../checkers/pandas/pandas_series_naming.py | 8 ++- pylint_ml/checkers/scipy/scipy_parameter.py | 8 ++- .../checkers/sklearn/sklearn_parameter.py | 8 ++- .../checkers/tensorflow/tensor_parameter.py | 8 ++- pylint_ml/checkers/torch/torch_parameter.py | 8 ++- ...ary_handler.py => library_base_checker.py} | 10 +++ .../test_numpy/test_numpy_parameter.py | 26 ++++--- .../pandas_dataframe_column_selection.py | 11 +-- .../test_pandas/test_pandas_dataframe_bool.py | 14 ++-- .../test_pandas_dataframe_empty_column.py | 16 +++-- .../test_pandas_dataframe_iterrows.py | 5 +- .../test_pandas_dataframe_naming.py | 15 ++-- .../test_pandas_dataframe_values.py | 5 +- .../test_pandas/test_pandas_inplace.py | 35 ++++++---- .../test_pandas/test_pandas_parameter.py | 70 +++++++++++-------- .../test_pandas/test_pandas_series_bool.py | 14 ++-- .../test_pandas/test_pandas_series_naming.py | 21 +++--- .../test_scipy/test_scipy_parameter.py | 35 ++++++---- 31 files changed, 260 insertions(+), 133 deletions(-) rename pylint_ml/util/{library_handler.py => library_base_checker.py} (89%) diff --git a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py index 6fc66d6..40113c9 100644 --- a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py +++ b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.config import LIB_MATPLOTLIB -from pylint_ml.util.library_handler import LibraryBaseChecker +from pylint_ml.util.library_base_checker import LibraryBaseChecker class MatplotlibParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/numpy/numpy_dot.py b/pylint_ml/checkers/numpy/numpy_dot.py index 6620a0c..e32e96b 100644 --- a/pylint_ml/checkers/numpy/numpy_dot.py +++ b/pylint_ml/checkers/numpy/numpy_dot.py @@ -11,7 +11,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.config import LIB_NUMPY -from pylint_ml.util.library_handler import LibraryBaseChecker +from pylint_ml.util.library_base_checker import LibraryBaseChecker class NumpyDotChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/numpy/numpy_nan_comparison.py b/pylint_ml/checkers/numpy/numpy_nan_comparison.py index 3b6bf71..6f9d330 100644 --- a/pylint_ml/checkers/numpy/numpy_nan_comparison.py +++ b/pylint_ml/checkers/numpy/numpy_nan_comparison.py @@ -11,7 +11,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.config import LIB_NUMPY -from pylint_ml.util.library_handler import LibraryBaseChecker +from pylint_ml.util.library_base_checker import LibraryBaseChecker COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "==")) NUMPY_NAN = frozenset(("nan", "NaN", "NAN")) diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index ac98852..a3921de 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.config import LIB_NUMPY -from pylint_ml.util.library_handler import LibraryBaseChecker +from pylint_ml.util.library_base_checker import LibraryBaseChecker class NumPyParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py index 7183d7b..cf0c2ae 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py @@ -11,7 +11,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.config import LIB_PANDAS -from pylint_ml.util.library_handler import LibraryBaseChecker +from pylint_ml.util.library_base_checker import LibraryBaseChecker class PandasDataFrameBoolChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py index 3f3e480..f34be18 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py @@ -7,12 +7,14 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PandasColumnSelectionChecker(BaseChecker): + +class PandasColumnSelectionChecker(LibraryBaseChecker): name = "pandas-column-selection" msgs = { "W8118": ( @@ -25,6 +27,10 @@ class PandasColumnSelectionChecker(BaseChecker): @only_required_for_messages("pandas-column-selection") def visit_attribute(self, node: nodes.Attribute) -> None: """Check for attribute access that might be a column selection.""" + + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + return + if isinstance(node.expr, nodes.Name) and node.expr.name.startswith("df_"): # Issue a warning for property-like access self.add_message("pandas-column-selection", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py index 3427f1b..db82a34 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py @@ -11,8 +11,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PandasEmptyColumnChecker(BaseChecker): + +class PandasEmptyColumnChecker(LibraryBaseChecker): name = "pandas-dataframe-empty-column" msgs = { "W8113": ( @@ -25,6 +28,9 @@ class PandasEmptyColumnChecker(BaseChecker): @only_required_for_messages("pandas-dataframe-empty-column") def visit_subscript(self, node: nodes.Subscript) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + return + if isinstance(node.value, nodes.Name) and node.value.name.startswith("df_"): if isinstance(node.slice, nodes.Const) and isinstance(node.parent, nodes.Assign): if isinstance(node.parent.value, nodes.Const): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py index 99d2b0c..b1c24a3 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py @@ -11,8 +11,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PandasIterrowsChecker(BaseChecker): + +class PandasIterrowsChecker(LibraryBaseChecker): name = "pandas-iterrows" msgs = { "W8106": ( @@ -25,6 +28,9 @@ class PandasIterrowsChecker(BaseChecker): @only_required_for_messages("pandas-iterrows") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + return + if isinstance(node.func, nodes.Attribute): method_name = getattr(node.func, "attrname", None) if method_name == "iterrows": diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py index a0aaf2d..fe08e3b 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py @@ -11,8 +11,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PandasDataFrameNamingChecker(BaseChecker): + +class PandasDataFrameNamingChecker(LibraryBaseChecker): name = "pandas-dataframe-naming" msgs = { "W8103": ( @@ -24,6 +27,9 @@ class PandasDataFrameNamingChecker(BaseChecker): @only_required_for_messages("pandas-dataframe-naming") def visit_assign(self, node: nodes.Assign) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + return + if isinstance(node.value, nodes.Call): func_name = getattr(node.value.func, "attrname", None) module_name = getattr(node.value.func.expr, "name", None) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_values.py b/pylint_ml/checkers/pandas/pandas_dataframe_values.py index 13b382f..689daf5 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_values.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_values.py @@ -11,8 +11,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PandasValuesChecker(BaseChecker): + +class PandasValuesChecker(LibraryBaseChecker): name = "pandas-dataframe-values" msgs = { "W8112": ( @@ -25,6 +28,9 @@ class PandasValuesChecker(BaseChecker): @only_required_for_messages("pandas-dataframe-values") def visit_attribute(self, node: nodes.Attribute) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + return + if isinstance(node.expr, nodes.Name): if node.attrname == "values" and node.expr.name.startswith("df_"): self.add_message("pandas-dataframe-values", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/pandas/pandas_inplace.py b/pylint_ml/checkers/pandas/pandas_inplace.py index c1eded6..9426065 100644 --- a/pylint_ml/checkers/pandas/pandas_inplace.py +++ b/pylint_ml/checkers/pandas/pandas_inplace.py @@ -11,8 +11,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PandasInplaceChecker(BaseChecker): + +class PandasInplaceChecker(LibraryBaseChecker): name = "pandas-inplace" msgs = { "W8109": ( @@ -39,6 +42,9 @@ class PandasInplaceChecker(BaseChecker): @only_required_for_messages("pandas-inplace") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + return + # Check if the call is to a method that supports 'inplace' if isinstance(node.func, nodes.Attribute): method_name = node.func.attrname diff --git a/pylint_ml/checkers/pandas/pandas_parameter.py b/pylint_ml/checkers/pandas/pandas_parameter.py index 3efee90..c6e629d 100644 --- a/pylint_ml/checkers/pandas/pandas_parameter.py +++ b/pylint_ml/checkers/pandas/pandas_parameter.py @@ -5,12 +5,14 @@ """Check for proper usage of Pandas functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PandasParameterChecker(BaseChecker): + +class PandasParameterChecker(LibraryBaseChecker): name = "pandas-parameter" msgs = { "W8111": ( @@ -64,6 +66,9 @@ class PandasParameterChecker(BaseChecker): @only_required_for_messages("pandas-parameter") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + return + method_name = self._get_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} diff --git a/pylint_ml/checkers/pandas/pandas_series_bool.py b/pylint_ml/checkers/pandas/pandas_series_bool.py index dafac68..d42d5d2 100644 --- a/pylint_ml/checkers/pandas/pandas_series_bool.py +++ b/pylint_ml/checkers/pandas/pandas_series_bool.py @@ -12,9 +12,11 @@ from pylint.interfaces import HIGH # Todo add version deprecated +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PandasSeriesBoolChecker(BaseChecker): +class PandasSeriesBoolChecker(LibraryBaseChecker): name = "pandas-series-bool" msgs = { "W8105": ( @@ -26,6 +28,9 @@ class PandasSeriesBoolChecker(BaseChecker): @only_required_for_messages("pandas-series-bool") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + return + if isinstance(node.func, nodes.Attribute): method_name = getattr(node.func, "attrname", None) diff --git a/pylint_ml/checkers/pandas/pandas_series_naming.py b/pylint_ml/checkers/pandas/pandas_series_naming.py index 8e5e3e2..780a426 100644 --- a/pylint_ml/checkers/pandas/pandas_series_naming.py +++ b/pylint_ml/checkers/pandas/pandas_series_naming.py @@ -11,8 +11,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PandasSeriesNamingChecker(BaseChecker): + +class PandasSeriesNamingChecker(LibraryBaseChecker): name = "pandas-series-naming" msgs = { "W8103": ( @@ -24,6 +27,9 @@ class PandasSeriesNamingChecker(BaseChecker): @only_required_for_messages("pandas-series-naming") def visit_assign(self, node: nodes.Assign) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + return + if isinstance(node.value, nodes.Call): func_name = getattr(node.value.func, "attrname", None) module_name = getattr(node.value.func.expr, "name", None) diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index 9b9464e..7af4a08 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -9,8 +9,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_SCIPY +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class ScipyParameterChecker(BaseChecker): + +class ScipyParameterChecker(LibraryBaseChecker): name = "scipy-parameter" msgs = { "W8111": ( @@ -42,6 +45,9 @@ class ScipyParameterChecker(BaseChecker): @only_required_for_messages("scipy-parameter") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_SCIPY, required_version=None): + return + method_name = self._get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} diff --git a/pylint_ml/checkers/sklearn/sklearn_parameter.py b/pylint_ml/checkers/sklearn/sklearn_parameter.py index c9ef152..8acd01d 100644 --- a/pylint_ml/checkers/sklearn/sklearn_parameter.py +++ b/pylint_ml/checkers/sklearn/sklearn_parameter.py @@ -9,8 +9,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_SKLEARN +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class SklearnParameterChecker(BaseChecker): + +class SklearnParameterChecker(LibraryBaseChecker): name = "sklearn-parameter" msgs = { "W8111": ( @@ -37,6 +40,9 @@ class SklearnParameterChecker(BaseChecker): @only_required_for_messages("sklearn-parameter") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_SKLEARN, required_version=None): + return + method_name = self._get_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} diff --git a/pylint_ml/checkers/tensorflow/tensor_parameter.py b/pylint_ml/checkers/tensorflow/tensor_parameter.py index c9a334f..8adb795 100644 --- a/pylint_ml/checkers/tensorflow/tensor_parameter.py +++ b/pylint_ml/checkers/tensorflow/tensor_parameter.py @@ -9,8 +9,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_TENSORFLOW +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class TensorFlowParameterChecker(BaseChecker): + +class TensorFlowParameterChecker(LibraryBaseChecker): name = "tensor-parameter" msgs = { "W8111": ( @@ -35,6 +38,9 @@ class TensorFlowParameterChecker(BaseChecker): @only_required_for_messages("tensor-parameter") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_TENSORFLOW, required_version=None): + return + method_name = self._get_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} diff --git a/pylint_ml/checkers/torch/torch_parameter.py b/pylint_ml/checkers/torch/torch_parameter.py index 75d6b4a..14e373f 100644 --- a/pylint_ml/checkers/torch/torch_parameter.py +++ b/pylint_ml/checkers/torch/torch_parameter.py @@ -9,8 +9,11 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.config import LIB_TENSORFLOW +from pylint_ml.util.library_base_checker import LibraryBaseChecker -class PyTorchParameterChecker(BaseChecker): + +class PyTorchParameterChecker(LibraryBaseChecker): name = "pytorch-parameter" msgs = { "W8111": ( @@ -34,6 +37,9 @@ class PyTorchParameterChecker(BaseChecker): @only_required_for_messages("pytorch-parameter") def visit_call(self, node: nodes.Call) -> None: + if not self.is_library_imported_and_version_valid(lib_name=LIB_TENSORFLOW, required_version=None): + return + method_name = self._get_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} diff --git a/pylint_ml/util/library_handler.py b/pylint_ml/util/library_base_checker.py similarity index 89% rename from pylint_ml/util/library_handler.py rename to pylint_ml/util/library_base_checker.py index a5d1498..ba4cd83 100644 --- a/pylint_ml/util/library_handler.py +++ b/pylint_ml/util/library_base_checker.py @@ -15,10 +15,17 @@ def visit_import(self, node): def visit_importfrom(self, node): module = node.modname + print(module) + for name, alias in node.names: + print(name) + print(alias) + print("-------------") full_name = f"{module}.{name}" self.imports[alias or name] = full_name + print(self.imports) + def is_library_imported_and_version_valid(self, lib_name, required_version): """ Checks if the library is imported and whether the installed version is valid (greater than or equal to the @@ -29,6 +36,9 @@ def is_library_imported_and_version_valid(self, lib_name, required_version): return: True if the library is imported and the version is valid, otherwise False. """ # Check if the library is imported + print("xxxxxxxxxxxx") + print(lib_name) + if not any(mod.startswith(lib_name) for mod in self.imports.values()): return False diff --git a/tests/checkers/test_numpy/test_numpy_parameter.py b/tests/checkers/test_numpy/test_numpy_parameter.py index d600bac..fffcc0c 100644 --- a/tests/checkers/test_numpy/test_numpy_parameter.py +++ b/tests/checkers/test_numpy/test_numpy_parameter.py @@ -31,9 +31,9 @@ def test_array_missing_object(self): self.checker.visit_call(call_node) def test_zeros_without_shape(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import numpy as np + import numpy as np #@ arr = np.zeros() #@ """ ) @@ -49,13 +49,14 @@ def test_zeros_without_shape(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(zeros_call) def test_random_rand_without_shape(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import numpy as np - arr = np.random.rand() #@ + import numpy as np #@ + arr = np.random.rand() #@ """ ) @@ -70,13 +71,14 @@ def test_random_rand_without_shape(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(rand_call) def test_dot_without_b(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import numpy as np - arr = np.dot(a=[1, 2, 3]) #@ + import numpy as np #@ + arr = np.dot(a=[1, 2, 3]) #@ """ ) @@ -91,13 +93,14 @@ def test_dot_without_b(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(dot_call) def test_percentile_without_q(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import numpy as np - result = np.percentile(a=[1, 2, 3]) #@ + import numpy as np #@ + result = np.percentile(a=[1, 2, 3]) #@ """ ) @@ -112,4 +115,5 @@ def test_percentile_without_q(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(percentile_call) diff --git a/tests/checkers/test_pandas/pandas_dataframe_column_selection.py b/tests/checkers/test_pandas/pandas_dataframe_column_selection.py index 0bd3592..511ef2d 100644 --- a/tests/checkers/test_pandas/pandas_dataframe_column_selection.py +++ b/tests/checkers/test_pandas/pandas_dataframe_column_selection.py @@ -9,22 +9,23 @@ class TestPandasColumnSelectionChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasColumnSelectionChecker def test_incorrect_column_selection(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) value = df_sales.A #@ """ ) - column_attribute = node.value + attribute_node = node.value with self.assertAddsMessages( pylint.testutils.MessageTest( msg_id="pandas-column-selection", confidence=HIGH, - node=column_attribute, + node=attribute_node, ), ignore_position=True, ): - self.checker.visit_attribute(column_attribute) + self.checker.visit_import(import_node) + self.checker.visit_attribute(attribute_node) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py index ec1ae81..44cd400 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py @@ -9,9 +9,9 @@ class TestDataFrameBoolChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasDataFrameBoolChecker def test_dataframe_bool_usage(self): - node = astroid.extract_node( + import_node, call_node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_customers = pd.DataFrame(data) df_customers.bool() #@ """ @@ -20,19 +20,21 @@ def test_dataframe_bool_usage(self): pylint.testutils.MessageTest( msg_id="pandas-dataframe-bool", confidence=HIGH, - node=node, + node=call_node, ), ignore_position=True, ): - self.checker.visit_call(node) + self.checker.visit_import(import_node) + self.checker.visit_call(call_node) def test_no_bool_usage(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_customers = pd.DataFrame(data) df_customers.sum() #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(node) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py index 37875f6..1b1c7e0 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py @@ -9,21 +9,21 @@ class TestPandasEmptyColumnChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasEmptyColumnChecker def test_correct_empty_column_initialization(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import numpy as np - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame() df_sales['new_col_str'] = pd.Series(dtype='object') #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_subscript(node) def test_incorrect_empty_column_initialization_with_zero(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame() df_sales['new_col_int'] = 0 #@ """ @@ -39,12 +39,13 @@ def test_incorrect_empty_column_initialization_with_zero(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_subscript(subscript_node) def test_incorrect_empty_column_initialization_with_empty_string(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame() df_sales['new_col_str'] = '' #@ """ @@ -60,4 +61,5 @@ def test_incorrect_empty_column_initialization_with_empty_string(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_subscript(subscript_node) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py index 721a75e..b796b69 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py @@ -9,9 +9,9 @@ class TestPandasIterrowsChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasIterrowsChecker def test_iterrows_used(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame({ "Product": ["A", "B", "C"], "Sales": [100, 200, 300] @@ -32,4 +32,5 @@ def test_iterrows_used(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(iterrows_call) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py index 558df80..91a30dd 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py @@ -9,19 +9,20 @@ class TestPandasDataFrameNamingChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasDataFrameNamingChecker def test_correct_dataframe_naming(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_customers = pd.DataFrame(data) #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_assign(node) def test_incorrect_dataframe_naming(self): - pandas_dataframe_node = astroid.extract_node( + import_node, pandas_dataframe_node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ customers = pd.DataFrame(data) #@ """ ) @@ -33,12 +34,13 @@ def test_incorrect_dataframe_naming(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_assign(pandas_dataframe_node) def test_incorrect_dataframe_name_length(self): - pandas_dataframe_node = astroid.extract_node( + import_node, pandas_dataframe_node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_ = pd.DataFrame(data) #@ """ ) @@ -50,4 +52,5 @@ def test_incorrect_dataframe_name_length(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_assign(pandas_dataframe_node) diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_values.py b/tests/checkers/test_pandas/test_pandas_dataframe_values.py index 232d9bf..83a44aa 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_values.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_values.py @@ -9,9 +9,9 @@ class TestPandasValuesChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasValuesChecker def test_values_usage_with_correct_naming(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_sales = pd.DataFrame({ "A": [1, 2, 3], "B": [4, 5, 6] @@ -31,4 +31,5 @@ def test_values_usage_with_correct_naming(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_attribute(attribute_node) diff --git a/tests/checkers/test_pandas/test_pandas_inplace.py b/tests/checkers/test_pandas/test_pandas_inplace.py index 20ed034..43fd594 100644 --- a/tests/checkers/test_pandas/test_pandas_inplace.py +++ b/tests/checkers/test_pandas/test_pandas_inplace.py @@ -9,14 +9,14 @@ class TestPandasInplaceChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasInplaceChecker def test_inplace_used_in_drop(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [1, 2, 3], "B": [4, 5, 6] }) - df.drop(columns=["A"], inplace=True) #@ + df.drop(columns=["A"], inplace=True) #@ """ ) with self.assertAddsMessages( @@ -27,17 +27,18 @@ def test_inplace_used_in_drop(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) def test_inplace_used_in_fillna(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [1, None, 3], "B": [4, 5, None] }) - df.fillna(0, inplace=True) #@ + df.fillna(0, inplace=True) #@ """ ) with self.assertAddsMessages( @@ -48,17 +49,18 @@ def test_inplace_used_in_fillna(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) def test_inplace_used_in_sort_values(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [3, 2, 1], "B": [4, 5, 6] }) - df.sort_values(by="A", inplace=True) #@ + df.sort_values(by="A", inplace=True) #@ """ ) with self.assertAddsMessages( @@ -69,36 +71,39 @@ def test_inplace_used_in_sort_values(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) def test_no_inplace(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [1, 2, 3], "B": [4, 5, 6] }) - df = df.drop(columns=["A"]) #@ + df = df.drop(columns=["A"]) #@ """ ) inplace_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(inplace_call) def test_inplace_used_in_unsupported_method(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df = pd.DataFrame({ "A": [1, 2, 3], "B": [4, 5, 6] }) - df.append({"A": 4, "B": 7}, inplace=True) #@ + df.append({"A": 4, "B": 7}, inplace=True) #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(node) diff --git a/tests/checkers/test_pandas/test_pandas_parameter.py b/tests/checkers/test_pandas/test_pandas_parameter.py index 6cfe8ca..a51e736 100644 --- a/tests/checkers/test_pandas/test_pandas_parameter.py +++ b/tests/checkers/test_pandas/test_pandas_parameter.py @@ -9,10 +9,10 @@ class TestPandasParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasParameterChecker def test_dataframe_missing_data(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd - df_yoda = pd.DataFrame() #@ + import pandas as pd #@ + df_yoda = pd.DataFrame() #@ """ ) @@ -27,15 +27,16 @@ def test_dataframe_missing_data(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(dataframe_call) def test_merge_without_required_params(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda1 = pd.DataFrame({'A': [1, 2]}) df_yoda2 = pd.DataFrame({'A': [3, 4]}) - df_yoda_merged = df_yoda1.merge(df_yoda2) #@ + df_yoda_merged = df_yoda1.merge(df_yoda2) #@ """ ) @@ -50,13 +51,14 @@ def test_merge_without_required_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(merge_call) def test_read_csv_without_filepath(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd - df_yoda = pd.read_csv() #@ + import pandas as pd #@ + df_yoda = pd.read_csv() #@ """ ) @@ -71,14 +73,15 @@ def test_read_csv_without_filepath(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(read_csv_call) def test_to_csv_without_path(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda = pd.DataFrame({'A': [1, 2]}) - df_yoda.to_csv() #@ + df_yoda.to_csv() #@ """ ) @@ -93,14 +96,15 @@ def test_to_csv_without_path(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(to_csv_call) def test_groupby_without_by(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda = pd.DataFrame({'A': [1, 2]}) - df_yoda.groupby() #@ + df_yoda.groupby() #@ """ ) @@ -115,14 +119,15 @@ def test_groupby_without_by(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(groupby_call) def test_fillna_without_value(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda = pd.DataFrame({'A': [1, None]}) - df_yoda.fillna() #@ + df_yoda.fillna() #@ """ ) @@ -137,14 +142,15 @@ def test_fillna_without_value(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(fillna_call) def test_sort_values_without_by(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ df_yoda = pd.DataFrame({'A': [1, 2]}) - df_yoda.sort_values() #@ + df_yoda.sort_values() #@ """ ) @@ -159,13 +165,14 @@ def test_sort_values_without_by(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(sort_values_call) def test_merge_with_missing_validate(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd - df_3 = df_1.merge(right=df_2, how='inner', on='col1') #@ + import pandas as pd #@ + df_3 = df_1.merge(right=df_2, how='inner', on='col1') #@ """ ) @@ -180,13 +187,14 @@ def test_merge_with_missing_validate(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(merge_call) def test_merge_with_wrong_naming_and_missing_params(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd - merged_df = df_1.merge(right=df_2) #@ + import pandas as pd #@ + merged_df = df_1.merge(right=df_2) #@ """ ) @@ -198,17 +206,19 @@ def test_merge_with_wrong_naming_and_missing_params(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(merge_call) def test_merge_with_all_params_and_correct_naming(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd - df_merged = df_1.merge(right=df_2, how='inner', on='col1', validate='1:1') #@ + import pandas as pd #@ + df_merged = df_1.merge(right=df_2, how='inner', on='col1', validate='1:1') #@ """ ) merge_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(merge_call) diff --git a/tests/checkers/test_pandas/test_pandas_series_bool.py b/tests/checkers/test_pandas/test_pandas_series_bool.py index b1d5b42..8189fda 100644 --- a/tests/checkers/test_pandas/test_pandas_series_bool.py +++ b/tests/checkers/test_pandas/test_pandas_series_bool.py @@ -9,11 +9,11 @@ class TestSeriesBoolChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasSeriesBoolChecker def test_series_bool_usage(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ ser_customer = pd.Series(data) - ser_customer.bool() #@ + ser_customer.bool() #@ """ ) with self.assertAddsMessages( @@ -24,15 +24,17 @@ def test_series_bool_usage(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) def test_no_bool_usage(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd + import pandas as pd #@ ser_customer = pd.Series(data) - ser_customer.sum() #@ + ser_customer.sum() #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(node) diff --git a/tests/checkers/test_pandas/test_pandas_series_naming.py b/tests/checkers/test_pandas/test_pandas_series_naming.py index c76d928..5560be5 100644 --- a/tests/checkers/test_pandas/test_pandas_series_naming.py +++ b/tests/checkers/test_pandas/test_pandas_series_naming.py @@ -9,20 +9,21 @@ class TestPandasSeriesNamingChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasSeriesNamingChecker def test_series_correct_naming(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd - ser_sales = pd.Series([100, 200, 300]) + import pandas as pd #@ + ser_sales = pd.Series([100, 200, 300]) #@ """ ) with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_assign(node) def test_series_incorrect_naming(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd - df_sales = pd.Series([100, 200, 300]) + import pandas as pd #@ + df_sales = pd.Series([100, 200, 300]) #@ """ ) with self.assertAddsMessages( @@ -33,13 +34,14 @@ def test_series_incorrect_naming(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_assign(node) def test_series_invalid_length_naming(self): - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import pandas as pd - ser_ = pd.Series([True]) + import pandas as pd #@ + ser_ = pd.Series([True]) #@ """ ) with self.assertAddsMessages( @@ -50,4 +52,5 @@ def test_series_invalid_length_naming(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_assign(node) diff --git a/tests/checkers/test_scipy/test_scipy_parameter.py b/tests/checkers/test_scipy/test_scipy_parameter.py index 1bec24b..d5560c2 100644 --- a/tests/checkers/test_scipy/test_scipy_parameter.py +++ b/tests/checkers/test_scipy/test_scipy_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,11 +10,15 @@ class TestScipyParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = ScipyParameterChecker - def test_minimize_params(self): - node = astroid.extract_node( + # TODO CONTINUE WITH MOCK FOR ALL TESTS + @patch("pylint_ml.util.library_base_checker.version") + def test_minimize_params(self, mock_version): + mock_version.return_value = "1.7.0" + + importfrom_node, node = astroid.extract_node( """ - from scipy.optimize import minimize - result = minimize(x0=[1, 2, 3]) #@ + from scipy.optimize import minimize #@ + result = minimize(x0=[1, 2, 3]) #@ """ ) minimize_call = node.value @@ -26,12 +32,13 @@ def test_minimize_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(minimize_call) def test_curve_fit_params(self): - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from scipy.optimize import curve_fit + from scipy.optimize import curve_fit #@ params = curve_fit(xdata=[1, 2, 3], ydata=[4, 5, 6]) #@ """ ) @@ -49,9 +56,9 @@ def test_curve_fit_params(self): self.checker.visit_call(curve_fit_call) def test_quad_params(self): - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from scipy.integrate import quad + from scipy.integrate import quad #@ result = quad(a=0, b=1) #@ """ ) @@ -69,9 +76,9 @@ def test_quad_params(self): self.checker.visit_call(quad_call) def test_solve_ivp_params(self): - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from scipy.integrate import solve_ivp + from scipy.integrate import solve_ivp #@ result = solve_ivp(fun=None, t_span=[0, 1]) #@ """ ) @@ -89,9 +96,9 @@ def test_solve_ivp_params(self): self.checker.visit_call(solve_ivp_call) def test_ttest_ind_params(self): - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from scipy.stats import ttest_ind + from scipy.stats import ttest_ind #@ result = ttest_ind(a=[1, 2]) #@ """ ) @@ -109,9 +116,9 @@ def test_ttest_ind_params(self): self.checker.visit_call(ttest_ind_call) def test_euclidean_params(self): - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from scipy.spatial.distance import euclidean + from scipy.spatial.distance import euclidean #@ dist = euclidean(u=[1, 2, 3]) #@ """ ) From e3e2ae968bb41893c467592f5b856253800416d9 Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Thu, 19 Sep 2024 15:52:22 +0200 Subject: [PATCH 03/19] Update tests --- .../matplotlib/matplotlib_parameter.py | 15 ++--------- pylint_ml/checkers/numpy/numpy_parameter.py | 23 ++-------------- .../pandas/pandas_dataframe_empty_column.py | 1 - .../pandas/pandas_dataframe_iterrows.py | 1 - .../pandas/pandas_dataframe_naming.py | 1 - .../pandas/pandas_dataframe_values.py | 1 - pylint_ml/checkers/pandas/pandas_inplace.py | 1 - pylint_ml/checkers/pandas/pandas_parameter.py | 16 ++---------- .../checkers/pandas/pandas_series_bool.py | 1 - .../checkers/pandas/pandas_series_naming.py | 1 - pylint_ml/checkers/scipy/scipy_parameter.py | 24 ++--------------- .../checkers/sklearn/sklearn_parameter.py | 17 ++---------- .../checkers/tensorflow/tensor_parameter.py | 17 ++---------- pylint_ml/checkers/torch/torch_parameter.py | 21 +++------------ pylint_ml/util/common.py | 26 +++++++++++++++++++ pylint_ml/util/config.py | 2 -- .../test_scipy/test_scipy_parameter.py | 26 ++++++++++++++----- 17 files changed, 62 insertions(+), 132 deletions(-) create mode 100644 pylint_ml/util/common.py diff --git a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py index 40113c9..c4e2a97 100644 --- a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py +++ b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py @@ -8,6 +8,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.common import get_full_method_name from pylint_ml.util.config import LIB_MATPLOTLIB from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -51,7 +52,7 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=LIB_MATPLOTLIB, required_version=None): return - method_name = self._get_full_method_name(node) + method_name = get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] @@ -62,15 +63,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - def _get_full_method_name(self, node: nodes.Call) -> str: - func = node.func - method_chain = [] - - while isinstance(func, nodes.Attribute): - method_chain.insert(0, func.attrname) - func = func.expr - if isinstance(func, nodes.Name): - method_chain.insert(0, func.name) - - return ".".join(method_chain) diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index a3921de..575d82d 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -8,6 +8,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.common import get_full_method_name from pylint_ml.util.config import LIB_NUMPY from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -76,11 +77,9 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=LIB_NUMPY, required_version=None): return - method_name = self._get_full_method_name(node) - + method_name = get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -89,21 +88,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_full_method_name(node: nodes.Call) -> str: - """ - Extracts the full method name, including chained attributes (e.g., np.random.rand). - """ - func = node.func - method_chain = [] - - # Traverse the attribute chain - while isinstance(func, nodes.Attribute): - method_chain.insert(0, func.attrname) - func = func.expr - - # Check if the root of the chain is "np" (as NumPy functions are expected to use np. prefix) - if isinstance(func, nodes.Name) and func.name == "np": - return ".".join(method_chain) - return "" diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py index db82a34..670e9a7 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py @@ -7,7 +7,6 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py index b1c24a3..54ab445 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py @@ -7,7 +7,6 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py index fe08e3b..16550ea 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py @@ -7,7 +7,6 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_values.py b/pylint_ml/checkers/pandas/pandas_dataframe_values.py index 689daf5..52d083a 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_values.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_values.py @@ -7,7 +7,6 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH diff --git a/pylint_ml/checkers/pandas/pandas_inplace.py b/pylint_ml/checkers/pandas/pandas_inplace.py index 9426065..128212d 100644 --- a/pylint_ml/checkers/pandas/pandas_inplace.py +++ b/pylint_ml/checkers/pandas/pandas_inplace.py @@ -7,7 +7,6 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH diff --git a/pylint_ml/checkers/pandas/pandas_parameter.py b/pylint_ml/checkers/pandas/pandas_parameter.py index c6e629d..9eaa29c 100644 --- a/pylint_ml/checkers/pandas/pandas_parameter.py +++ b/pylint_ml/checkers/pandas/pandas_parameter.py @@ -8,6 +8,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.common import get_method_name from pylint_ml.util.config import LIB_PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -69,10 +70,9 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): return - method_name = self._get_method_name(node) + method_name = get_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -81,15 +81,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_method_name(node: nodes.Call) -> str: - """Extracts the method name from a Call node, including handling chained calls.""" - func = node.func - while isinstance(func, nodes.Attribute): - func = func.expr - return ( - node.func.attrname - if isinstance(node.func, nodes.Attribute) - else func.name if isinstance(func, nodes.Name) else "" - ) diff --git a/pylint_ml/checkers/pandas/pandas_series_bool.py b/pylint_ml/checkers/pandas/pandas_series_bool.py index d42d5d2..e31f06e 100644 --- a/pylint_ml/checkers/pandas/pandas_series_bool.py +++ b/pylint_ml/checkers/pandas/pandas_series_bool.py @@ -7,7 +7,6 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH diff --git a/pylint_ml/checkers/pandas/pandas_series_naming.py b/pylint_ml/checkers/pandas/pandas_series_naming.py index 780a426..93e7c3e 100644 --- a/pylint_ml/checkers/pandas/pandas_series_naming.py +++ b/pylint_ml/checkers/pandas/pandas_series_naming.py @@ -7,7 +7,6 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index 7af4a08..d6117ff 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -5,10 +5,10 @@ """Check for proper usage of Scipy functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.common import get_full_method_name from pylint_ml.util.config import LIB_SCIPY from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -48,10 +48,9 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=LIB_SCIPY, required_version=None): return - method_name = self._get_full_method_name(node) + method_name = get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -60,22 +59,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - def _get_full_method_name(self, node: nodes.Call) -> str: - """ - Extracts the full method name, including handling chained attributes (e.g., scipy.spatial.distance.euclidean) - and also handles direct imports like euclidean. - """ - func = node.func - method_chain = [] - - # Traverse the attribute chain to get the full method name - while isinstance(func, nodes.Attribute): - method_chain.insert(0, func.attrname) - func = func.expr - - # If it's a direct function name, like `euclidean`, return it - if isinstance(func, nodes.Name): - method_chain.insert(0, func.name) - - return ".".join(method_chain) diff --git a/pylint_ml/checkers/sklearn/sklearn_parameter.py b/pylint_ml/checkers/sklearn/sklearn_parameter.py index 8acd01d..cbe0d2e 100644 --- a/pylint_ml/checkers/sklearn/sklearn_parameter.py +++ b/pylint_ml/checkers/sklearn/sklearn_parameter.py @@ -5,10 +5,10 @@ """Check for proper usage of Scikit-learn functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.common import get_method_name from pylint_ml.util.config import LIB_SKLEARN from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -43,10 +43,9 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=LIB_SKLEARN, required_version=None): return - method_name = self._get_method_name(node) + method_name = get_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -55,15 +54,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_method_name(node: nodes.Call) -> str: - """Extracts the method name from a Call node, including handling chained calls.""" - func = node.func - while isinstance(func, nodes.Attribute): - func = func.expr - return ( - node.func.attrname - if isinstance(node.func, nodes.Attribute) - else func.name if isinstance(func, nodes.Name) else "" - ) diff --git a/pylint_ml/checkers/tensorflow/tensor_parameter.py b/pylint_ml/checkers/tensorflow/tensor_parameter.py index 8adb795..ae3bdd9 100644 --- a/pylint_ml/checkers/tensorflow/tensor_parameter.py +++ b/pylint_ml/checkers/tensorflow/tensor_parameter.py @@ -5,10 +5,10 @@ """Check for proper usage of Tensorflow functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH +from pylint_ml.util.common import get_method_name from pylint_ml.util.config import LIB_TENSORFLOW from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -41,10 +41,9 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=LIB_TENSORFLOW, required_version=None): return - method_name = self._get_method_name(node) + method_name = get_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -53,15 +52,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_method_name(node: nodes.Call) -> str: - """Extracts the method name from a Call node, including handling chained calls.""" - func = node.func - while isinstance(func, nodes.Attribute): - func = func.expr - return ( - node.func.attrname - if isinstance(node.func, nodes.Attribute) - else func.name if isinstance(func, nodes.Name) else "" - ) diff --git a/pylint_ml/checkers/torch/torch_parameter.py b/pylint_ml/checkers/torch/torch_parameter.py index 14e373f..463867c 100644 --- a/pylint_ml/checkers/torch/torch_parameter.py +++ b/pylint_ml/checkers/torch/torch_parameter.py @@ -5,11 +5,11 @@ """Check for proper usage of PyTorch functions with required parameters.""" from astroid import nodes -from pylint.checkers import BaseChecker from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_TENSORFLOW +from pylint_ml.util.common import get_method_name +from pylint_ml.util.config import LIB_PYTORCH from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -37,13 +37,12 @@ class PyTorchParameterChecker(LibraryBaseChecker): @only_required_for_messages("pytorch-parameter") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_TENSORFLOW, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=LIB_PYTORCH, required_version=None): return - method_name = self._get_method_name(node) + method_name = get_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - # Collect all missing parameters missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( @@ -52,15 +51,3 @@ def visit_call(self, node: nodes.Call) -> None: confidence=HIGH, args=(", ".join(missing_params), method_name), ) - - @staticmethod - def _get_method_name(node: nodes.Call) -> str: - """Extracts the method name from a Call node, including handling chained calls.""" - func = node.func - while isinstance(func, nodes.Attribute): - func = func.expr - return ( - node.func.attrname - if isinstance(node.func, nodes.Attribute) - else func.name if isinstance(func, nodes.Name) else "" - ) diff --git a/pylint_ml/util/common.py b/pylint_ml/util/common.py new file mode 100644 index 0000000..bff1069 --- /dev/null +++ b/pylint_ml/util/common.py @@ -0,0 +1,26 @@ +from astroid import nodes + + +def get_method_name(node: nodes.Call) -> str: + """Extracts the method name from a Call node, including handling chained calls.""" + func = node.func + while isinstance(func, nodes.Attribute): + func = func.expr + return ( + node.func.attrname + if isinstance(node.func, nodes.Attribute) + else func.name if isinstance(func, nodes.Name) else "" + ) + + +def get_full_method_name(node: nodes.Call) -> str: + func = node.func + method_chain = [] + + while isinstance(func, nodes.Attribute): + method_chain.insert(0, func.attrname) + func = func.expr + if isinstance(func, nodes.Name): + method_chain.insert(0, func.name) + + return ".".join(method_chain) diff --git a/pylint_ml/util/config.py b/pylint_ml/util/config.py index 59caa7b..2632fb2 100644 --- a/pylint_ml/util/config.py +++ b/pylint_ml/util/config.py @@ -1,5 +1,3 @@ - - # Library names LIB_PANDAS = "pandas" LIB_NUMPY = "numpy" diff --git a/tests/checkers/test_scipy/test_scipy_parameter.py b/tests/checkers/test_scipy/test_scipy_parameter.py index d5560c2..bf95249 100644 --- a/tests/checkers/test_scipy/test_scipy_parameter.py +++ b/tests/checkers/test_scipy/test_scipy_parameter.py @@ -14,7 +14,6 @@ class TestScipyParameterChecker(pylint.testutils.CheckerTestCase): @patch("pylint_ml.util.library_base_checker.version") def test_minimize_params(self, mock_version): mock_version.return_value = "1.7.0" - importfrom_node, node = astroid.extract_node( """ from scipy.optimize import minimize #@ @@ -35,7 +34,9 @@ def test_minimize_params(self, mock_version): self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(minimize_call) - def test_curve_fit_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_curve_fit_params(self, mock_version): + mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( """ from scipy.optimize import curve_fit #@ @@ -53,9 +54,12 @@ def test_curve_fit_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(curve_fit_call) - def test_quad_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_quad_params(self, mock_version): + mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( """ from scipy.integrate import quad #@ @@ -73,9 +77,12 @@ def test_quad_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(quad_call) - def test_solve_ivp_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_solve_ivp_params(self, mock_version): + mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( """ from scipy.integrate import solve_ivp #@ @@ -93,9 +100,12 @@ def test_solve_ivp_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(solve_ivp_call) - def test_ttest_ind_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_ttest_ind_params(self, mock_version): + mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( """ from scipy.stats import ttest_ind #@ @@ -113,9 +123,12 @@ def test_ttest_ind_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(ttest_ind_call) - def test_euclidean_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_euclidean_params(self, mock_version): + mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( """ from scipy.spatial.distance import euclidean #@ @@ -133,4 +146,5 @@ def test_euclidean_params(self): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(euclidean_call) From bfd41dcc7131b501732f15de19a772b68c180a3c Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Fri, 20 Sep 2024 15:25:58 +0200 Subject: [PATCH 04/19] Mock tests to use library version --- .../matplotlib/matplotlib_parameter.py | 7 ++-- pylint_ml/checkers/numpy/numpy_dot.py | 4 +- .../checkers/numpy/numpy_nan_comparison.py | 6 +-- pylint_ml/checkers/numpy/numpy_parameter.py | 6 +-- .../checkers/pandas/pandas_dataframe_bool.py | 4 +- .../pandas_dataframe_column_selection.py | 4 +- .../pandas/pandas_dataframe_empty_column.py | 4 +- .../pandas/pandas_dataframe_iterrows.py | 4 +- .../pandas/pandas_dataframe_naming.py | 4 +- .../pandas/pandas_dataframe_values.py | 4 +- pylint_ml/checkers/pandas/pandas_inplace.py | 4 +- pylint_ml/checkers/pandas/pandas_parameter.py | 4 +- .../checkers/pandas/pandas_series_bool.py | 4 +- .../checkers/pandas/pandas_series_naming.py | 4 +- pylint_ml/checkers/scipy/scipy_parameter.py | 6 +-- .../checkers/sklearn/sklearn_parameter.py | 4 +- .../checkers/tensorflow/tensor_parameter.py | 4 +- pylint_ml/checkers/torch/torch_parameter.py | 4 +- pylint_ml/util/common.py | 13 ++++-- pylint_ml/util/config.py | 23 +++++++---- pylint_ml/util/library_base_checker.py | 5 +-- tests/checkers/test_numpy/test_numpy_dot.py | 6 ++- .../test_numpy/test_numpy_nan_comparison.py | 11 +++-- .../test_numpy/test_numpy_parameter.py | 22 +++++++--- .../pandas_dataframe_column_selection.py | 4 +- .../test_pandas/test_pandas_dataframe_bool.py | 8 +++- .../test_pandas_dataframe_empty_column.py | 12 ++++-- .../test_pandas_dataframe_iterrows.py | 4 +- .../test_pandas_dataframe_naming.py | 12 ++++-- .../test_pandas_dataframe_values.py | 4 +- .../test_pandas/test_pandas_inplace.py | 5 +++ .../test_pandas/test_pandas_parameter.py | 40 ++++++++++++++----- .../test_pandas/test_pandas_series_bool.py | 8 +++- .../test_pandas/test_pandas_series_naming.py | 12 ++++-- .../test_scipy/test_scipy_parameter.py | 1 - .../test_sklearn/test_sklearn_parameter.py | 24 ++++++++--- .../test_tensorflow/test_tensor_parameter.py | 40 ++++++++++++++----- .../test_torch/test_torch_parameter.py | 40 ++++++++++++++----- 38 files changed, 258 insertions(+), 117 deletions(-) diff --git a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py index c4e2a97..225d717 100644 --- a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py +++ b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import LIB_MATPLOTLIB +from pylint_ml.util.config import MATPLOTLIB from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -49,10 +49,11 @@ class MatplotlibParameterChecker(LibraryBaseChecker): @only_required_for_messages("matplotlib-parameter") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_MATPLOTLIB, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=MATPLOTLIB, required_version=None): return - method_name = get_full_method_name(node) + # TODO UPDATE + method_name = get_full_method_name(lib_alias="", node=node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/numpy/numpy_dot.py b/pylint_ml/checkers/numpy/numpy_dot.py index e32e96b..048e105 100644 --- a/pylint_ml/checkers/numpy/numpy_dot.py +++ b/pylint_ml/checkers/numpy/numpy_dot.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_NUMPY +from pylint_ml.util.config import NUMPY from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -30,7 +30,7 @@ def visit_import(self, node: nodes.Import): @only_required_for_messages("numpy-dot-usage") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_NUMPY, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): return # Check if the function being called is np.dot diff --git a/pylint_ml/checkers/numpy/numpy_nan_comparison.py b/pylint_ml/checkers/numpy/numpy_nan_comparison.py index 6f9d330..e07eb25 100644 --- a/pylint_ml/checkers/numpy/numpy_nan_comparison.py +++ b/pylint_ml/checkers/numpy/numpy_nan_comparison.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_NUMPY +from pylint_ml.util.config import NUMPY, NUMPY_ALIAS from pylint_ml.util.library_base_checker import LibraryBaseChecker COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "==")) @@ -30,11 +30,11 @@ class NumpyNaNComparisonChecker(LibraryBaseChecker): @classmethod def __is_np_nan_call(cls, node: nodes.Attribute) -> bool: """Check if the node represents a call to np.nan.""" - return node.attrname in NUMPY_NAN and isinstance(node.expr, nodes.Name) and node.expr.name == "np" + return node.attrname in NUMPY_NAN and isinstance(node.expr, nodes.Name) and node.expr.name == NUMPY_ALIAS @only_required_for_messages("numpy-nan-compare") def visit_compare(self, node: nodes.Compare) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_NUMPY, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): return if isinstance(node.left, nodes.Attribute) and self.__is_np_nan_call(node.left): diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index 575d82d..c9fda0b 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import LIB_NUMPY +from pylint_ml.util.config import NUMPY, NUMPY_ALIAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -74,10 +74,10 @@ class NumPyParameterChecker(LibraryBaseChecker): @only_required_for_messages("numpy-parameter") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_NUMPY, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): return - method_name = get_full_method_name(node) + method_name = get_full_method_name(lib_alias=NUMPY_ALIAS, node=node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py index cf0c2ae..fc17bce 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -26,7 +26,7 @@ class PandasDataFrameBoolChecker(LibraryBaseChecker): @only_required_for_messages("pandas-dataframe-bool") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version="2.1.0"): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version="2.1.0"): return if isinstance(node.func, nodes.Attribute): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py index f34be18..1748d24 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -28,7 +28,7 @@ class PandasColumnSelectionChecker(LibraryBaseChecker): def visit_attribute(self, node: nodes.Attribute) -> None: """Check for attribute access that might be a column selection.""" - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return if isinstance(node.expr, nodes.Name) and node.expr.name.startswith("df_"): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py index 670e9a7..fb37145 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -27,7 +27,7 @@ class PandasEmptyColumnChecker(LibraryBaseChecker): @only_required_for_messages("pandas-dataframe-empty-column") def visit_subscript(self, node: nodes.Subscript) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return if isinstance(node.value, nodes.Name) and node.value.name.startswith("df_"): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py index 54ab445..b83dee3 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -27,7 +27,7 @@ class PandasIterrowsChecker(LibraryBaseChecker): @only_required_for_messages("pandas-iterrows") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return if isinstance(node.func, nodes.Attribute): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py index 16550ea..67644f0 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -26,7 +26,7 @@ class PandasDataFrameNamingChecker(LibraryBaseChecker): @only_required_for_messages("pandas-dataframe-naming") def visit_assign(self, node: nodes.Assign) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return if isinstance(node.value, nodes.Call): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_values.py b/pylint_ml/checkers/pandas/pandas_dataframe_values.py index 52d083a..d4c107a 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_values.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_values.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -27,7 +27,7 @@ class PandasValuesChecker(LibraryBaseChecker): @only_required_for_messages("pandas-dataframe-values") def visit_attribute(self, node: nodes.Attribute) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return if isinstance(node.expr, nodes.Name): diff --git a/pylint_ml/checkers/pandas/pandas_inplace.py b/pylint_ml/checkers/pandas/pandas_inplace.py index 128212d..804b113 100644 --- a/pylint_ml/checkers/pandas/pandas_inplace.py +++ b/pylint_ml/checkers/pandas/pandas_inplace.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -41,7 +41,7 @@ class PandasInplaceChecker(LibraryBaseChecker): @only_required_for_messages("pandas-inplace") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return # Check if the call is to a method that supports 'inplace' diff --git a/pylint_ml/checkers/pandas/pandas_parameter.py b/pylint_ml/checkers/pandas/pandas_parameter.py index 9eaa29c..b72994e 100644 --- a/pylint_ml/checkers/pandas/pandas_parameter.py +++ b/pylint_ml/checkers/pandas/pandas_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_method_name -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -67,7 +67,7 @@ class PandasParameterChecker(LibraryBaseChecker): @only_required_for_messages("pandas-parameter") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return method_name = get_method_name(node) diff --git a/pylint_ml/checkers/pandas/pandas_series_bool.py b/pylint_ml/checkers/pandas/pandas_series_bool.py index e31f06e..24747a8 100644 --- a/pylint_ml/checkers/pandas/pandas_series_bool.py +++ b/pylint_ml/checkers/pandas/pandas_series_bool.py @@ -11,7 +11,7 @@ from pylint.interfaces import HIGH # Todo add version deprecated -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -27,7 +27,7 @@ class PandasSeriesBoolChecker(LibraryBaseChecker): @only_required_for_messages("pandas-series-bool") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return if isinstance(node.func, nodes.Attribute): diff --git a/pylint_ml/checkers/pandas/pandas_series_naming.py b/pylint_ml/checkers/pandas/pandas_series_naming.py index 93e7c3e..7fe1f32 100644 --- a/pylint_ml/checkers/pandas/pandas_series_naming.py +++ b/pylint_ml/checkers/pandas/pandas_series_naming.py @@ -10,7 +10,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import LIB_PANDAS +from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -26,7 +26,7 @@ class PandasSeriesNamingChecker(LibraryBaseChecker): @only_required_for_messages("pandas-series-naming") def visit_assign(self, node: nodes.Assign) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PANDAS, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return if isinstance(node.value, nodes.Call): diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index d6117ff..eb9c12f 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import LIB_SCIPY +from pylint_ml.util.config import SCIPY, PANDAS_ALIAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -45,10 +45,10 @@ class ScipyParameterChecker(LibraryBaseChecker): @only_required_for_messages("scipy-parameter") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_SCIPY, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=SCIPY, required_version=None): return - method_name = get_full_method_name(node) + method_name = get_full_method_name(lib_alias=PANDAS_ALIAS, node=node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/sklearn/sklearn_parameter.py b/pylint_ml/checkers/sklearn/sklearn_parameter.py index cbe0d2e..8824a4d 100644 --- a/pylint_ml/checkers/sklearn/sklearn_parameter.py +++ b/pylint_ml/checkers/sklearn/sklearn_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_method_name -from pylint_ml.util.config import LIB_SKLEARN +from pylint_ml.util.config import SKLEARN from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -40,7 +40,7 @@ class SklearnParameterChecker(LibraryBaseChecker): @only_required_for_messages("sklearn-parameter") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_SKLEARN, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=SKLEARN, required_version=None): return method_name = get_method_name(node) diff --git a/pylint_ml/checkers/tensorflow/tensor_parameter.py b/pylint_ml/checkers/tensorflow/tensor_parameter.py index ae3bdd9..40704b3 100644 --- a/pylint_ml/checkers/tensorflow/tensor_parameter.py +++ b/pylint_ml/checkers/tensorflow/tensor_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_method_name -from pylint_ml.util.config import LIB_TENSORFLOW +from pylint_ml.util.config import TENSORFLOW from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -38,7 +38,7 @@ class TensorFlowParameterChecker(LibraryBaseChecker): @only_required_for_messages("tensor-parameter") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_TENSORFLOW, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=TENSORFLOW, required_version=None): return method_name = get_method_name(node) diff --git a/pylint_ml/checkers/torch/torch_parameter.py b/pylint_ml/checkers/torch/torch_parameter.py index 463867c..b7d38e0 100644 --- a/pylint_ml/checkers/torch/torch_parameter.py +++ b/pylint_ml/checkers/torch/torch_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_method_name -from pylint_ml.util.config import LIB_PYTORCH +from pylint_ml.util.config import PYTORCH from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -37,7 +37,7 @@ class PyTorchParameterChecker(LibraryBaseChecker): @only_required_for_messages("pytorch-parameter") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=LIB_PYTORCH, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=PYTORCH, required_version=None): return method_name = get_method_name(node) diff --git a/pylint_ml/util/common.py b/pylint_ml/util/common.py index bff1069..846278f 100644 --- a/pylint_ml/util/common.py +++ b/pylint_ml/util/common.py @@ -13,14 +13,19 @@ def get_method_name(node: nodes.Call) -> str: ) -def get_full_method_name(node: nodes.Call) -> str: +def get_full_method_name(lib_alias: str, node: nodes.Call) -> str: + """ + Extracts the full method name, including chained attributes (e.g., np.random.rand). + """ func = node.func method_chain = [] + # Traverse the attribute chain while isinstance(func, nodes.Attribute): method_chain.insert(0, func.attrname) func = func.expr - if isinstance(func, nodes.Name): - method_chain.insert(0, func.name) - return ".".join(method_chain) + # Check if the root of the chain is "np" (as NumPy functions are expected to use np. prefix) + if isinstance(func, nodes.Name) and func.name == lib_alias: + return ".".join(method_chain) + return "" diff --git a/pylint_ml/util/config.py b/pylint_ml/util/config.py index 2632fb2..6b99e86 100644 --- a/pylint_ml/util/config.py +++ b/pylint_ml/util/config.py @@ -1,8 +1,17 @@ # Library names -LIB_PANDAS = "pandas" -LIB_NUMPY = "numpy" -LIB_TENSORFLOW = "tensor" -LIB_SCIPY = "scipy" -LIB_SKLEARN = "sklearn" -LIB_PYTORCH = "torch" -LIB_MATPLOTLIB = "matplotlib" +PANDAS = "pandas" +PANDAS_ALIAS = "pd" + +NUMPY = "numpy" +NUMPY_ALIAS = "np" + +TENSORFLOW = "tensor" + +SCIPY = "scipy" + +SKLEARN = "sklearn" + +PYTORCH = "torch" + +MATPLOTLIB = "matplotlib" + diff --git a/pylint_ml/util/library_base_checker.py b/pylint_ml/util/library_base_checker.py index ba4cd83..6c1e110 100644 --- a/pylint_ml/util/library_base_checker.py +++ b/pylint_ml/util/library_base_checker.py @@ -31,14 +31,11 @@ def is_library_imported_and_version_valid(self, lib_name, required_version): Checks if the library is imported and whether the installed version is valid (greater than or equal to the required version). - param lib_name: Name of the library (as a string). + param lib_alias: Name of the library (as a string). param required_version: The required minimum version (as a string). return: True if the library is imported and the version is valid, otherwise False. """ # Check if the library is imported - print("xxxxxxxxxxxx") - print(lib_name) - if not any(mod.startswith(lib_name) for mod in self.imports.values()): return False diff --git a/tests/checkers/test_numpy/test_numpy_dot.py b/tests/checkers/test_numpy/test_numpy_dot.py index f01b811..6c7626b 100644 --- a/tests/checkers/test_numpy/test_numpy_dot.py +++ b/tests/checkers/test_numpy/test_numpy_dot.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,7 +10,9 @@ class TestNumpyDotChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumpyDotChecker - def test_warning_for_dot(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_warning_for_dot(self, mock_version): + mock_version.return_value = "1.7.0" import_np, node = astroid.extract_node( """ import numpy as np #@ diff --git a/tests/checkers/test_numpy/test_numpy_nan_comparison.py b/tests/checkers/test_numpy/test_numpy_nan_comparison.py index 5f9f2a7..00562af 100644 --- a/tests/checkers/test_numpy/test_numpy_nan_comparison.py +++ b/tests/checkers/test_numpy/test_numpy_nan_comparison.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,10 +10,12 @@ class TestNumpyNaNComparison(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumpyNaNComparisonChecker - def test_singleton_nan_compare(self): - singleton_node, chained_node, great_than_node = astroid.extract_node( + @patch("pylint_ml.util.library_base_checker.version") + def test_singleton_nan_compare(self, mock_version): + mock_version.return_value = "2.1.1" + import_node, singleton_node, chained_node, great_than_node = astroid.extract_node( """ - import numpy as np + import numpy as np #@ a_nan = np.array([0, 1, np.nan]) np.nan == a_nan #@ @@ -38,6 +42,7 @@ def test_singleton_nan_compare(self): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_compare(singleton_node) self.checker.visit_compare(chained_node) self.checker.visit_compare(great_than_node) diff --git a/tests/checkers/test_numpy/test_numpy_parameter.py b/tests/checkers/test_numpy/test_numpy_parameter.py index fffcc0c..3919c14 100644 --- a/tests/checkers/test_numpy/test_numpy_parameter.py +++ b/tests/checkers/test_numpy/test_numpy_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -8,7 +10,9 @@ class TestNumPyParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumPyParameterChecker - def test_array_missing_object(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_array_missing_object(self, mock_version): + mock_version.return_value = "2.1.1" import_node, call_node = astroid.extract_node( """ import numpy as np #@ @@ -30,7 +34,9 @@ def test_array_missing_object(self): self.checker.visit_import(import_node) self.checker.visit_call(call_node) - def test_zeros_without_shape(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_zeros_without_shape(self, mock_version): + mock_version.return_value = "2.1.1" import_node, node = astroid.extract_node( """ import numpy as np #@ @@ -52,7 +58,9 @@ def test_zeros_without_shape(self): self.checker.visit_import(import_node) self.checker.visit_call(zeros_call) - def test_random_rand_without_shape(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_random_rand_without_shape(self, mock_version): + mock_version.return_value = "2.1.1" import_node, node = astroid.extract_node( """ import numpy as np #@ @@ -74,7 +82,9 @@ def test_random_rand_without_shape(self): self.checker.visit_import(import_node) self.checker.visit_call(rand_call) - def test_dot_without_b(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_dot_without_b(self, mock_version): + mock_version.return_value = "2.1.1" import_node, node = astroid.extract_node( """ import numpy as np #@ @@ -96,7 +106,9 @@ def test_dot_without_b(self): self.checker.visit_import(import_node) self.checker.visit_call(dot_call) - def test_percentile_without_q(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_percentile_without_q(self, mock_version): + mock_version.return_value = "2.1.1" import_node, node = astroid.extract_node( """ import numpy as np #@ diff --git a/tests/checkers/test_pandas/pandas_dataframe_column_selection.py b/tests/checkers/test_pandas/pandas_dataframe_column_selection.py index 511ef2d..15fc321 100644 --- a/tests/checkers/test_pandas/pandas_dataframe_column_selection.py +++ b/tests/checkers/test_pandas/pandas_dataframe_column_selection.py @@ -8,7 +8,9 @@ class TestPandasColumnSelectionChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasColumnSelectionChecker - def test_incorrect_column_selection(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_incorrect_column_selection(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py index 44cd400..d47a3be 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py @@ -8,7 +8,9 @@ class TestDataFrameBoolChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasDataFrameBoolChecker - def test_dataframe_bool_usage(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_dataframe_bool_usage(self, mock_version): + mock_version.return_value = "2.2.2" import_node, call_node = astroid.extract_node( """ import pandas as pd #@ @@ -27,7 +29,9 @@ def test_dataframe_bool_usage(self): self.checker.visit_import(import_node) self.checker.visit_call(call_node) - def test_no_bool_usage(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_no_bool_usage(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py index 1b1c7e0..c6c26ab 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py @@ -8,7 +8,9 @@ class TestPandasEmptyColumnChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasEmptyColumnChecker - def test_correct_empty_column_initialization(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_correct_empty_column_initialization(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -20,7 +22,9 @@ def test_correct_empty_column_initialization(self): self.checker.visit_import(import_node) self.checker.visit_subscript(node) - def test_incorrect_empty_column_initialization_with_zero(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_incorrect_empty_column_initialization_with_zero(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -42,7 +46,9 @@ def test_incorrect_empty_column_initialization_with_zero(self): self.checker.visit_import(import_node) self.checker.visit_subscript(subscript_node) - def test_incorrect_empty_column_initialization_with_empty_string(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_incorrect_empty_column_initialization_with_empty_string(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py index b796b69..95be7b1 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py @@ -8,7 +8,9 @@ class TestPandasIterrowsChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasIterrowsChecker - def test_iterrows_used(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_iterrows_used(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py index 91a30dd..45bdf99 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py @@ -8,7 +8,9 @@ class TestPandasDataFrameNamingChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasDataFrameNamingChecker - def test_correct_dataframe_naming(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_correct_dataframe_naming(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -19,7 +21,9 @@ def test_correct_dataframe_naming(self): self.checker.visit_import(import_node) self.checker.visit_assign(node) - def test_incorrect_dataframe_naming(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_incorrect_dataframe_naming(self, mock_version): + mock_version.return_value = "2.2.2" import_node, pandas_dataframe_node = astroid.extract_node( """ import pandas as pd #@ @@ -37,7 +41,9 @@ def test_incorrect_dataframe_naming(self): self.checker.visit_import(import_node) self.checker.visit_assign(pandas_dataframe_node) - def test_incorrect_dataframe_name_length(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_incorrect_dataframe_name_length(self, mock_version): + mock_version.return_value = "2.2.2" import_node, pandas_dataframe_node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_values.py b/tests/checkers/test_pandas/test_pandas_dataframe_values.py index 83a44aa..b12d355 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_values.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_values.py @@ -8,7 +8,9 @@ class TestPandasValuesChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasValuesChecker - def test_values_usage_with_correct_naming(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_values_usage_with_correct_naming(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_pandas/test_pandas_inplace.py b/tests/checkers/test_pandas/test_pandas_inplace.py index 43fd594..fee643f 100644 --- a/tests/checkers/test_pandas/test_pandas_inplace.py +++ b/tests/checkers/test_pandas/test_pandas_inplace.py @@ -8,6 +8,7 @@ class TestPandasInplaceChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasInplaceChecker + @patch("pylint_ml.util.library_base_checker.version") def test_inplace_used_in_drop(self): import_node, node = astroid.extract_node( """ @@ -30,6 +31,7 @@ def test_inplace_used_in_drop(self): self.checker.visit_import(import_node) self.checker.visit_call(node) + @patch("pylint_ml.util.library_base_checker.version") def test_inplace_used_in_fillna(self): import_node, node = astroid.extract_node( """ @@ -52,6 +54,7 @@ def test_inplace_used_in_fillna(self): self.checker.visit_import(import_node) self.checker.visit_call(node) + @patch("pylint_ml.util.library_base_checker.version") def test_inplace_used_in_sort_values(self): import_node, node = astroid.extract_node( """ @@ -74,6 +77,7 @@ def test_inplace_used_in_sort_values(self): self.checker.visit_import(import_node) self.checker.visit_call(node) + @patch("pylint_ml.util.library_base_checker.version") def test_no_inplace(self): import_node, node = astroid.extract_node( """ @@ -92,6 +96,7 @@ def test_no_inplace(self): self.checker.visit_import(import_node) self.checker.visit_call(inplace_call) + @patch("pylint_ml.util.library_base_checker.version") def test_inplace_used_in_unsupported_method(self): import_node, node = astroid.extract_node( """ diff --git a/tests/checkers/test_pandas/test_pandas_parameter.py b/tests/checkers/test_pandas/test_pandas_parameter.py index a51e736..17ef98d 100644 --- a/tests/checkers/test_pandas/test_pandas_parameter.py +++ b/tests/checkers/test_pandas/test_pandas_parameter.py @@ -8,7 +8,9 @@ class TestPandasParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasParameterChecker - def test_dataframe_missing_data(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_dataframe_missing_data(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -30,7 +32,9 @@ def test_dataframe_missing_data(self): self.checker.visit_import(import_node) self.checker.visit_call(dataframe_call) - def test_merge_without_required_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_merge_without_required_params(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -54,7 +58,9 @@ def test_merge_without_required_params(self): self.checker.visit_import(import_node) self.checker.visit_call(merge_call) - def test_read_csv_without_filepath(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_read_csv_without_filepath(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -76,7 +82,9 @@ def test_read_csv_without_filepath(self): self.checker.visit_import(import_node) self.checker.visit_call(read_csv_call) - def test_to_csv_without_path(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_to_csv_without_path(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -99,7 +107,9 @@ def test_to_csv_without_path(self): self.checker.visit_import(import_node) self.checker.visit_call(to_csv_call) - def test_groupby_without_by(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_groupby_without_by(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -122,7 +132,9 @@ def test_groupby_without_by(self): self.checker.visit_import(import_node) self.checker.visit_call(groupby_call) - def test_fillna_without_value(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_fillna_without_value(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -145,7 +157,9 @@ def test_fillna_without_value(self): self.checker.visit_import(import_node) self.checker.visit_call(fillna_call) - def test_sort_values_without_by(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_sort_values_without_by(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -168,7 +182,9 @@ def test_sort_values_without_by(self): self.checker.visit_import(import_node) self.checker.visit_call(sort_values_call) - def test_merge_with_missing_validate(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_merge_with_missing_validate(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -190,7 +206,9 @@ def test_merge_with_missing_validate(self): self.checker.visit_import(import_node) self.checker.visit_call(merge_call) - def test_merge_with_wrong_naming_and_missing_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_merge_with_wrong_naming_and_missing_params(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -209,7 +227,9 @@ def test_merge_with_wrong_naming_and_missing_params(self): self.checker.visit_import(import_node) self.checker.visit_call(merge_call) - def test_merge_with_all_params_and_correct_naming(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_merge_with_all_params_and_correct_naming(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_pandas/test_pandas_series_bool.py b/tests/checkers/test_pandas/test_pandas_series_bool.py index 8189fda..464b25e 100644 --- a/tests/checkers/test_pandas/test_pandas_series_bool.py +++ b/tests/checkers/test_pandas/test_pandas_series_bool.py @@ -8,7 +8,9 @@ class TestSeriesBoolChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasSeriesBoolChecker - def test_series_bool_usage(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_series_bool_usage(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -27,7 +29,9 @@ def test_series_bool_usage(self): self.checker.visit_import(import_node) self.checker.visit_call(node) - def test_no_bool_usage(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_no_bool_usage(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_pandas/test_pandas_series_naming.py b/tests/checkers/test_pandas/test_pandas_series_naming.py index 5560be5..e755005 100644 --- a/tests/checkers/test_pandas/test_pandas_series_naming.py +++ b/tests/checkers/test_pandas/test_pandas_series_naming.py @@ -8,7 +8,9 @@ class TestPandasSeriesNamingChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasSeriesNamingChecker - def test_series_correct_naming(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_series_correct_naming(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -19,7 +21,9 @@ def test_series_correct_naming(self): self.checker.visit_import(import_node) self.checker.visit_assign(node) - def test_series_incorrect_naming(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_series_incorrect_naming(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -37,7 +41,9 @@ def test_series_incorrect_naming(self): self.checker.visit_import(import_node) self.checker.visit_assign(node) - def test_series_invalid_length_naming(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_series_invalid_length_naming(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_scipy/test_scipy_parameter.py b/tests/checkers/test_scipy/test_scipy_parameter.py index bf95249..6d9da75 100644 --- a/tests/checkers/test_scipy/test_scipy_parameter.py +++ b/tests/checkers/test_scipy/test_scipy_parameter.py @@ -10,7 +10,6 @@ class TestScipyParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = ScipyParameterChecker - # TODO CONTINUE WITH MOCK FOR ALL TESTS @patch("pylint_ml.util.library_base_checker.version") def test_minimize_params(self, mock_version): mock_version.return_value = "1.7.0" diff --git a/tests/checkers/test_sklearn/test_sklearn_parameter.py b/tests/checkers/test_sklearn/test_sklearn_parameter.py index 9612965..4c80e08 100644 --- a/tests/checkers/test_sklearn/test_sklearn_parameter.py +++ b/tests/checkers/test_sklearn/test_sklearn_parameter.py @@ -8,7 +8,9 @@ class TestSklearnParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = SklearnParameterChecker - def test_random_forest_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_random_forest_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ from sklearn.ensemble import RandomForestClassifier @@ -29,7 +31,9 @@ def test_random_forest_params(self): ): self.checker.visit_call(forest_call) - def test_random_forest_with_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_random_forest_with_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ from sklearn.ensemble import RandomForestClassifier @@ -42,7 +46,9 @@ def test_random_forest_with_params(self): with self.assertNoMessages(): self.checker.visit_call(forest_call) - def test_svc_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_svc_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ from sklearn.svm import SVC @@ -63,7 +69,9 @@ def test_svc_params(self): ): self.checker.visit_call(svc_call) - def test_svc_with_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_svc_with_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ from sklearn.svm import SVC @@ -76,7 +84,9 @@ def test_svc_with_params(self): with self.assertNoMessages(): self.checker.visit_call(svc_call) - def test_kmeans_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_kmeans_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ from sklearn.cluster import KMeans @@ -97,7 +107,9 @@ def test_kmeans_params(self): ): self.checker.visit_call(kmeans_call) - def test_kmeans_with_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_kmeans_with_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ from sklearn.cluster import KMeans diff --git a/tests/checkers/test_tensorflow/test_tensor_parameter.py b/tests/checkers/test_tensorflow/test_tensor_parameter.py index 48197dd..d2f11fd 100644 --- a/tests/checkers/test_tensorflow/test_tensor_parameter.py +++ b/tests/checkers/test_tensorflow/test_tensor_parameter.py @@ -8,7 +8,9 @@ class TestTensorParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = TensorFlowParameterChecker - def test_sequential_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_sequential_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf @@ -29,7 +31,9 @@ def test_sequential_params(self): ): self.checker.visit_call(sequential_call) - def test_sequential_with_layers(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_sequential_with_layers(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf @@ -45,7 +49,9 @@ def test_sequential_with_layers(self): with self.assertNoMessages(): self.checker.visit_call(sequential_call) - def test_compile_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_compile_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf @@ -65,7 +71,9 @@ def test_compile_params(self): ): self.checker.visit_call(node) - def test_compile_with_all_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_compile_with_all_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf @@ -79,7 +87,9 @@ def test_compile_with_all_params(self): with self.assertNoMessages(): self.checker.visit_call(compile_call) - def test_fit_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_fit_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf @@ -102,7 +112,9 @@ def test_fit_params(self): ): self.checker.visit_call(fit_call) - def test_fit_with_all_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_fit_with_all_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf @@ -117,7 +129,9 @@ def test_fit_with_all_params(self): with self.assertNoMessages(): self.checker.visit_call(fit_call) - def test_conv2d_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_conv2d_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf @@ -138,7 +152,9 @@ def test_conv2d_params(self): ): self.checker.visit_call(conv2d_call) - def test_conv2d_with_all_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_conv2d_with_all_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf @@ -151,7 +167,9 @@ def test_conv2d_with_all_params(self): with self.assertNoMessages(): self.checker.visit_call(conv2d_call) - def test_dense_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_dense_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf @@ -172,7 +190,9 @@ def test_dense_params(self): ): self.checker.visit_call(dense_call) - def test_dense_with_all_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_dense_with_all_params(self, mock_version): + mock_version.return_value = "1.5.2" node = astroid.extract_node( """ import tensorflow as tf diff --git a/tests/checkers/test_torch/test_torch_parameter.py b/tests/checkers/test_torch/test_torch_parameter.py index 6c81205..fe2ffb4 100644 --- a/tests/checkers/test_torch/test_torch_parameter.py +++ b/tests/checkers/test_torch/test_torch_parameter.py @@ -8,7 +8,9 @@ class TestTorchParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PyTorchParameterChecker - def test_sgd_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_sgd_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.optim as optim @@ -29,7 +31,9 @@ def test_sgd_params(self): ): self.checker.visit_call(sgd_call) - def test_sgd_with_all_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_sgd_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.optim as optim @@ -42,7 +46,9 @@ def test_sgd_with_all_params(self): with self.assertNoMessages(): self.checker.visit_call(sgd_call) - def test_adam_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_adam_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.optim as optim @@ -63,7 +69,9 @@ def test_adam_params(self): ): self.checker.visit_call(adam_call) - def test_adam_with_all_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_adam_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.optim as optim @@ -76,7 +84,9 @@ def test_adam_with_all_params(self): with self.assertNoMessages(): self.checker.visit_call(adam_call) - def test_conv2d_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_conv2d_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.nn as nn @@ -97,7 +107,9 @@ def test_conv2d_params(self): ): self.checker.visit_call(conv2d_call) - def test_conv2d_with_all_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_conv2d_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.nn as nn @@ -110,7 +122,9 @@ def test_conv2d_with_all_params(self): with self.assertNoMessages(): self.checker.visit_call(conv2d_call) - def test_linear_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_linear_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.nn as nn @@ -131,7 +145,9 @@ def test_linear_params(self): ): self.checker.visit_call(linear_call) - def test_linear_with_all_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_linear_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.nn as nn @@ -144,7 +160,9 @@ def test_linear_with_all_params(self): with self.assertNoMessages(): self.checker.visit_call(linear_call) - def test_lstm_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_lstm_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.nn as nn @@ -165,7 +183,9 @@ def test_lstm_params(self): ): self.checker.visit_call(lstm_call) - def test_lstm_with_all_params(self): + @patch("pylint_ml.util.library_base_checker.version") + def test_lstm_with_all_params(self, mock_version): + mock_version.return_value = "2.4.1" node = astroid.extract_node( """ import torch.nn as nn From 6b9e23511e31028622f7000043202d781237aa52 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Sep 2024 13:26:18 +0000 Subject: [PATCH 05/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pylint_ml/checkers/scipy/scipy_parameter.py | 2 +- pylint_ml/util/config.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index eb9c12f..00916c8 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import SCIPY, PANDAS_ALIAS +from pylint_ml.util.config import PANDAS_ALIAS, SCIPY from pylint_ml.util.library_base_checker import LibraryBaseChecker diff --git a/pylint_ml/util/config.py b/pylint_ml/util/config.py index 6b99e86..f145c23 100644 --- a/pylint_ml/util/config.py +++ b/pylint_ml/util/config.py @@ -14,4 +14,3 @@ PYTORCH = "torch" MATPLOTLIB = "matplotlib" - From 9ed9a0bfbcc7e90f77d1066314601b3c24901ad7 Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Mon, 23 Sep 2024 14:25:55 +0200 Subject: [PATCH 06/19] Update tests --- .../matplotlib/matplotlib_parameter.py | 3 +- pylint_ml/checkers/numpy/numpy_parameter.py | 13 +++++---- pylint_ml/checkers/pandas/pandas_parameter.py | 4 +-- pylint_ml/checkers/scipy/scipy_parameter.py | 4 +-- .../checkers/sklearn/sklearn_parameter.py | 4 +-- .../checkers/tensorflow/tensor_parameter.py | 4 +-- pylint_ml/checkers/torch/torch_parameter.py | 4 +-- pylint_ml/util/common.py | 29 +++++++------------ pylint_ml/util/config.py | 1 - pylint_ml/util/library_base_checker.py | 16 ++++------ .../pandas_dataframe_column_selection.py | 2 ++ .../test_pandas/test_pandas_dataframe_bool.py | 2 ++ .../test_pandas_dataframe_empty_column.py | 2 ++ .../test_pandas_dataframe_iterrows.py | 2 ++ .../test_pandas_dataframe_naming.py | 2 ++ .../test_pandas_dataframe_values.py | 2 ++ .../test_pandas/test_pandas_inplace.py | 17 +++++++---- .../test_pandas/test_pandas_parameter.py | 2 ++ .../test_pandas/test_pandas_series_bool.py | 2 ++ .../test_pandas/test_pandas_series_naming.py | 2 ++ .../test_sklearn/test_sklearn_parameter.py | 2 ++ .../test_tensorflow/test_tensor_parameter.py | 2 ++ .../test_torch/test_torch_parameter.py | 2 ++ 23 files changed, 70 insertions(+), 53 deletions(-) diff --git a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py index 225d717..0e3dc8c 100644 --- a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py +++ b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py @@ -52,8 +52,7 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=MATPLOTLIB, required_version=None): return - # TODO UPDATE - method_name = get_full_method_name(lib_alias="", node=node) + method_name = get_full_method_name(node=node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index c9fda0b..b5c17a9 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import NUMPY, NUMPY_ALIAS +from pylint_ml.util.config import NUMPY from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -77,14 +77,17 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): return - method_name = get_full_method_name(lib_alias=NUMPY_ALIAS, node=node) - if method_name in self.REQUIRED_PARAMS: + method_name = get_full_method_name(node=node) + extracted_method = method_name[len("np.") :] + if method_name.startswith("np.") and extracted_method in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] + missing_params = [ + param for param in self.REQUIRED_PARAMS[extracted_method] if param not in provided_keywords + ] if missing_params: self.add_message( "numpy-parameter", node=node, confidence=HIGH, - args=(", ".join(missing_params), method_name), + args=(", ".join(missing_params), extracted_method), ) diff --git a/pylint_ml/checkers/pandas/pandas_parameter.py b/pylint_ml/checkers/pandas/pandas_parameter.py index b72994e..ccea76a 100644 --- a/pylint_ml/checkers/pandas/pandas_parameter.py +++ b/pylint_ml/checkers/pandas/pandas_parameter.py @@ -8,7 +8,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.common import get_method_name +from pylint_ml.util.common import get_full_method_name from pylint_ml.util.config import PANDAS from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -70,7 +70,7 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return - method_name = get_method_name(node) + method_name = get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index eb9c12f..3d6fedb 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -9,7 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import SCIPY, PANDAS_ALIAS +from pylint_ml.util.config import SCIPY from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -48,7 +48,7 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=SCIPY, required_version=None): return - method_name = get_full_method_name(lib_alias=PANDAS_ALIAS, node=node) + method_name = get_full_method_name(node=node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/sklearn/sklearn_parameter.py b/pylint_ml/checkers/sklearn/sklearn_parameter.py index 8824a4d..c5b567d 100644 --- a/pylint_ml/checkers/sklearn/sklearn_parameter.py +++ b/pylint_ml/checkers/sklearn/sklearn_parameter.py @@ -8,7 +8,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.common import get_method_name +from pylint_ml.util.common import get_full_method_name from pylint_ml.util.config import SKLEARN from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -43,7 +43,7 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=SKLEARN, required_version=None): return - method_name = get_method_name(node) + method_name = get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/tensorflow/tensor_parameter.py b/pylint_ml/checkers/tensorflow/tensor_parameter.py index 40704b3..2649dd4 100644 --- a/pylint_ml/checkers/tensorflow/tensor_parameter.py +++ b/pylint_ml/checkers/tensorflow/tensor_parameter.py @@ -8,7 +8,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.common import get_method_name +from pylint_ml.util.common import get_full_method_name from pylint_ml.util.config import TENSORFLOW from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -41,7 +41,7 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=TENSORFLOW, required_version=None): return - method_name = get_method_name(node) + method_name = get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/torch/torch_parameter.py b/pylint_ml/checkers/torch/torch_parameter.py index b7d38e0..49888fe 100644 --- a/pylint_ml/checkers/torch/torch_parameter.py +++ b/pylint_ml/checkers/torch/torch_parameter.py @@ -8,7 +8,7 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.common import get_method_name +from pylint_ml.util.common import get_full_method_name from pylint_ml.util.config import PYTORCH from pylint_ml.util.library_base_checker import LibraryBaseChecker @@ -40,7 +40,7 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=PYTORCH, required_version=None): return - method_name = get_method_name(node) + method_name = get_full_method_name(node) if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/util/common.py b/pylint_ml/util/common.py index 846278f..b435a9f 100644 --- a/pylint_ml/util/common.py +++ b/pylint_ml/util/common.py @@ -1,31 +1,22 @@ from astroid import nodes -def get_method_name(node: nodes.Call) -> str: - """Extracts the method name from a Call node, including handling chained calls.""" - func = node.func - while isinstance(func, nodes.Attribute): - func = func.expr - return ( - node.func.attrname - if isinstance(node.func, nodes.Attribute) - else func.name if isinstance(func, nodes.Name) else "" - ) - - -def get_full_method_name(lib_alias: str, node: nodes.Call) -> str: +def get_full_method_name(node: nodes.Call) -> str: """ - Extracts the full method name, including chained attributes (e.g., np.random.rand). + Extracts the full method name from a Call node, including handling chained calls. """ func = node.func method_chain = [] - # Traverse the attribute chain + # Traverse the attribute chain to build the full method chain while isinstance(func, nodes.Attribute): method_chain.insert(0, func.attrname) func = func.expr - # Check if the root of the chain is "np" (as NumPy functions are expected to use np. prefix) - if isinstance(func, nodes.Name) and func.name == lib_alias: - return ".".join(method_chain) - return "" + # Check if the root of the chain is a Name node (like a module or base name) + if isinstance(func, nodes.Name): + method_chain.insert(0, func.name) # Add the base name + + print(method_chain) + # Join the method chain to create the full method name + return ".".join(method_chain) diff --git a/pylint_ml/util/config.py b/pylint_ml/util/config.py index 6b99e86..f145c23 100644 --- a/pylint_ml/util/config.py +++ b/pylint_ml/util/config.py @@ -14,4 +14,3 @@ PYTORCH = "torch" MATPLOTLIB = "matplotlib" - diff --git a/pylint_ml/util/library_base_checker.py b/pylint_ml/util/library_base_checker.py index 6c1e110..fef5145 100644 --- a/pylint_ml/util/library_base_checker.py +++ b/pylint_ml/util/library_base_checker.py @@ -11,27 +11,21 @@ def __init__(self, linter): def visit_import(self, node): for name, alias in node.names: - self.imports[alias or name] = name + self.imports[alias or name] = name # E.g. {'pd': 'pandas'} def visit_importfrom(self, node): - module = node.modname - print(module) + base_module = node.modname.split(".")[0] # Extract the first part of the module name for name, alias in node.names: - print(name) - print(alias) - print("-------------") - full_name = f"{module}.{name}" - self.imports[alias or name] = full_name - - print(self.imports) + full_name = f"{node.modname}.{name}" + self.imports[base_module] = full_name # E.g. {'scipy': 'scipy.optimize.minimize'} def is_library_imported_and_version_valid(self, lib_name, required_version): """ Checks if the library is imported and whether the installed version is valid (greater than or equal to the required version). - param lib_alias: Name of the library (as a string). + param lib_name: Name of the library (as a string). param required_version: The required minimum version (as a string). return: True if the library is imported and the version is valid, otherwise False. """ diff --git a/tests/checkers/test_pandas/pandas_dataframe_column_selection.py b/tests/checkers/test_pandas/pandas_dataframe_column_selection.py index 15fc321..22e295b 100644 --- a/tests/checkers/test_pandas/pandas_dataframe_column_selection.py +++ b/tests/checkers/test_pandas/pandas_dataframe_column_selection.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py index d47a3be..c1c3b0b 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py index c6c26ab..db2bb73 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py index 95be7b1..145931d 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py index 45bdf99..2c13d6f 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_values.py b/tests/checkers/test_pandas/test_pandas_dataframe_values.py index b12d355..2fbcebb 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_values.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_values.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_pandas/test_pandas_inplace.py b/tests/checkers/test_pandas/test_pandas_inplace.py index fee643f..54a0527 100644 --- a/tests/checkers/test_pandas/test_pandas_inplace.py +++ b/tests/checkers/test_pandas/test_pandas_inplace.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH @@ -9,7 +11,8 @@ class TestPandasInplaceChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasInplaceChecker @patch("pylint_ml.util.library_base_checker.version") - def test_inplace_used_in_drop(self): + def test_inplace_used_in_drop(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -32,7 +35,8 @@ def test_inplace_used_in_drop(self): self.checker.visit_call(node) @patch("pylint_ml.util.library_base_checker.version") - def test_inplace_used_in_fillna(self): + def test_inplace_used_in_fillna(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -55,7 +59,8 @@ def test_inplace_used_in_fillna(self): self.checker.visit_call(node) @patch("pylint_ml.util.library_base_checker.version") - def test_inplace_used_in_sort_values(self): + def test_inplace_used_in_sort_values(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -78,7 +83,8 @@ def test_inplace_used_in_sort_values(self): self.checker.visit_call(node) @patch("pylint_ml.util.library_base_checker.version") - def test_no_inplace(self): + def test_no_inplace(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ @@ -97,7 +103,8 @@ def test_no_inplace(self): self.checker.visit_call(inplace_call) @patch("pylint_ml.util.library_base_checker.version") - def test_inplace_used_in_unsupported_method(self): + def test_inplace_used_in_unsupported_method(self, mock_version): + mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( """ import pandas as pd #@ diff --git a/tests/checkers/test_pandas/test_pandas_parameter.py b/tests/checkers/test_pandas/test_pandas_parameter.py index 17ef98d..7f87e3b 100644 --- a/tests/checkers/test_pandas/test_pandas_parameter.py +++ b/tests/checkers/test_pandas/test_pandas_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_pandas/test_pandas_series_bool.py b/tests/checkers/test_pandas/test_pandas_series_bool.py index 464b25e..5da72a6 100644 --- a/tests/checkers/test_pandas/test_pandas_series_bool.py +++ b/tests/checkers/test_pandas/test_pandas_series_bool.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_pandas/test_pandas_series_naming.py b/tests/checkers/test_pandas/test_pandas_series_naming.py index e755005..4510d7b 100644 --- a/tests/checkers/test_pandas/test_pandas_series_naming.py +++ b/tests/checkers/test_pandas/test_pandas_series_naming.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_sklearn/test_sklearn_parameter.py b/tests/checkers/test_sklearn/test_sklearn_parameter.py index 4c80e08..f7be58c 100644 --- a/tests/checkers/test_sklearn/test_sklearn_parameter.py +++ b/tests/checkers/test_sklearn/test_sklearn_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_tensorflow/test_tensor_parameter.py b/tests/checkers/test_tensorflow/test_tensor_parameter.py index d2f11fd..99c26db 100644 --- a/tests/checkers/test_tensorflow/test_tensor_parameter.py +++ b/tests/checkers/test_tensorflow/test_tensor_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH diff --git a/tests/checkers/test_torch/test_torch_parameter.py b/tests/checkers/test_torch/test_torch_parameter.py index fe2ffb4..5e8ffb1 100644 --- a/tests/checkers/test_torch/test_torch_parameter.py +++ b/tests/checkers/test_torch/test_torch_parameter.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import astroid import pylint.testutils from pylint.interfaces import HIGH From 0eb5477c6ab9608fa5e6c78f51daff7c174d53fe Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Mon, 23 Sep 2024 14:27:42 +0200 Subject: [PATCH 07/19] Update tests --- pylint_ml/checkers/scipy/scipy_parameter.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index 08fd5f9..3d6fedb 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -9,11 +9,7 @@ from pylint.interfaces import HIGH from pylint_ml.util.common import get_full_method_name -<<<<<<< HEAD from pylint_ml.util.config import SCIPY -======= -from pylint_ml.util.config import PANDAS_ALIAS, SCIPY ->>>>>>> 6b9e23511e31028622f7000043202d781237aa52 from pylint_ml.util.library_base_checker import LibraryBaseChecker From d0176dea49f73e676355cfb4e4959421fc34dc14 Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Mon, 23 Sep 2024 16:18:09 +0200 Subject: [PATCH 08/19] Add safe_infer --- pylint_ml/{util => checkers}/config.py | 0 .../library_base_checker.py | 0 .../matplotlib/matplotlib_parameter.py | 6 ++--- pylint_ml/checkers/numpy/numpy_dot.py | 4 +-- .../checkers/numpy/numpy_nan_comparison.py | 4 +-- pylint_ml/checkers/numpy/numpy_parameter.py | 18 +++++++++---- .../checkers/pandas/pandas_dataframe_bool.py | 4 +-- .../pandas_dataframe_column_selection.py | 4 +-- .../pandas/pandas_dataframe_empty_column.py | 4 +-- .../pandas/pandas_dataframe_iterrows.py | 4 +-- .../pandas/pandas_dataframe_naming.py | 4 +-- .../pandas/pandas_dataframe_values.py | 4 +-- pylint_ml/checkers/pandas/pandas_inplace.py | 4 +-- pylint_ml/checkers/pandas/pandas_parameter.py | 27 +++++++++++++------ .../checkers/pandas/pandas_series_bool.py | 4 +-- .../checkers/pandas/pandas_series_naming.py | 4 +-- pylint_ml/checkers/scipy/scipy_parameter.py | 17 +++++++++--- .../checkers/sklearn/sklearn_parameter.py | 6 ++--- .../checkers/tensorflow/tensor_parameter.py | 6 ++--- pylint_ml/checkers/torch/torch_parameter.py | 6 ++--- .../{util/common.py => checkers/utils.py} | 14 ++++++++++ ...test_pandas_dataframe_column_selection.py} | 0 22 files changed, 93 insertions(+), 51 deletions(-) rename pylint_ml/{util => checkers}/config.py (100%) rename pylint_ml/{util => checkers}/library_base_checker.py (100%) rename pylint_ml/{util/common.py => checkers/utils.py} (57%) rename tests/checkers/test_pandas/{pandas_dataframe_column_selection.py => test_pandas_dataframe_column_selection.py} (100%) diff --git a/pylint_ml/util/config.py b/pylint_ml/checkers/config.py similarity index 100% rename from pylint_ml/util/config.py rename to pylint_ml/checkers/config.py diff --git a/pylint_ml/util/library_base_checker.py b/pylint_ml/checkers/library_base_checker.py similarity index 100% rename from pylint_ml/util/library_base_checker.py rename to pylint_ml/checkers/library_base_checker.py diff --git a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py index 0e3dc8c..6b655f0 100644 --- a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py +++ b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import MATPLOTLIB -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.config import MATPLOTLIB +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class MatplotlibParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/numpy/numpy_dot.py b/pylint_ml/checkers/numpy/numpy_dot.py index 048e105..e255f96 100644 --- a/pylint_ml/checkers/numpy/numpy_dot.py +++ b/pylint_ml/checkers/numpy/numpy_dot.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import NUMPY -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import NUMPY +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class NumpyDotChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/numpy/numpy_nan_comparison.py b/pylint_ml/checkers/numpy/numpy_nan_comparison.py index e07eb25..0ef4254 100644 --- a/pylint_ml/checkers/numpy/numpy_nan_comparison.py +++ b/pylint_ml/checkers/numpy/numpy_nan_comparison.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import NUMPY, NUMPY_ALIAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import NUMPY, NUMPY_ALIAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "==")) NUMPY_NAN = frozenset(("nan", "NaN", "NAN")) diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index b5c17a9..f5f8d66 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -5,12 +5,12 @@ """Check for proper usage of numpy functions with required parameters.""" from astroid import nodes -from pylint.checkers.utils import only_required_for_messages +from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH -from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import NUMPY -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.config import NUMPY +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class NumPyParameterChecker(LibraryBaseChecker): @@ -78,7 +78,15 @@ def visit_call(self, node: nodes.Call) -> None: return method_name = get_full_method_name(node=node) - extracted_method = method_name[len("np.") :] + extracted_method = method_name[len("np."):] + + infer_node = safe_infer(node=node) + infer_object = safe_infer(node.func.expr) + print(node.func.expr) + print(infer_object) + print("------") + print(infer_node) + if method_name.startswith("np.") and extracted_method in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [ diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py index fc17bce..d1b0119 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasDataFrameBoolChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py index 1748d24..447fce9 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasColumnSelectionChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py index fb37145..52e32dc 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasEmptyColumnChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py index b83dee3..0f6d424 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasIterrowsChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py index 67644f0..15af29e 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasDataFrameNamingChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_values.py b/pylint_ml/checkers/pandas/pandas_dataframe_values.py index d4c107a..1ea4133 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_values.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_values.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasValuesChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_inplace.py b/pylint_ml/checkers/pandas/pandas_inplace.py index 804b113..0d66689 100644 --- a/pylint_ml/checkers/pandas/pandas_inplace.py +++ b/pylint_ml/checkers/pandas/pandas_inplace.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasInplaceChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_parameter.py b/pylint_ml/checkers/pandas/pandas_parameter.py index ccea76a..a1d34fb 100644 --- a/pylint_ml/checkers/pandas/pandas_parameter.py +++ b/pylint_ml/checkers/pandas/pandas_parameter.py @@ -5,12 +5,12 @@ """Check for proper usage of Pandas functions with required parameters.""" from astroid import nodes -from pylint.checkers.utils import only_required_for_messages +from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH -from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasParameterChecker(LibraryBaseChecker): @@ -70,14 +70,25 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return - method_name = get_full_method_name(node) - if method_name in self.REQUIRED_PARAMS: + method_name = get_full_method_name(node=node) + extracted_method = method_name[len("pd."):] + + infer_node = safe_infer(node=node) + infer_object = safe_infer(node.func.expr) + print(node.func.expr) + print(infer_object) + print("------") + print(infer_node) + + if method_name.startswith("pd.") and extracted_method in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] + missing_params = [ + param for param in self.REQUIRED_PARAMS[extracted_method] if param not in provided_keywords + ] if missing_params: self.add_message( "pandas-parameter", node=node, confidence=HIGH, - args=(", ".join(missing_params), method_name), + args=(", ".join(missing_params), extracted_method), ) diff --git a/pylint_ml/checkers/pandas/pandas_series_bool.py b/pylint_ml/checkers/pandas/pandas_series_bool.py index 24747a8..aae0512 100644 --- a/pylint_ml/checkers/pandas/pandas_series_bool.py +++ b/pylint_ml/checkers/pandas/pandas_series_bool.py @@ -11,8 +11,8 @@ from pylint.interfaces import HIGH # Todo add version deprecated -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasSeriesBoolChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/pandas/pandas_series_naming.py b/pylint_ml/checkers/pandas/pandas_series_naming.py index 7fe1f32..4e30aa5 100644 --- a/pylint_ml/checkers/pandas/pandas_series_naming.py +++ b/pylint_ml/checkers/pandas/pandas_series_naming.py @@ -10,8 +10,8 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.config import PANDAS -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.config import PANDAS +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PandasSeriesNamingChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index 3d6fedb..a793b7b 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -5,12 +5,12 @@ """Check for proper usage of Scipy functions with required parameters.""" from astroid import nodes -from pylint.checkers.utils import only_required_for_messages +from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH -from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import SCIPY -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.config import SCIPY +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class ScipyParameterChecker(LibraryBaseChecker): @@ -49,6 +49,15 @@ def visit_call(self, node: nodes.Call) -> None: return method_name = get_full_method_name(node=node) + + infer_node = safe_infer(node=node) + print("------") + print(infer_node) + infer_object = safe_infer(node.func.expr) + print(node.func.expr) + print(infer_object) + + if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/sklearn/sklearn_parameter.py b/pylint_ml/checkers/sklearn/sklearn_parameter.py index c5b567d..924cc36 100644 --- a/pylint_ml/checkers/sklearn/sklearn_parameter.py +++ b/pylint_ml/checkers/sklearn/sklearn_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import SKLEARN -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.config import SKLEARN +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class SklearnParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/tensorflow/tensor_parameter.py b/pylint_ml/checkers/tensorflow/tensor_parameter.py index 2649dd4..d9cde68 100644 --- a/pylint_ml/checkers/tensorflow/tensor_parameter.py +++ b/pylint_ml/checkers/tensorflow/tensor_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import TENSORFLOW -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.config import TENSORFLOW +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class TensorFlowParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/torch/torch_parameter.py b/pylint_ml/checkers/torch/torch_parameter.py index 49888fe..bfcf0af 100644 --- a/pylint_ml/checkers/torch/torch_parameter.py +++ b/pylint_ml/checkers/torch/torch_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.util.common import get_full_method_name -from pylint_ml.util.config import PYTORCH -from pylint_ml.util.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.config import PYTORCH +from pylint_ml.checkers.library_base_checker import LibraryBaseChecker class PyTorchParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/util/common.py b/pylint_ml/checkers/utils.py similarity index 57% rename from pylint_ml/util/common.py rename to pylint_ml/checkers/utils.py index b435a9f..1086526 100644 --- a/pylint_ml/util/common.py +++ b/pylint_ml/checkers/utils.py @@ -20,3 +20,17 @@ def get_full_method_name(node: nodes.Call) -> str: print(method_chain) # Join the method chain to create the full method name return ".".join(method_chain) + + +def is_specific_library_object(node: nodes.NodeNG, library_name: str) -> bool: + """ + Returns True if the given node is an object from the specified library/module. + + Args: + node: The AST node to check. + library_name: The name of the library/module to check (e.g., 'pandas', 'numpy'). + + Returns: + bool: True if the node belongs to the specified library, False otherwise. + """ + return node and node.root().name == library_name # Checks if the root module matches the library name diff --git a/tests/checkers/test_pandas/pandas_dataframe_column_selection.py b/tests/checkers/test_pandas/test_pandas_dataframe_column_selection.py similarity index 100% rename from tests/checkers/test_pandas/pandas_dataframe_column_selection.py rename to tests/checkers/test_pandas/test_pandas_dataframe_column_selection.py From 3ccef1b0c18d39b9af3be88ddc77a2e84d5c1f8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 14:18:25 +0000 Subject: [PATCH 09/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pylint_ml/checkers/matplotlib/matplotlib_parameter.py | 2 +- pylint_ml/checkers/numpy/numpy_parameter.py | 4 ++-- pylint_ml/checkers/pandas/pandas_parameter.py | 4 ++-- pylint_ml/checkers/scipy/scipy_parameter.py | 3 +-- pylint_ml/checkers/sklearn/sklearn_parameter.py | 2 +- pylint_ml/checkers/tensorflow/tensor_parameter.py | 2 +- pylint_ml/checkers/torch/torch_parameter.py | 2 +- 7 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py index 6b655f0..e630dcf 100644 --- a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py +++ b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import MATPLOTLIB from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class MatplotlibParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index f5f8d66..c717824 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import NUMPY from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class NumPyParameterChecker(LibraryBaseChecker): @@ -78,7 +78,7 @@ def visit_call(self, node: nodes.Call) -> None: return method_name = get_full_method_name(node=node) - extracted_method = method_name[len("np."):] + extracted_method = method_name[len("np.") :] infer_node = safe_infer(node=node) infer_object = safe_infer(node.func.expr) diff --git a/pylint_ml/checkers/pandas/pandas_parameter.py b/pylint_ml/checkers/pandas/pandas_parameter.py index a1d34fb..dabf5cb 100644 --- a/pylint_ml/checkers/pandas/pandas_parameter.py +++ b/pylint_ml/checkers/pandas/pandas_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class PandasParameterChecker(LibraryBaseChecker): @@ -71,7 +71,7 @@ def visit_call(self, node: nodes.Call) -> None: return method_name = get_full_method_name(node=node) - extracted_method = method_name[len("pd."):] + extracted_method = method_name[len("pd.") :] infer_node = safe_infer(node=node) infer_object = safe_infer(node.func.expr) diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index a793b7b..2d7010a 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import SCIPY from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class ScipyParameterChecker(LibraryBaseChecker): @@ -57,7 +57,6 @@ def visit_call(self, node: nodes.Call) -> None: print(node.func.expr) print(infer_object) - if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/sklearn/sklearn_parameter.py b/pylint_ml/checkers/sklearn/sklearn_parameter.py index 924cc36..b8f3a95 100644 --- a/pylint_ml/checkers/sklearn/sklearn_parameter.py +++ b/pylint_ml/checkers/sklearn/sklearn_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import SKLEARN from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class SklearnParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/tensorflow/tensor_parameter.py b/pylint_ml/checkers/tensorflow/tensor_parameter.py index d9cde68..e143354 100644 --- a/pylint_ml/checkers/tensorflow/tensor_parameter.py +++ b/pylint_ml/checkers/tensorflow/tensor_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import TENSORFLOW from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class TensorFlowParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/torch/torch_parameter.py b/pylint_ml/checkers/torch/torch_parameter.py index bfcf0af..ac37021 100644 --- a/pylint_ml/checkers/torch/torch_parameter.py +++ b/pylint_ml/checkers/torch/torch_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import PYTORCH from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class PyTorchParameterChecker(LibraryBaseChecker): From 8c17663207cd13eccb9eb16dc519947a7093d002 Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Tue, 24 Sep 2024 15:46:48 +0200 Subject: [PATCH 10/19] Add utilization of safe_infer --- .../matplotlib/matplotlib_parameter.py | 2 +- pylint_ml/checkers/numpy/numpy_dot.py | 18 +++-- .../checkers/numpy/numpy_nan_comparison.py | 7 +- pylint_ml/checkers/numpy/numpy_parameter.py | 4 +- pylint_ml/checkers/pandas/pandas_parameter.py | 4 +- pylint_ml/checkers/scipy/scipy_parameter.py | 3 +- .../checkers/sklearn/sklearn_parameter.py | 2 +- .../checkers/tensorflow/tensor_parameter.py | 2 +- pylint_ml/checkers/torch/torch_parameter.py | 2 +- pylint_ml/checkers/utils.py | 67 ++++++++++++++++++- tests/checkers/test_numpy/test_numpy_dot.py | 2 +- .../test_numpy/test_numpy_nan_comparison.py | 2 +- .../test_numpy/test_numpy_parameter.py | 10 +-- .../test_pandas/test_pandas_dataframe_bool.py | 4 +- .../test_pandas_dataframe_column_selection.py | 2 +- .../test_pandas_dataframe_empty_column.py | 6 +- .../test_pandas_dataframe_iterrows.py | 2 +- .../test_pandas_dataframe_naming.py | 6 +- .../test_pandas_dataframe_values.py | 2 +- .../test_pandas/test_pandas_inplace.py | 10 +-- .../test_pandas/test_pandas_parameter.py | 20 +++--- .../test_pandas/test_pandas_series_bool.py | 4 +- .../test_pandas/test_pandas_series_naming.py | 6 +- .../test_scipy/test_scipy_parameter.py | 12 ++-- .../test_sklearn/test_sklearn_parameter.py | 12 ++-- .../test_tensorflow/test_tensor_parameter.py | 20 +++--- .../test_torch/test_torch_parameter.py | 20 +++--- 27 files changed, 158 insertions(+), 93 deletions(-) diff --git a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py index 6b655f0..e630dcf 100644 --- a/pylint_ml/checkers/matplotlib/matplotlib_parameter.py +++ b/pylint_ml/checkers/matplotlib/matplotlib_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import MATPLOTLIB from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class MatplotlibParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/numpy/numpy_dot.py b/pylint_ml/checkers/numpy/numpy_dot.py index e255f96..3c905d1 100644 --- a/pylint_ml/checkers/numpy/numpy_dot.py +++ b/pylint_ml/checkers/numpy/numpy_dot.py @@ -12,6 +12,7 @@ from pylint_ml.checkers.config import NUMPY from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call class NumpyDotChecker(LibraryBaseChecker): @@ -25,19 +26,16 @@ class NumpyDotChecker(LibraryBaseChecker): ), } - def visit_import(self, node: nodes.Import): - super().visit_import(node=node) - @only_required_for_messages("numpy-dot-usage") def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): return # Check if the function being called is np.dot - if isinstance(node.func, nodes.Attribute): - func_name = node.func.attrname - module_name = getattr(node.func.expr, "name", None) - - if func_name == "dot" and module_name == "np": - # Suggest using np.matmul() instead - self.add_message("numpy-dot-usage", node=node, confidence=HIGH) + if ( + isinstance(node.func, nodes.Attribute) + and node.func.attrname == "dot" + and infer_specific_module_from_call(node=node, module_name=NUMPY) + ): + # Suggest using np.matmul() instead + self.add_message("numpy-dot-usage", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/numpy/numpy_nan_comparison.py b/pylint_ml/checkers/numpy/numpy_nan_comparison.py index 0ef4254..0d6ee1f 100644 --- a/pylint_ml/checkers/numpy/numpy_nan_comparison.py +++ b/pylint_ml/checkers/numpy/numpy_nan_comparison.py @@ -10,8 +10,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.config import NUMPY, NUMPY_ALIAS +from pylint_ml.checkers.config import NUMPY from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_attribute COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "==")) NUMPY_NAN = frozenset(("nan", "NaN", "NAN")) @@ -30,17 +31,19 @@ class NumpyNaNComparisonChecker(LibraryBaseChecker): @classmethod def __is_np_nan_call(cls, node: nodes.Attribute) -> bool: """Check if the node represents a call to np.nan.""" - return node.attrname in NUMPY_NAN and isinstance(node.expr, nodes.Name) and node.expr.name == NUMPY_ALIAS + return node.attrname in NUMPY_NAN and (infer_specific_module_from_attribute(node=node, module_name="numpy")) @only_required_for_messages("numpy-nan-compare") def visit_compare(self, node: nodes.Compare) -> None: if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): return + # Check node.left first for numpy nan usage if isinstance(node.left, nodes.Attribute) and self.__is_np_nan_call(node.left): self.add_message("numpy-nan-compare", node=node, confidence=HIGH) return + # Check remaining nodes and operators for numpy nan usage for op, comparator in node.ops: if op in COMPARISON_OP and isinstance(comparator, nodes.Attribute) and self.__is_np_nan_call(comparator): self.add_message("numpy-nan-compare", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index f5f8d66..c717824 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import NUMPY from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class NumPyParameterChecker(LibraryBaseChecker): @@ -78,7 +78,7 @@ def visit_call(self, node: nodes.Call) -> None: return method_name = get_full_method_name(node=node) - extracted_method = method_name[len("np."):] + extracted_method = method_name[len("np.") :] infer_node = safe_infer(node=node) infer_object = safe_infer(node.func.expr) diff --git a/pylint_ml/checkers/pandas/pandas_parameter.py b/pylint_ml/checkers/pandas/pandas_parameter.py index a1d34fb..dabf5cb 100644 --- a/pylint_ml/checkers/pandas/pandas_parameter.py +++ b/pylint_ml/checkers/pandas/pandas_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class PandasParameterChecker(LibraryBaseChecker): @@ -71,7 +71,7 @@ def visit_call(self, node: nodes.Call) -> None: return method_name = get_full_method_name(node=node) - extracted_method = method_name[len("pd."):] + extracted_method = method_name[len("pd.") :] infer_node = safe_infer(node=node) infer_object = safe_infer(node.func.expr) diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index a793b7b..2d7010a 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import SCIPY from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class ScipyParameterChecker(LibraryBaseChecker): @@ -57,7 +57,6 @@ def visit_call(self, node: nodes.Call) -> None: print(node.func.expr) print(infer_object) - if method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] diff --git a/pylint_ml/checkers/sklearn/sklearn_parameter.py b/pylint_ml/checkers/sklearn/sklearn_parameter.py index 924cc36..b8f3a95 100644 --- a/pylint_ml/checkers/sklearn/sklearn_parameter.py +++ b/pylint_ml/checkers/sklearn/sklearn_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import SKLEARN from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class SklearnParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/tensorflow/tensor_parameter.py b/pylint_ml/checkers/tensorflow/tensor_parameter.py index d9cde68..e143354 100644 --- a/pylint_ml/checkers/tensorflow/tensor_parameter.py +++ b/pylint_ml/checkers/tensorflow/tensor_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import TENSORFLOW from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class TensorFlowParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/torch/torch_parameter.py b/pylint_ml/checkers/torch/torch_parameter.py index bfcf0af..ac37021 100644 --- a/pylint_ml/checkers/torch/torch_parameter.py +++ b/pylint_ml/checkers/torch/torch_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.utils import get_full_method_name from pylint_ml.checkers.config import PYTORCH from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import get_full_method_name class PyTorchParameterChecker(LibraryBaseChecker): diff --git a/pylint_ml/checkers/utils.py b/pylint_ml/checkers/utils.py index 1086526..3c21819 100644 --- a/pylint_ml/checkers/utils.py +++ b/pylint_ml/checkers/utils.py @@ -1,4 +1,5 @@ from astroid import nodes +from pylint.checkers.utils import safe_infer def get_full_method_name(node: nodes.Call) -> str: @@ -17,7 +18,6 @@ def get_full_method_name(node: nodes.Call) -> str: if isinstance(func, nodes.Name): method_chain.insert(0, func.name) # Add the base name - print(method_chain) # Join the method chain to create the full method name return ".".join(method_chain) @@ -34,3 +34,68 @@ def is_specific_library_object(node: nodes.NodeNG, library_name: str) -> bool: bool: True if the node belongs to the specified library, False otherwise. """ return node and node.root().name == library_name # Checks if the root module matches the library name + + +def infer_module_from_node_chain(start_node: nodes.NodeNG, module_name: str) -> bool: + """ + Traverses the chain of attributes and checks if the root module of the node chain + matches the specified module name (e.g., 'numpy' or 'pandas'). + + Args: + start_node (nodes.NodeNG): The starting node (either Attribute or Call). + module_name (str): The module name to check against (e.g., 'numpy', 'pandas'). + + Returns: + bool: True if the root module matches the specified module_name, False otherwise. + """ + current_node = start_node + + # Traverse backward through the chain, handling Attribute and Name node types + while isinstance(current_node, (nodes.Attribute, nodes.Name)): + if isinstance(current_node, nodes.Attribute): + # Infer the current expression (e.g., np.some) + inferred_object = safe_infer(current_node.expr) + if inferred_object is None: + return False + current_node = current_node.expr # Step backwards + elif isinstance(current_node, nodes.Name): + # Base case: a Name node is likely a module or variable (e.g., 'np') + inferred_root = safe_infer(current_node) + if inferred_root: + # Check if the inferred object's name matches the module_name + if inferred_root.qname() == module_name: + return True + else: + return False + else: + return False # If inference of the Name node fails + + return False # Return False if we couldn't infer a valid module + + +def infer_specific_module_from_call(node: nodes.Call, module_name: str) -> bool: + """ + Infers if the function call belongs to the specified module (e.g., 'numpy', 'pandas'). + + Args: + node (nodes.Call): The Call node representing the method call. + module_name (str): The module name to check against (e.g., 'numpy', 'pandas'). + + Returns: + bool: True if the root module matches the specified module_name, False otherwise. + """ + return infer_module_from_node_chain(node.func, module_name) + + +def infer_specific_module_from_attribute(node: nodes.Attribute, module_name: str) -> bool: + """ + Infers if the attribute access belongs to the specified module (e.g., 'numpy', 'pandas'). + + Args: + node (nodes.Attribute): The Attribute node representing the method or attribute access. + module_name (str): The module name to check against (e.g., 'numpy', 'pandas'). + + Returns: + bool: True if the root module matches the specified module_name, False otherwise. + """ + return infer_module_from_node_chain(node, module_name) diff --git a/tests/checkers/test_numpy/test_numpy_dot.py b/tests/checkers/test_numpy/test_numpy_dot.py index 6c7626b..ccda7f8 100644 --- a/tests/checkers/test_numpy/test_numpy_dot.py +++ b/tests/checkers/test_numpy/test_numpy_dot.py @@ -10,7 +10,7 @@ class TestNumpyDotChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumpyDotChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_warning_for_dot(self, mock_version): mock_version.return_value = "1.7.0" import_np, node = astroid.extract_node( diff --git a/tests/checkers/test_numpy/test_numpy_nan_comparison.py b/tests/checkers/test_numpy/test_numpy_nan_comparison.py index 00562af..58bfa4a 100644 --- a/tests/checkers/test_numpy/test_numpy_nan_comparison.py +++ b/tests/checkers/test_numpy/test_numpy_nan_comparison.py @@ -10,7 +10,7 @@ class TestNumpyNaNComparison(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumpyNaNComparisonChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_singleton_nan_compare(self, mock_version): mock_version.return_value = "2.1.1" import_node, singleton_node, chained_node, great_than_node = astroid.extract_node( diff --git a/tests/checkers/test_numpy/test_numpy_parameter.py b/tests/checkers/test_numpy/test_numpy_parameter.py index 3919c14..e5d251e 100644 --- a/tests/checkers/test_numpy/test_numpy_parameter.py +++ b/tests/checkers/test_numpy/test_numpy_parameter.py @@ -10,7 +10,7 @@ class TestNumPyParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = NumPyParameterChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_array_missing_object(self, mock_version): mock_version.return_value = "2.1.1" import_node, call_node = astroid.extract_node( @@ -34,7 +34,7 @@ def test_array_missing_object(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(call_node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_zeros_without_shape(self, mock_version): mock_version.return_value = "2.1.1" import_node, node = astroid.extract_node( @@ -58,7 +58,7 @@ def test_zeros_without_shape(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(zeros_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_random_rand_without_shape(self, mock_version): mock_version.return_value = "2.1.1" import_node, node = astroid.extract_node( @@ -82,7 +82,7 @@ def test_random_rand_without_shape(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(rand_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_dot_without_b(self, mock_version): mock_version.return_value = "2.1.1" import_node, node = astroid.extract_node( @@ -106,7 +106,7 @@ def test_dot_without_b(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(dot_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_percentile_without_q(self, mock_version): mock_version.return_value = "2.1.1" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py index c1c3b0b..0560173 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_bool.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_bool.py @@ -10,7 +10,7 @@ class TestDataFrameBoolChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasDataFrameBoolChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_dataframe_bool_usage(self, mock_version): mock_version.return_value = "2.2.2" import_node, call_node = astroid.extract_node( @@ -31,7 +31,7 @@ def test_dataframe_bool_usage(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(call_node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_no_bool_usage(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_column_selection.py b/tests/checkers/test_pandas/test_pandas_dataframe_column_selection.py index 22e295b..e65126d 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_column_selection.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_column_selection.py @@ -10,7 +10,7 @@ class TestPandasColumnSelectionChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasColumnSelectionChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_incorrect_column_selection(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py index db2bb73..4856fe1 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_empty_column.py @@ -10,7 +10,7 @@ class TestPandasEmptyColumnChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasEmptyColumnChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_correct_empty_column_initialization(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -24,7 +24,7 @@ def test_correct_empty_column_initialization(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_subscript(node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_incorrect_empty_column_initialization_with_zero(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -48,7 +48,7 @@ def test_incorrect_empty_column_initialization_with_zero(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_subscript(subscript_node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_incorrect_empty_column_initialization_with_empty_string(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py index 145931d..5464624 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_iterrows.py @@ -10,7 +10,7 @@ class TestPandasIterrowsChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasIterrowsChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_iterrows_used(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py index 2c13d6f..3ab77ce 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_naming.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_naming.py @@ -10,7 +10,7 @@ class TestPandasDataFrameNamingChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasDataFrameNamingChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_correct_dataframe_naming(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -23,7 +23,7 @@ def test_correct_dataframe_naming(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_assign(node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_incorrect_dataframe_naming(self, mock_version): mock_version.return_value = "2.2.2" import_node, pandas_dataframe_node = astroid.extract_node( @@ -43,7 +43,7 @@ def test_incorrect_dataframe_naming(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_assign(pandas_dataframe_node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_incorrect_dataframe_name_length(self, mock_version): mock_version.return_value = "2.2.2" import_node, pandas_dataframe_node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_dataframe_values.py b/tests/checkers/test_pandas/test_pandas_dataframe_values.py index 2fbcebb..5373c87 100644 --- a/tests/checkers/test_pandas/test_pandas_dataframe_values.py +++ b/tests/checkers/test_pandas/test_pandas_dataframe_values.py @@ -10,7 +10,7 @@ class TestPandasValuesChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasValuesChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_values_usage_with_correct_naming(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_inplace.py b/tests/checkers/test_pandas/test_pandas_inplace.py index 54a0527..0a05cfc 100644 --- a/tests/checkers/test_pandas/test_pandas_inplace.py +++ b/tests/checkers/test_pandas/test_pandas_inplace.py @@ -10,7 +10,7 @@ class TestPandasInplaceChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasInplaceChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_inplace_used_in_drop(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -34,7 +34,7 @@ def test_inplace_used_in_drop(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_inplace_used_in_fillna(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -58,7 +58,7 @@ def test_inplace_used_in_fillna(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_inplace_used_in_sort_values(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -82,7 +82,7 @@ def test_inplace_used_in_sort_values(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_no_inplace(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -102,7 +102,7 @@ def test_no_inplace(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(inplace_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_inplace_used_in_unsupported_method(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_parameter.py b/tests/checkers/test_pandas/test_pandas_parameter.py index 7f87e3b..7399b1e 100644 --- a/tests/checkers/test_pandas/test_pandas_parameter.py +++ b/tests/checkers/test_pandas/test_pandas_parameter.py @@ -10,7 +10,7 @@ class TestPandasParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasParameterChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_dataframe_missing_data(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -34,7 +34,7 @@ def test_dataframe_missing_data(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(dataframe_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_merge_without_required_params(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -60,7 +60,7 @@ def test_merge_without_required_params(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(merge_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_read_csv_without_filepath(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -84,7 +84,7 @@ def test_read_csv_without_filepath(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(read_csv_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_to_csv_without_path(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -109,7 +109,7 @@ def test_to_csv_without_path(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(to_csv_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_groupby_without_by(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -134,7 +134,7 @@ def test_groupby_without_by(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(groupby_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_fillna_without_value(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -159,7 +159,7 @@ def test_fillna_without_value(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(fillna_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_sort_values_without_by(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -184,7 +184,7 @@ def test_sort_values_without_by(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(sort_values_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_merge_with_missing_validate(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -208,7 +208,7 @@ def test_merge_with_missing_validate(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(merge_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_merge_with_wrong_naming_and_missing_params(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -229,7 +229,7 @@ def test_merge_with_wrong_naming_and_missing_params(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(merge_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_merge_with_all_params_and_correct_naming(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_series_bool.py b/tests/checkers/test_pandas/test_pandas_series_bool.py index 5da72a6..7efdc64 100644 --- a/tests/checkers/test_pandas/test_pandas_series_bool.py +++ b/tests/checkers/test_pandas/test_pandas_series_bool.py @@ -10,7 +10,7 @@ class TestSeriesBoolChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasSeriesBoolChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_series_bool_usage(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -31,7 +31,7 @@ def test_series_bool_usage(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_call(node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_no_bool_usage(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_pandas/test_pandas_series_naming.py b/tests/checkers/test_pandas/test_pandas_series_naming.py index 4510d7b..58c9cbc 100644 --- a/tests/checkers/test_pandas/test_pandas_series_naming.py +++ b/tests/checkers/test_pandas/test_pandas_series_naming.py @@ -10,7 +10,7 @@ class TestPandasSeriesNamingChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PandasSeriesNamingChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_series_correct_naming(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -23,7 +23,7 @@ def test_series_correct_naming(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_assign(node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_series_incorrect_naming(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( @@ -43,7 +43,7 @@ def test_series_incorrect_naming(self, mock_version): self.checker.visit_import(import_node) self.checker.visit_assign(node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_series_invalid_length_naming(self, mock_version): mock_version.return_value = "2.2.2" import_node, node = astroid.extract_node( diff --git a/tests/checkers/test_scipy/test_scipy_parameter.py b/tests/checkers/test_scipy/test_scipy_parameter.py index 6d9da75..45bbde7 100644 --- a/tests/checkers/test_scipy/test_scipy_parameter.py +++ b/tests/checkers/test_scipy/test_scipy_parameter.py @@ -10,7 +10,7 @@ class TestScipyParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = ScipyParameterChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_minimize_params(self, mock_version): mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( @@ -33,7 +33,7 @@ def test_minimize_params(self, mock_version): self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(minimize_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_curve_fit_params(self, mock_version): mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( @@ -56,7 +56,7 @@ def test_curve_fit_params(self, mock_version): self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(curve_fit_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_quad_params(self, mock_version): mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( @@ -79,7 +79,7 @@ def test_quad_params(self, mock_version): self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(quad_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_solve_ivp_params(self, mock_version): mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( @@ -102,7 +102,7 @@ def test_solve_ivp_params(self, mock_version): self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(solve_ivp_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_ttest_ind_params(self, mock_version): mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( @@ -125,7 +125,7 @@ def test_ttest_ind_params(self, mock_version): self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(ttest_ind_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_euclidean_params(self, mock_version): mock_version.return_value = "1.7.0" importfrom_node, node = astroid.extract_node( diff --git a/tests/checkers/test_sklearn/test_sklearn_parameter.py b/tests/checkers/test_sklearn/test_sklearn_parameter.py index f7be58c..9a67075 100644 --- a/tests/checkers/test_sklearn/test_sklearn_parameter.py +++ b/tests/checkers/test_sklearn/test_sklearn_parameter.py @@ -10,7 +10,7 @@ class TestSklearnParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = SklearnParameterChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_random_forest_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -33,7 +33,7 @@ def test_random_forest_params(self, mock_version): ): self.checker.visit_call(forest_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_random_forest_with_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -48,7 +48,7 @@ def test_random_forest_with_params(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(forest_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_svc_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -71,7 +71,7 @@ def test_svc_params(self, mock_version): ): self.checker.visit_call(svc_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_svc_with_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -86,7 +86,7 @@ def test_svc_with_params(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(svc_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_kmeans_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -109,7 +109,7 @@ def test_kmeans_params(self, mock_version): ): self.checker.visit_call(kmeans_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_kmeans_with_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( diff --git a/tests/checkers/test_tensorflow/test_tensor_parameter.py b/tests/checkers/test_tensorflow/test_tensor_parameter.py index 99c26db..a27aa08 100644 --- a/tests/checkers/test_tensorflow/test_tensor_parameter.py +++ b/tests/checkers/test_tensorflow/test_tensor_parameter.py @@ -10,7 +10,7 @@ class TestTensorParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = TensorFlowParameterChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_sequential_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -33,7 +33,7 @@ def test_sequential_params(self, mock_version): ): self.checker.visit_call(sequential_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_sequential_with_layers(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -51,7 +51,7 @@ def test_sequential_with_layers(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(sequential_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_compile_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -73,7 +73,7 @@ def test_compile_params(self, mock_version): ): self.checker.visit_call(node) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_compile_with_all_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -89,7 +89,7 @@ def test_compile_with_all_params(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(compile_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_fit_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -114,7 +114,7 @@ def test_fit_params(self, mock_version): ): self.checker.visit_call(fit_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_fit_with_all_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -131,7 +131,7 @@ def test_fit_with_all_params(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(fit_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_conv2d_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -154,7 +154,7 @@ def test_conv2d_params(self, mock_version): ): self.checker.visit_call(conv2d_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_conv2d_with_all_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -169,7 +169,7 @@ def test_conv2d_with_all_params(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(conv2d_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_dense_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( @@ -192,7 +192,7 @@ def test_dense_params(self, mock_version): ): self.checker.visit_call(dense_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_dense_with_all_params(self, mock_version): mock_version.return_value = "1.5.2" node = astroid.extract_node( diff --git a/tests/checkers/test_torch/test_torch_parameter.py b/tests/checkers/test_torch/test_torch_parameter.py index 5e8ffb1..4337d7f 100644 --- a/tests/checkers/test_torch/test_torch_parameter.py +++ b/tests/checkers/test_torch/test_torch_parameter.py @@ -10,7 +10,7 @@ class TestTorchParameterChecker(pylint.testutils.CheckerTestCase): CHECKER_CLASS = PyTorchParameterChecker - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_sgd_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( @@ -33,7 +33,7 @@ def test_sgd_params(self, mock_version): ): self.checker.visit_call(sgd_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_sgd_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( @@ -48,7 +48,7 @@ def test_sgd_with_all_params(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(sgd_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_adam_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( @@ -71,7 +71,7 @@ def test_adam_params(self, mock_version): ): self.checker.visit_call(adam_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_adam_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( @@ -86,7 +86,7 @@ def test_adam_with_all_params(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(adam_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_conv2d_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( @@ -109,7 +109,7 @@ def test_conv2d_params(self, mock_version): ): self.checker.visit_call(conv2d_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_conv2d_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( @@ -124,7 +124,7 @@ def test_conv2d_with_all_params(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(conv2d_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_linear_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( @@ -147,7 +147,7 @@ def test_linear_params(self, mock_version): ): self.checker.visit_call(linear_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_linear_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( @@ -162,7 +162,7 @@ def test_linear_with_all_params(self, mock_version): with self.assertNoMessages(): self.checker.visit_call(linear_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_lstm_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( @@ -185,7 +185,7 @@ def test_lstm_params(self, mock_version): ): self.checker.visit_call(lstm_call) - @patch("pylint_ml.util.library_base_checker.version") + @patch("pylint_ml.checkers.library_base_checker.version") def test_lstm_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" node = astroid.extract_node( From 8674ab9627ea397ca053d81f64838f87fbe56ef6 Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Wed, 25 Sep 2024 10:26:33 +0200 Subject: [PATCH 11/19] Finalize update of numpy checkers --- pylint_ml/checkers/numpy/numpy_parameter.py | 40 ++++++++----------- .../test_numpy/test_numpy_parameter.py | 2 +- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index c717824..5dddc75 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -10,7 +10,7 @@ from pylint_ml.checkers.config import NUMPY from pylint_ml.checkers.library_base_checker import LibraryBaseChecker -from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.utils import get_full_method_name, infer_specific_module_from_call class NumPyParameterChecker(LibraryBaseChecker): @@ -36,12 +36,12 @@ class NumPyParameterChecker(LibraryBaseChecker): "eye": ["N"], "identity": ["n"], # Random Sampling - "random.rand": ["d0"], - "random.randn": ["d0"], - "random.randint": ["low", "high"], - "random.choice": ["a"], - "random.uniform": ["low", "high"], - "random.normal": ["loc", "scale"], + "rand": ["d0"], + "randn": ["d0"], + "randint": ["low", "high"], + "choice": ["a"], + "uniform": ["low", "high"], + "normal": ["loc", "scale"], # Mathematical Functions "sum": ["a"], "mean": ["a"], @@ -62,9 +62,9 @@ class NumPyParameterChecker(LibraryBaseChecker): # Linear Algebra "dot": ["a", "b"], "matmul": ["a", "b"], - "linalg.inv": ["a"], - "linalg.eig": ["a"], - "linalg.solve": ["a", "b"], + "inv": ["a"], + "eig": ["a"], + "solve": ["a", "b"], # Statistical Functions "percentile": ["a", "q"], "quantile": ["a", "q"], @@ -77,25 +77,19 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): return - method_name = get_full_method_name(node=node) - extracted_method = method_name[len("np.") :] - - infer_node = safe_infer(node=node) - infer_object = safe_infer(node.func.expr) - print(node.func.expr) - print(infer_object) - print("------") - print(infer_node) - - if method_name.startswith("np.") and extracted_method in self.REQUIRED_PARAMS: + if ( + infer_specific_module_from_call(node=node, module_name=NUMPY) + and isinstance(node.func, nodes.Attribute) + and node.func.attrname in self.REQUIRED_PARAMS + ): provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [ - param for param in self.REQUIRED_PARAMS[extracted_method] if param not in provided_keywords + param for param in self.REQUIRED_PARAMS[node.func.attrname] if param not in provided_keywords ] if missing_params: self.add_message( "numpy-parameter", node=node, confidence=HIGH, - args=(", ".join(missing_params), extracted_method), + args=(", ".join(missing_params), node.func.attrname), ) diff --git a/tests/checkers/test_numpy/test_numpy_parameter.py b/tests/checkers/test_numpy/test_numpy_parameter.py index e5d251e..969b573 100644 --- a/tests/checkers/test_numpy/test_numpy_parameter.py +++ b/tests/checkers/test_numpy/test_numpy_parameter.py @@ -75,7 +75,7 @@ def test_random_rand_without_shape(self, mock_version): msg_id="numpy-parameter", confidence=HIGH, node=rand_call, - args=("d0", "random.rand"), + args=("d0", "rand"), ), ignore_position=True, ): From 6b9ee0fb66110701f298b8fa6465189e15f1b335 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 08:26:43 +0000 Subject: [PATCH 12/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pylint_ml/checkers/numpy/numpy_parameter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index 5dddc75..2749cd1 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -5,12 +5,12 @@ """Check for proper usage of numpy functions with required parameters.""" from astroid import nodes -from pylint.checkers.utils import only_required_for_messages, safe_infer +from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH from pylint_ml.checkers.config import NUMPY from pylint_ml.checkers.library_base_checker import LibraryBaseChecker -from pylint_ml.checkers.utils import get_full_method_name, infer_specific_module_from_call +from pylint_ml.checkers.utils import infer_specific_module_from_call class NumPyParameterChecker(LibraryBaseChecker): @@ -78,9 +78,9 @@ def visit_call(self, node: nodes.Call) -> None: return if ( - infer_specific_module_from_call(node=node, module_name=NUMPY) - and isinstance(node.func, nodes.Attribute) - and node.func.attrname in self.REQUIRED_PARAMS + infer_specific_module_from_call(node=node, module_name=NUMPY) + and isinstance(node.func, nodes.Attribute) + and node.func.attrname in self.REQUIRED_PARAMS ): provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [ From 800aa10aa6d2e5c3f3947d22514b05f73eaeed9d Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Wed, 25 Sep 2024 12:12:36 +0200 Subject: [PATCH 13/19] Utilize safe_infer for detecting pandas --- pylint_ml/checkers/numpy/numpy_parameter.py | 8 ++++---- pylint_ml/checkers/pandas/pandas_dataframe_bool.py | 10 +++++++--- .../pandas/pandas_dataframe_column_selection.py | 7 ++++++- pylint_ml/checkers/utils.py | 4 +++- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index 5dddc75..ff08818 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -77,19 +77,19 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=NUMPY, required_version=None): return + method_name = getattr(node.func, "attrname", None) if ( infer_specific_module_from_call(node=node, module_name=NUMPY) - and isinstance(node.func, nodes.Attribute) - and node.func.attrname in self.REQUIRED_PARAMS + and method_name in self.REQUIRED_PARAMS ): provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [ - param for param in self.REQUIRED_PARAMS[node.func.attrname] if param not in provided_keywords + param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords ] if missing_params: self.add_message( "numpy-parameter", node=node, confidence=HIGH, - args=(", ".join(missing_params), node.func.attrname), + args=(", ".join(missing_params), method_name), ) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py index d1b0119..64ffc2e 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py @@ -7,11 +7,12 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers.utils import only_required_for_messages +from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call class PandasDataFrameBoolChecker(LibraryBaseChecker): @@ -31,11 +32,14 @@ def visit_call(self, node: nodes.Call) -> None: if isinstance(node.func, nodes.Attribute): method_name = getattr(node.func, "attrname", None) - if method_name == "bool": # Check if the object calling .bool() has a name starting with 'df_' object_name = getattr(node.func.expr, "name", None) - if object_name and self._is_valid_dataframe_name(object_name): + if ( + infer_specific_module_from_call(node=node, module_name=PANDAS) + and object_name + and self._is_valid_dataframe_name(object_name) + ): self.add_message("pandas-dataframe-bool", node=node, confidence=HIGH) @staticmethod diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py index 447fce9..a65cf77 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py @@ -12,6 +12,7 @@ from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_attribute class PandasColumnSelectionChecker(LibraryBaseChecker): @@ -31,6 +32,10 @@ def visit_attribute(self, node: nodes.Attribute) -> None: if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return - if isinstance(node.expr, nodes.Name) and node.expr.name.startswith("df_"): + if ( + infer_specific_module_from_attribute(node=node, module_name=PANDAS) + and isinstance(node.expr, nodes.Name) + and node.expr.name.startswith("df_") + ): # Issue a warning for property-like access self.add_message("pandas-column-selection", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/utils.py b/pylint_ml/checkers/utils.py index 3c21819..041485f 100644 --- a/pylint_ml/checkers/utils.py +++ b/pylint_ml/checkers/utils.py @@ -61,9 +61,11 @@ def infer_module_from_node_chain(start_node: nodes.NodeNG, module_name: str) -> elif isinstance(current_node, nodes.Name): # Base case: a Name node is likely a module or variable (e.g., 'np') inferred_root = safe_infer(current_node) + print(inferred_root) if inferred_root: # Check if the inferred object's name matches the module_name - if inferred_root.qname() == module_name: + # TODO update solution to handle MODULE and INSTANCE + if module_name in inferred_root.qname() or inferred_root.qname() == module_name: return True else: return False From 2541932c606f60a6cf559710225e5140efa10f5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:13:42 +0000 Subject: [PATCH 14/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pylint_ml/checkers/pandas/pandas_dataframe_bool.py | 8 ++++---- .../checkers/pandas/pandas_dataframe_column_selection.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py index 64ffc2e..693385d 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_bool.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_bool.py @@ -7,7 +7,7 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers.utils import only_required_for_messages, safe_infer +from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH from pylint_ml.checkers.config import PANDAS @@ -36,9 +36,9 @@ def visit_call(self, node: nodes.Call) -> None: # Check if the object calling .bool() has a name starting with 'df_' object_name = getattr(node.func.expr, "name", None) if ( - infer_specific_module_from_call(node=node, module_name=PANDAS) - and object_name - and self._is_valid_dataframe_name(object_name) + infer_specific_module_from_call(node=node, module_name=PANDAS) + and object_name + and self._is_valid_dataframe_name(object_name) ): self.add_message("pandas-dataframe-bool", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py index a65cf77..8a92de9 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_column_selection.py @@ -33,9 +33,9 @@ def visit_attribute(self, node: nodes.Attribute) -> None: return if ( - infer_specific_module_from_attribute(node=node, module_name=PANDAS) - and isinstance(node.expr, nodes.Name) - and node.expr.name.startswith("df_") + infer_specific_module_from_attribute(node=node, module_name=PANDAS) + and isinstance(node.expr, nodes.Name) + and node.expr.name.startswith("df_") ): # Issue a warning for property-like access self.add_message("pandas-column-selection", node=node, confidence=HIGH) From 7f3d0609f2742e82d57aeb8c25b8e4f4d18fc6d0 Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Thu, 26 Sep 2024 15:08:45 +0200 Subject: [PATCH 15/19] Update checker using safe_infer --- .../checkers/pandas/pandas_dataframe_empty_column.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py index 52e32dc..e34b8d3 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py @@ -7,7 +7,7 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers.utils import only_required_for_messages +from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH from pylint_ml.checkers.config import PANDAS @@ -30,7 +30,12 @@ def visit_subscript(self, node: nodes.Subscript) -> None: if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return - if isinstance(node.value, nodes.Name) and node.value.name.startswith("df_"): + if ( + isinstance(node.value, nodes.Name) + and node.value.name.startswith("df_") + and PANDAS in safe_infer(node.value).qname() + ): + print(node.value.name) if isinstance(node.slice, nodes.Const) and isinstance(node.parent, nodes.Assign): if isinstance(node.parent.value, nodes.Const): # Checking for filler values: 0 or empty string From 506251e25b95166691d3d00fac6b581b183f4f41 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:10:39 +0000 Subject: [PATCH 16/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py index e34b8d3..b8595f5 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_empty_column.py @@ -31,9 +31,9 @@ def visit_subscript(self, node: nodes.Subscript) -> None: return if ( - isinstance(node.value, nodes.Name) - and node.value.name.startswith("df_") - and PANDAS in safe_infer(node.value).qname() + isinstance(node.value, nodes.Name) + and node.value.name.startswith("df_") + and PANDAS in safe_infer(node.value).qname() ): print(node.value.name) if isinstance(node.slice, nodes.Const) and isinstance(node.parent, nodes.Assign): From ea055be9dfd52a00b1164b1410732879ca208a8a Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Thu, 26 Sep 2024 15:26:52 +0200 Subject: [PATCH 17/19] Update checker using safe_infer --- pylint_ml/checkers/numpy/numpy_nan_comparison.py | 2 +- pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pylint_ml/checkers/numpy/numpy_nan_comparison.py b/pylint_ml/checkers/numpy/numpy_nan_comparison.py index 0d6ee1f..7eedd4b 100644 --- a/pylint_ml/checkers/numpy/numpy_nan_comparison.py +++ b/pylint_ml/checkers/numpy/numpy_nan_comparison.py @@ -31,7 +31,7 @@ class NumpyNaNComparisonChecker(LibraryBaseChecker): @classmethod def __is_np_nan_call(cls, node: nodes.Attribute) -> bool: """Check if the node represents a call to np.nan.""" - return node.attrname in NUMPY_NAN and (infer_specific_module_from_attribute(node=node, module_name="numpy")) + return node.attrname in NUMPY_NAN and (infer_specific_module_from_attribute(node=node, module_name=NUMPY)) @only_required_for_messages("numpy-nan-compare") def visit_compare(self, node: nodes.Compare) -> None: diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py index 0f6d424..bb6234d 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py @@ -12,6 +12,7 @@ from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_attribute class PandasIterrowsChecker(LibraryBaseChecker): @@ -30,7 +31,10 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return - if isinstance(node.func, nodes.Attribute): + if ( + isinstance(node.func, nodes.Attribute) + and infer_specific_module_from_attribute(node=node.func, module_name=PANDAS) + ): method_name = getattr(node.func, "attrname", None) if method_name == "iterrows": object_name = getattr(node.func.expr, "name", None) From 4f32ce6bd9945589cb331f2eedf5704008a64b99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:27:06 +0000 Subject: [PATCH 18/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py index bb6234d..e9871bc 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_iterrows.py @@ -31,9 +31,8 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return - if ( - isinstance(node.func, nodes.Attribute) - and infer_specific_module_from_attribute(node=node.func, module_name=PANDAS) + if isinstance(node.func, nodes.Attribute) and infer_specific_module_from_attribute( + node=node.func, module_name=PANDAS ): method_name = getattr(node.func, "attrname", None) if method_name == "iterrows": From ed076c7c4a1dd4666592f965fa89616a098b1c96 Mon Sep 17 00:00:00 2001 From: Peter Hamfelt Date: Sat, 28 Sep 2024 00:04:06 +0200 Subject: [PATCH 19/19] Update checkers utilizing lookup and safe_infer --- pylint_ml/checkers/config.py | 2 +- pylint_ml/checkers/numpy/numpy_parameter.py | 15 +--- .../pandas/pandas_dataframe_naming.py | 7 +- .../pandas/pandas_dataframe_values.py | 8 +- pylint_ml/checkers/pandas/pandas_inplace.py | 5 +- pylint_ml/checkers/pandas/pandas_parameter.py | 23 ++---- .../checkers/pandas/pandas_series_bool.py | 3 +- .../checkers/pandas/pandas_series_naming.py | 3 +- pylint_ml/checkers/scipy/scipy_parameter.py | 59 +++++++++------ .../checkers/tensorflow/tensor_parameter.py | 10 ++- pylint_ml/checkers/torch/torch_parameter.py | 12 +-- pylint_ml/checkers/utils.py | 9 ++- .../test_pandas/test_pandas_parameter.py | 3 + .../test_sklearn/test_sklearn_parameter.py | 42 ++++++----- .../test_tensorflow/test_tensor_parameter.py | 73 ++++++++++--------- .../test_torch/test_torch_parameter.py | 49 ++++++++----- 16 files changed, 181 insertions(+), 142 deletions(-) diff --git a/pylint_ml/checkers/config.py b/pylint_ml/checkers/config.py index f145c23..4ea6082 100644 --- a/pylint_ml/checkers/config.py +++ b/pylint_ml/checkers/config.py @@ -11,6 +11,6 @@ SKLEARN = "sklearn" -PYTORCH = "torch" +TORCH = "torch" MATPLOTLIB = "matplotlib" diff --git a/pylint_ml/checkers/numpy/numpy_parameter.py b/pylint_ml/checkers/numpy/numpy_parameter.py index 955cc4b..180fa84 100644 --- a/pylint_ml/checkers/numpy/numpy_parameter.py +++ b/pylint_ml/checkers/numpy/numpy_parameter.py @@ -78,20 +78,9 @@ def visit_call(self, node: nodes.Call) -> None: return method_name = getattr(node.func, "attrname", None) - if ( -<<<<<<< HEAD - infer_specific_module_from_call(node=node, module_name=NUMPY) - and method_name in self.REQUIRED_PARAMS -======= - infer_specific_module_from_call(node=node, module_name=NUMPY) - and isinstance(node.func, nodes.Attribute) - and node.func.attrname in self.REQUIRED_PARAMS ->>>>>>> 6b9ee0fb66110701f298b8fa6465189e15f1b335 - ): + if infer_specific_module_from_call(node=node, module_name=NUMPY) and method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - missing_params = [ - param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords - ] + missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( "numpy-parameter", diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py index 15af29e..922e1a4 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_naming.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_naming.py @@ -12,6 +12,7 @@ from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call class PandasDataFrameNamingChecker(LibraryBaseChecker): @@ -33,7 +34,11 @@ def visit_assign(self, node: nodes.Assign) -> None: func_name = getattr(node.value.func, "attrname", None) module_name = getattr(node.value.func.expr, "name", None) - if func_name == "DataFrame" and module_name == "pd": + if ( + func_name == "DataFrame" + and module_name == "pd" + and infer_specific_module_from_call(node=node.value, module_name=PANDAS) + ): for target in node.targets: if isinstance(target, nodes.AssignName): var_name = target.name diff --git a/pylint_ml/checkers/pandas/pandas_dataframe_values.py b/pylint_ml/checkers/pandas/pandas_dataframe_values.py index 1ea4133..5651131 100644 --- a/pylint_ml/checkers/pandas/pandas_dataframe_values.py +++ b/pylint_ml/checkers/pandas/pandas_dataframe_values.py @@ -7,7 +7,7 @@ from __future__ import annotations from astroid import nodes -from pylint.checkers.utils import only_required_for_messages +from pylint.checkers.utils import only_required_for_messages, safe_infer from pylint.interfaces import HIGH from pylint_ml.checkers.config import PANDAS @@ -31,5 +31,9 @@ def visit_attribute(self, node: nodes.Attribute) -> None: return if isinstance(node.expr, nodes.Name): - if node.attrname == "values" and node.expr.name.startswith("df_"): + if ( + node.attrname == "values" + and node.expr.name.startswith("df_") + and PANDAS in safe_infer(node.expr).qname() + ): self.add_message("pandas-dataframe-values", node=node, confidence=HIGH) diff --git a/pylint_ml/checkers/pandas/pandas_inplace.py b/pylint_ml/checkers/pandas/pandas_inplace.py index 0d66689..7a13142 100644 --- a/pylint_ml/checkers/pandas/pandas_inplace.py +++ b/pylint_ml/checkers/pandas/pandas_inplace.py @@ -12,6 +12,7 @@ from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_attribute class PandasInplaceChecker(LibraryBaseChecker): @@ -45,7 +46,9 @@ def visit_call(self, node: nodes.Call) -> None: return # Check if the call is to a method that supports 'inplace' - if isinstance(node.func, nodes.Attribute): + if isinstance(node.func, nodes.Attribute) and infer_specific_module_from_attribute( + node=node.func, module_name=PANDAS + ): method_name = node.func.attrname if method_name in self._inplace_methods: for keyword in node.keywords: diff --git a/pylint_ml/checkers/pandas/pandas_parameter.py b/pylint_ml/checkers/pandas/pandas_parameter.py index dabf5cb..36067d9 100644 --- a/pylint_ml/checkers/pandas/pandas_parameter.py +++ b/pylint_ml/checkers/pandas/pandas_parameter.py @@ -5,12 +5,12 @@ """Check for proper usage of Pandas functions with required parameters.""" from astroid import nodes -from pylint.checkers.utils import only_required_for_messages, safe_infer +from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker -from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.utils import infer_specific_module_from_call class PandasParameterChecker(LibraryBaseChecker): @@ -70,25 +70,14 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return - method_name = get_full_method_name(node=node) - extracted_method = method_name[len("pd.") :] - - infer_node = safe_infer(node=node) - infer_object = safe_infer(node.func.expr) - print(node.func.expr) - print(infer_object) - print("------") - print(infer_node) - - if method_name.startswith("pd.") and extracted_method in self.REQUIRED_PARAMS: + method_name = getattr(node.func, "attrname", None) + if infer_specific_module_from_call(node=node, module_name=PANDAS) and method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - missing_params = [ - param for param in self.REQUIRED_PARAMS[extracted_method] if param not in provided_keywords - ] + missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: self.add_message( "pandas-parameter", node=node, confidence=HIGH, - args=(", ".join(missing_params), extracted_method), + args=(", ".join(missing_params), method_name), ) diff --git a/pylint_ml/checkers/pandas/pandas_series_bool.py b/pylint_ml/checkers/pandas/pandas_series_bool.py index aae0512..eaac96d 100644 --- a/pylint_ml/checkers/pandas/pandas_series_bool.py +++ b/pylint_ml/checkers/pandas/pandas_series_bool.py @@ -13,6 +13,7 @@ # Todo add version deprecated from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call class PandasSeriesBoolChecker(LibraryBaseChecker): @@ -30,7 +31,7 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return - if isinstance(node.func, nodes.Attribute): + if isinstance(node.func, nodes.Attribute) and infer_specific_module_from_call(node=node, module_name=PANDAS): method_name = getattr(node.func, "attrname", None) if method_name == "bool": diff --git a/pylint_ml/checkers/pandas/pandas_series_naming.py b/pylint_ml/checkers/pandas/pandas_series_naming.py index 4e30aa5..010b83e 100644 --- a/pylint_ml/checkers/pandas/pandas_series_naming.py +++ b/pylint_ml/checkers/pandas/pandas_series_naming.py @@ -12,6 +12,7 @@ from pylint_ml.checkers.config import PANDAS from pylint_ml.checkers.library_base_checker import LibraryBaseChecker +from pylint_ml.checkers.utils import infer_specific_module_from_call class PandasSeriesNamingChecker(LibraryBaseChecker): @@ -29,7 +30,7 @@ def visit_assign(self, node: nodes.Assign) -> None: if not self.is_library_imported_and_version_valid(lib_name=PANDAS, required_version=None): return - if isinstance(node.value, nodes.Call): + if isinstance(node.value, nodes.Call) and infer_specific_module_from_call(node=node.value, module_name=PANDAS): func_name = getattr(node.value.func, "attrname", None) module_name = getattr(node.value.func.expr, "name", None) diff --git a/pylint_ml/checkers/scipy/scipy_parameter.py b/pylint_ml/checkers/scipy/scipy_parameter.py index 2d7010a..e754d25 100644 --- a/pylint_ml/checkers/scipy/scipy_parameter.py +++ b/pylint_ml/checkers/scipy/scipy_parameter.py @@ -5,12 +5,11 @@ """Check for proper usage of Scipy functions with required parameters.""" from astroid import nodes -from pylint.checkers.utils import only_required_for_messages, safe_infer +from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH from pylint_ml.checkers.config import SCIPY from pylint_ml.checkers.library_base_checker import LibraryBaseChecker -from pylint_ml.checkers.utils import get_full_method_name class ScipyParameterChecker(LibraryBaseChecker): @@ -36,11 +35,10 @@ class ScipyParameterChecker(LibraryBaseChecker): # scipy.stats "ttest_ind": ["a", "b"], "ttest_rel": ["a", "b"], - "norm.pdf": ["x"], + "pdf": ["x"], # scipy.spatial - "distance.euclidean": ["u", "v"], # Full chain - "euclidean": ["u", "v"], # Direct import of euclidean - "KDTree.query": ["x"], + "euclidean": ["u", "v"], + "query": ["x"], } @only_required_for_messages("scipy-parameter") @@ -48,22 +46,37 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=SCIPY, required_version=None): return - method_name = get_full_method_name(node=node) + # Determine whether the function is a simple Name (method call) + if isinstance(node.func, nodes.Name): + method_name = node.func.name # For cases like minimize() + else: + return # Exit early - infer_node = safe_infer(node=node) - print("------") - print(infer_node) - infer_object = safe_infer(node.func.expr) - print(node.func.expr) - print(infer_object) + # Perform a lookup in the current scope for the function/method name + scope = node.scope() + name_lookup = scope.lookup(method_name) - if method_name in self.REQUIRED_PARAMS: - provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} - missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] - if missing_params: - self.add_message( - "scipy-parameter", - node=node, - confidence=HIGH, - args=(", ".join(missing_params), method_name), - ) + if name_lookup: + _, assignments = name_lookup + if assignments: + assignment = assignments[0] + if isinstance(assignment, nodes.ImportFrom): + # Check if the import is from scipy.optimize + # Correctly unpack the names from the tuple + imported_names = [name for name, _ in assignment.names] + + if SCIPY in assignment.modname and method_name in imported_names: + # Proceed with checking parameters + if method_name in self.REQUIRED_PARAMS: + provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} + missing_params = [ + param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords + ] + + if missing_params: + self.add_message( + "scipy-parameter", + node=node, + confidence=HIGH, + args=(", ".join(missing_params), method_name), + ) diff --git a/pylint_ml/checkers/tensorflow/tensor_parameter.py b/pylint_ml/checkers/tensorflow/tensor_parameter.py index e143354..cefdeaa 100644 --- a/pylint_ml/checkers/tensorflow/tensor_parameter.py +++ b/pylint_ml/checkers/tensorflow/tensor_parameter.py @@ -10,7 +10,7 @@ from pylint_ml.checkers.config import TENSORFLOW from pylint_ml.checkers.library_base_checker import LibraryBaseChecker -from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.utils import infer_specific_module_from_call class TensorFlowParameterChecker(LibraryBaseChecker): @@ -41,8 +41,12 @@ def visit_call(self, node: nodes.Call) -> None: if not self.is_library_imported_and_version_valid(lib_name=TENSORFLOW, required_version=None): return - method_name = get_full_method_name(node) - if method_name in self.REQUIRED_PARAMS: + # TODO UPDATE SOLUTION + + # method_name = get_full_method_name(node) + # if method_name in self.REQUIRED_PARAMS: + method_name = getattr(node.func, "attrname", None) + if infer_specific_module_from_call(node=node, module_name=TENSORFLOW) and method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: diff --git a/pylint_ml/checkers/torch/torch_parameter.py b/pylint_ml/checkers/torch/torch_parameter.py index ac37021..4c631f7 100644 --- a/pylint_ml/checkers/torch/torch_parameter.py +++ b/pylint_ml/checkers/torch/torch_parameter.py @@ -8,9 +8,9 @@ from pylint.checkers.utils import only_required_for_messages from pylint.interfaces import HIGH -from pylint_ml.checkers.config import PYTORCH +from pylint_ml.checkers.config import TORCH from pylint_ml.checkers.library_base_checker import LibraryBaseChecker -from pylint_ml.checkers.utils import get_full_method_name +from pylint_ml.checkers.utils import infer_specific_module_from_call class PyTorchParameterChecker(LibraryBaseChecker): @@ -37,11 +37,13 @@ class PyTorchParameterChecker(LibraryBaseChecker): @only_required_for_messages("pytorch-parameter") def visit_call(self, node: nodes.Call) -> None: - if not self.is_library_imported_and_version_valid(lib_name=PYTORCH, required_version=None): + if not self.is_library_imported_and_version_valid(lib_name=TORCH, required_version=None): return - method_name = get_full_method_name(node) - if method_name in self.REQUIRED_PARAMS: + # TODO UPDATE SOLUTION + + method_name = getattr(node.func, "attrname", None) + if infer_specific_module_from_call(node=node, module_name=TORCH) and method_name in self.REQUIRED_PARAMS: provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None} missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords] if missing_params: diff --git a/pylint_ml/checkers/utils.py b/pylint_ml/checkers/utils.py index 041485f..a5c489f 100644 --- a/pylint_ml/checkers/utils.py +++ b/pylint_ml/checkers/utils.py @@ -52,19 +52,22 @@ def infer_module_from_node_chain(start_node: nodes.NodeNG, module_name: str) -> # Traverse backward through the chain, handling Attribute and Name node types while isinstance(current_node, (nodes.Attribute, nodes.Name)): + print(current_node) + if isinstance(current_node, nodes.Attribute): # Infer the current expression (e.g., np.some) inferred_object = safe_infer(current_node.expr) if inferred_object is None: - return False - current_node = current_node.expr # Step backwards + current_node = current_node.expr + else: + current_node = current_node.expr # Step backwards elif isinstance(current_node, nodes.Name): # Base case: a Name node is likely a module or variable (e.g., 'np') inferred_root = safe_infer(current_node) - print(inferred_root) if inferred_root: # Check if the inferred object's name matches the module_name # TODO update solution to handle MODULE and INSTANCE + if module_name in inferred_root.qname() or inferred_root.qname() == module_name: return True else: diff --git a/tests/checkers/test_pandas/test_pandas_parameter.py b/tests/checkers/test_pandas/test_pandas_parameter.py index 7399b1e..ed4769f 100644 --- a/tests/checkers/test_pandas/test_pandas_parameter.py +++ b/tests/checkers/test_pandas/test_pandas_parameter.py @@ -190,6 +190,7 @@ def test_merge_with_missing_validate(self, mock_version): import_node, node = astroid.extract_node( """ import pandas as pd #@ + df_1 = pd.DataFrame({'A': [1, 2]}) df_3 = df_1.merge(right=df_2, how='inner', on='col1') #@ """ ) @@ -214,6 +215,7 @@ def test_merge_with_wrong_naming_and_missing_params(self, mock_version): import_node, node = astroid.extract_node( """ import pandas as pd #@ + df_1 = pd.DataFrame({'A': [1, 2]}) merged_df = df_1.merge(right=df_2) #@ """ ) @@ -235,6 +237,7 @@ def test_merge_with_all_params_and_correct_naming(self, mock_version): import_node, node = astroid.extract_node( """ import pandas as pd #@ + df_1 = pd.DataFrame({'A': [1, 2]}) df_merged = df_1.merge(right=df_2, how='inner', on='col1', validate='1:1') #@ """ ) diff --git a/tests/checkers/test_sklearn/test_sklearn_parameter.py b/tests/checkers/test_sklearn/test_sklearn_parameter.py index 9a67075..77e744d 100644 --- a/tests/checkers/test_sklearn/test_sklearn_parameter.py +++ b/tests/checkers/test_sklearn/test_sklearn_parameter.py @@ -13,10 +13,10 @@ class TestSklearnParameterChecker(pylint.testutils.CheckerTestCase): @patch("pylint_ml.checkers.library_base_checker.version") def test_random_forest_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from sklearn.ensemble import RandomForestClassifier - clf = RandomForestClassifier() #@ + from sklearn.ensemble import RandomForestClassifier #@ + clf = RandomForestClassifier() #@ """ ) @@ -31,30 +31,32 @@ def test_random_forest_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(forest_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_random_forest_with_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from sklearn.ensemble import RandomForestClassifier - clf = RandomForestClassifier(n_estimators=100) #@ + from sklearn.ensemble import RandomForestClassifier #@ + clf = RandomForestClassifier(n_estimators=100) #@ """ ) forest_call = node.value with self.assertNoMessages(): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(forest_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_svc_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from sklearn.svm import SVC - clf = SVC() #@ + from sklearn.svm import SVC #@ + clf = SVC() #@ """ ) @@ -69,30 +71,32 @@ def test_svc_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(svc_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_svc_with_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from sklearn.svm import SVC - clf = SVC(C=1.0, kernel='linear') #@ + from sklearn.svm import SVC #@ + clf = SVC(C=1.0, kernel='linear') #@ """ ) svc_call = node.value with self.assertNoMessages(): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(svc_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_kmeans_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from sklearn.cluster import KMeans - kmeans = KMeans() #@ + from sklearn.cluster import KMeans #@ + kmeans = KMeans() #@ """ ) @@ -107,19 +111,21 @@ def test_kmeans_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(kmeans_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_kmeans_with_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + importfrom_node, node = astroid.extract_node( """ - from sklearn.cluster import KMeans - kmeans = KMeans(n_clusters=8) #@ + from sklearn.cluster import KMeans #@ + kmeans = KMeans(n_clusters=8) #@ """ ) kmeans_call = node.value with self.assertNoMessages(): + self.checker.visit_importfrom(importfrom_node) self.checker.visit_call(kmeans_call) diff --git a/tests/checkers/test_tensorflow/test_tensor_parameter.py b/tests/checkers/test_tensorflow/test_tensor_parameter.py index a27aa08..b00db15 100644 --- a/tests/checkers/test_tensorflow/test_tensor_parameter.py +++ b/tests/checkers/test_tensorflow/test_tensor_parameter.py @@ -13,10 +13,10 @@ class TestTensorParameterChecker(pylint.testutils.CheckerTestCase): @patch("pylint_ml.checkers.library_base_checker.version") def test_sequential_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf - model = tf.keras.models.Sequential() #@ + import tensorflow as tf #@ + model = tf.keras.models.Sequential() #@ """ ) @@ -31,34 +31,33 @@ def test_sequential_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(sequential_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_sequential_with_layers(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf - model = tf.keras.Sequential(layers=[ - tf.keras.layers.Dense(units=64, activation='relu'), - tf.keras.layers.Dense(units=10) - ]) + import tensorflow as tf #@ + model = tf.keras.Sequential(layers=[tf.keras.layers.Dense(units=64, activation='relu'),tf.keras.layers.Dense(units=10)]) #@ """ ) sequential_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(sequential_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_compile_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf + import tensorflow as tf #@ model = tf.keras.models.Sequential() - model.compile() #@ + model.compile() #@ """ ) @@ -71,33 +70,35 @@ def test_compile_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(node) @patch("pylint_ml.checkers.library_base_checker.version") def test_compile_with_all_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf + import tensorflow as tf #@ model = tf.keras.models.Sequential() - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) #@ + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) #@ """ ) compile_call = node with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(compile_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_fit_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf + import tensorflow as tf #@ model = tf.keras.models.Sequential() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') - model.fit(epochs=10) #@ + model.fit(epochs=10) #@ """ ) @@ -112,32 +113,34 @@ def test_fit_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(fit_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_fit_with_all_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf + import tensorflow as tf #@ model = tf.keras.models.Sequential() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') - model.fit(x=train_data, y=train_labels, epochs=10) #@ + model.fit(x=train_data, y=train_labels, epochs=10) #@ """ ) fit_call = node with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(fit_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_conv2d_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf - layer = tf.keras.layers.Conv2D(kernel_size=(3, 3)) #@ + import tensorflow as tf #@ + layer = tf.keras.layers.Conv2D(kernel_size=(3, 3)) #@ """ ) @@ -152,30 +155,32 @@ def test_conv2d_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(conv2d_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_conv2d_with_all_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf - layer = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3)) #@ + import tensorflow as tf #@ + layer = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3)) #@ """ ) conv2d_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(conv2d_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_dense_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf - layer = tf.keras.layers.Dense() #@ + import tensorflow as tf #@ + layer = tf.keras.layers.Dense() #@ """ ) @@ -190,19 +195,21 @@ def test_dense_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(dense_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_dense_with_all_params(self, mock_version): mock_version.return_value = "1.5.2" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import tensorflow as tf - layer = tf.keras.layers.Dense(units=64) #@ + import tensorflow as tf #@ + layer = tf.keras.layers.Dense(units=64) #@ """ ) dense_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(dense_call) diff --git a/tests/checkers/test_torch/test_torch_parameter.py b/tests/checkers/test_torch/test_torch_parameter.py index 4337d7f..73f613b 100644 --- a/tests/checkers/test_torch/test_torch_parameter.py +++ b/tests/checkers/test_torch/test_torch_parameter.py @@ -13,9 +13,9 @@ class TestTorchParameterChecker(pylint.testutils.CheckerTestCase): @patch("pylint_ml.checkers.library_base_checker.version") def test_sgd_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.optim as optim + import torch.optim as optim #@ optimizer = optim.SGD(model.parameters(), momentum=0.9) #@ """ ) @@ -31,14 +31,15 @@ def test_sgd_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(sgd_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_sgd_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.optim as optim + import torch.optim as optim #@ optimizer = optim.SGD(lr=0.01) #@ """ ) @@ -46,14 +47,15 @@ def test_sgd_with_all_params(self, mock_version): sgd_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(sgd_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_adam_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.optim as optim + import torch.optim as optim #@ optimizer = optim.Adam(model.parameters()) #@ """ ) @@ -69,14 +71,15 @@ def test_adam_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(adam_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_adam_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.optim as optim + import torch.optim as optim #@ optimizer = optim.Adam(lr=0.001) #@ """ ) @@ -84,14 +87,15 @@ def test_adam_with_all_params(self, mock_version): adam_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(adam_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_conv2d_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.Conv2d(in_channels=3, kernel_size=3) #@ """ ) @@ -107,14 +111,15 @@ def test_conv2d_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(conv2d_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_conv2d_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3) #@ """ ) @@ -122,14 +127,15 @@ def test_conv2d_with_all_params(self, mock_version): conv2d_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(conv2d_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_linear_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.Linear(in_features=128) #@ """ ) @@ -145,14 +151,15 @@ def test_linear_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(linear_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_linear_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.Linear(in_features=128, out_features=64) #@ """ ) @@ -165,9 +172,9 @@ def test_linear_with_all_params(self, mock_version): @patch("pylint_ml.checkers.library_base_checker.version") def test_lstm_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.LSTM(input_size=128) #@ """ ) @@ -183,14 +190,15 @@ def test_lstm_params(self, mock_version): ), ignore_position=True, ): + self.checker.visit_import(import_node) self.checker.visit_call(lstm_call) @patch("pylint_ml.checkers.library_base_checker.version") def test_lstm_with_all_params(self, mock_version): mock_version.return_value = "2.4.1" - node = astroid.extract_node( + import_node, node = astroid.extract_node( """ - import torch.nn as nn + import torch.nn as nn #@ layer = nn.LSTM(input_size=128, hidden_size=64) #@ """ ) @@ -198,4 +206,5 @@ def test_lstm_with_all_params(self, mock_version): lstm_call = node.value with self.assertNoMessages(): + self.checker.visit_import(import_node) self.checker.visit_call(lstm_call)