数据检查#
EvalML 提供了数据检查功能,帮助您构建性能最佳的模型。这些实用函数有助于处理过拟合、异常数据和缺失数据等问题。这些数据检查位于 evalml/data_checks
下。下面我们将介绍 EvalML 中每种可用数据检查以及 DefaultDataChecks
数据检查集合的示例。
缺失数据#
缺失数据或包含 NaN
值的行对机器学习管道提出了许多挑战。在最坏的情况下,许多算法根本无法处理缺失数据!EvalML 管道包含填充(imputation)组件,以确保不会发生这种情况。填充通过现有值来近似缺失值。然而,如果一列包含大量缺失值,那么该列的很大一部分将由一小部分值近似。这可能会导致该列对于机器学习管道而言不包含有用信息。使用 NullDataCheck
,EvalML 将通过返回超过缺失值阈值的列来提醒您这个潜在问题。
[1]:
import numpy as np
import pandas as pd
from evalml.data_checks import NullDataCheck
X = pd.DataFrame(
[[1, 2, 3], [0, 4, np.nan], [1, 4, np.nan], [9, 4, np.nan], [8, 6, np.nan]]
)
null_check = NullDataCheck(pct_null_col_threshold=0.8, pct_null_row_threshold=0.8)
messages = null_check.validate(X)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: Column(s) '2' are 80.0% or more null
异常数据#
EvalML 提供了几种数据检查来检测异常数据
NoVarianceDataCheck
ClassImbalanceDataCheck
TargetLeakageDataCheck
InvalidTargetDataCheck
IDColumnsDataCheck
OutliersDataCheck
HighVarianceCVDataCheck
MulticollinearityDataCheck
UniquenessDataCheck
TargetDistributionDataCheck
DateTimeFormatDataCheck
TimeSeriesParametersDataCheck
TimeSeriesSplittingDataCheck
零方差#
方差为零的数据表明所有值都相同。如果特征的方差为零,它很可能不是一个有用的特征。类似地,如果目标的方差为零,很可能存在问题。NoVarianceDataCheck
检查目标或任何特征是否只有一个唯一值,并提醒您注意任何此类列。
[2]:
from evalml.data_checks import NoVarianceDataCheck
X = pd.DataFrame({"no var col": [0, 0, 0], "good col": [0, 4, 1]})
y = pd.Series([1, 0, 1])
no_variance_data_check = NoVarianceDataCheck()
messages = no_variance_data_check.validate(X, y)
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
Warning: 'no var col' has 1 unique value.
请注意,您可以将 NaN
设置为算作一个唯一值,但如果给定列中只有一个唯一的非 NaN
值,NoVarianceDataCheck
仍会返回警告。
[3]:
from evalml.data_checks import NoVarianceDataCheck
X = pd.DataFrame(
{
"no var col": [0, 0, 0],
"no var col with nan": [1, np.nan, 1],
"good col": [0, 4, 1],
}
)
y = pd.Series([1, 0, 1])
no_variance_data_check = NoVarianceDataCheck(count_nan_as_value=True)
messages = no_variance_data_check.validate(X, y)
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
Warning: 'no var col' has 1 unique value.
Warning: 'no var col with nan' has two unique values including nulls. Consider encoding the nulls for this column to be useful for machine learning.
类别不平衡#
对于分类问题,每个类别中的样本分布可能有所不同。对于小的变化,这是正常且预期的。然而,当每个类别标签的样本数量相对于某个特定类别(或多个类别)极度偏颇或倾斜时,机器学习模型可能难以很好地预测。此外,给定类别的样本数量较低可能意味着用于训练数据生成的一个或多个交叉验证(CV)折叠(fold)可能只有很少或没有该类别的样本。这可能导致模型仅预测多数类别,并最终导致性能较差的模型。
ClassImbalanceDataCheck
检查目标标签在一定数量的 CV 折叠中是否不平衡超过指定阈值。对于样本数量少于指定 CV 折叠数两倍的任何类别,它会返回 DataCheckError
消息(因为这表明给定折叠中该类别的样本很少或没有的可能性很大),对于低于设定阈值百分比的任何类别,则返回 DataCheckWarning
消息。
[4]:
from evalml.data_checks import ClassImbalanceDataCheck
X = pd.DataFrame([[1, 2, 0, 1], [4, 1, 9, 0], [4, 4, 8, 3], [9, 2, 7, 1]])
y = pd.Series([0, 1, 1, 1, 1])
class_imbalance_check = ClassImbalanceDataCheck(threshold=0.25, num_cv_folds=4)
messages = class_imbalance_check.validate(X, y)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: The following labels fall below 25% of the target: [0]
Warning: The following labels in the target have severe class imbalance because they fall under 25% of the target and have less than 100 samples: [0]
Error: The number of instances of these targets is less than 2 * the number of cross folds = 8 instances: [0, 1]
目标泄漏#
目标泄漏(也称为数据泄漏)可能发生在您在训练模型的数据集中包含了在预测时本不应该可用的信息。这会导致模型在评分时表现得异常好,但在生产环境中性能不佳。TargetLeakageDataCheck
通过计算每个特征与目标之间的皮尔逊相关系数来检查可能“泄漏”信息的特征,并在特征与目标高度相关时警告用户。目前,仅考虑数值特征。
[5]:
from evalml.data_checks import TargetLeakageDataCheck
X = pd.DataFrame(
{
"leak": [10, 42, 31, 51, 61] * 5,
"x": [42, 54, 12, 64, 12] * 5,
"y": [12, 5, 13, 74, 24] * 5,
}
)
y = pd.Series([10, 42, 31, 51, 40] * 5)
target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.8)
messages = target_leakage_check.validate(X, y)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: Columns 'leak', 'x', 'y' are 80.0% or more correlated with the target
无效目标数据#
InvalidTargetDataCheck
检查目标数据是否包含任何缺失或无效值。具体来说:
如果任何目标值缺失,将返回
DataCheckError
消息。如果指定的任务类型是二元分类问题,但目标中的唯一值多于或少于两个,将返回
DataCheckError
消息。如果二元分类目标类别是数值,且不等于 {0, 1},将返回
DataCheckError
消息,因为将其传递给管道可能会导致不可预测的行为。
[6]:
from evalml.data_checks import InvalidTargetDataCheck
X = pd.DataFrame({})
y = pd.Series([0, 1, None, None])
invalid_target_check = InvalidTargetDataCheck("binary", "Log Loss Binary")
messages = invalid_target_check.validate(X, y)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: Input target and features have different lengths
Warning: Input target and features have mismatched indices. Details will include the first 10 mismatched indices.
Error: 2 row(s) (50.0%) of target values are null
ID列#
数据集中的 ID 列对机器学习管道几乎没有益处,因为管道无法从唯一标识符中提取有用信息。因此,如果存在这些列,IDColumnsDataCheck
会提醒您。在给定的示例中,'user_number' 和 'revenue_id' 列都被识别为可能是应删除的唯一标识符。
[7]:
from evalml.data_checks import IDColumnsDataCheck
X = pd.DataFrame(
[[0, 53, 6325, 5], [1, 90, 6325, 10], [2, 90, 18, 20]],
columns=["user_number", "cost", "revenue", "revenue_id"],
)
id_col_check = IDColumnsDataCheck(id_threshold=0.9)
messages = id_col_check.validate(X)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: Columns 'user_number', 'revenue_id' are 90.0% or more likely to be an ID column
然而,主键列可能有用。主键列通常是数据集中的第一列,具有所有唯一值,并且命名为 ID
或以 _id
结尾的名称。尽管它们在建模过程中被忽略,但可以用作在建模过程之前或之后进行查询的标识符。如果发现 DataFrame 的第一列是主键,IDColumnsDataCheck
也会提醒您。在给定的示例中,user_id
被识别为主键,而 revenue_id
被识别为常规的唯一标识符。
[8]:
from evalml.data_checks import IDColumnsDataCheck
X = pd.DataFrame(
[[0, 53, 6325, 5], [1, 90, 6325, 10], [2, 90, 18, 20]],
columns=["user_id", "cost", "revenue", "revenue_id"],
)
id_col_check = IDColumnsDataCheck(id_threshold=0.9)
messages = id_col_check.validate(X)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: The first column 'user_id' is likely to be the primary key
Warning: Columns 'revenue_id' are 90.0% or more likely to be an ID column
多重共线性#
MulticollinearityDataCheck
数据检查用于检测是否存在任何可能存在多重共线性的特征集合。多重共线性特征会影响模型的性能,但更重要的是,它可能会极大地影响模型的解释。EvalML 使用互信息来确定共线性。
[9]:
from evalml.data_checks import MulticollinearityDataCheck
y = pd.Series([1, 0, 2, 3, 4] * 5)
X = pd.DataFrame(
{
"col_1": y,
"col_2": y * 3,
"col_3": ~y,
"col_4": y / 2,
"col_5": y + 1,
"not_collinear": [0, 1, 0, 0, 0] * 5,
}
)
multi_check = MulticollinearityDataCheck(threshold=0.95)
messages = multi_check.validate(X)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: Columns are likely to be correlated: [('col_1', 'col_2'), ('col_1', 'col_3'), ('col_1', 'col_4'), ('col_1', 'col_5'), ('col_2', 'col_3'), ('col_2', 'col_4'), ('col_2', 'col_5'), ('col_3', 'col_4'), ('col_3', 'col_5'), ('col_4', 'col_5')]
唯一性#
UniquenessDataCheck
用于检测值过于唯一或不够唯一的列。对于回归类问题,会检查唯一性的下限。对于多类别问题,会检查上限。
[10]:
import pandas as pd
from evalml.data_checks import UniquenessDataCheck
X = pd.DataFrame(
{
"most_unique": [float(x) for x in range(10)], # [0,1,2,3,4,5,6,7,8,9]
"more_unique": [x % 5 for x in range(10)], # [0,1,2,3,4,0,1,2,3,4]
"unique": [x % 3 for x in range(10)], # [0,1,2,0,1,2,0,1,2,0]
"less_unique": [x % 2 for x in range(10)], # [0,1,0,1,0,1,0,1,0,1]
"not_unique": [float(1) for x in range(10)],
}
) # [1,1,1,1,1,1,1,1,1,1]
uniqueness_check = UniquenessDataCheck(problem_type="regression", threshold=0.5)
messages = uniqueness_check.validate(X)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: Input columns 'not_unique' for regression problem type are not unique enough.
稀疏性#
SparsityDataCheck
用于识别包含值稀疏性的特征。
[11]:
from evalml.data_checks import SparsityDataCheck
X = pd.DataFrame(
{
"most_sparse": [float(x) for x in range(10)], # [0,1,2,3,4,5,6,7,8,9]
"more_sparse": [x % 5 for x in range(10)], # [0,1,2,3,4,0,1,2,3,4]
"sparse": [x % 3 for x in range(10)], # [0,1,2,0,1,2,0,1,2,0]
"less_sparse": [x % 2 for x in range(10)], # [0,1,0,1,0,1,0,1,0,1]
"not_sparse": [float(1) for x in range(10)],
}
) # [1,1,1,1,1,1,1,1,1,1]
sparsity_check = SparsityDataCheck(
problem_type="multiclass", threshold=0.4, unique_count_threshold=3
)
messages = sparsity_check.validate(X)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: Input columns ('most_sparse', 'more_sparse', 'sparse') for multiclass problem type are too sparse.
离群值#
离群值是与同一样本中其他观测值显著不同的观测值。如果不从训练集中移除离群值,许多机器学习管道的性能会受到影响,因为它们不能代表数据。OutliersDataCheck()
使用 IQR 来通知您样本是否可以被视为离群值。
下面我们生成一个包含一些离群值的随机数据集。
[12]:
data = np.tile(np.arange(10) * 0.01, (100, 10))
X = pd.DataFrame(data=data)
# generate some outliers in columns 3, 25, 55, and 72
X.iloc[0, 3] = -10000
X.iloc[3, 25] = 10000
X.iloc[5, 55] = 10000
X.iloc[10, 72] = -10000
然后我们利用 OutliersDataCheck()
来重新发现这些离群值。
[13]:
from evalml.data_checks import OutliersDataCheck
outliers_check = OutliersDataCheck()
messages = outliers_check.validate(X)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: Column(s) '3', '25', '55', '72' are likely to have outlier data.
目标分布#
目标数据可以呈现各种分布,例如高斯分布或对数正态分布。当我们使用机器学习模型时,我们将数据输入到一个估计器中,该估计器从提供的训练数据中学习。有时数据可能分布非常分散,带有长尾或离群值,这可能导致对数正态分布。这会影响机器学习模型的性能。
为了帮助估计器更好地理解数据中特征与目标之间的潜在关系,我们可以使用 TargetDistributionDataCheck
来识别此类分布。
[14]:
from scipy.stats import lognorm
from evalml.data_checks import TargetDistributionDataCheck
data = np.tile(np.arange(10) * 0.01, (100, 10))
X = pd.DataFrame(data=data)
y = pd.Series(lognorm.rvs(s=0.4, loc=1, scale=1, size=100))
target_dist_check = TargetDistributionDataCheck()
messages = target_dist_check.validate(X, y)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Warning: Target may have a lognormal distribution.
日期时间格式#
日期时间信息是时间序列问题的必要组成部分,但有时我们处理的数据可能包含使时间序列模型无法正常工作的缺陷。例如,为了识别日期时间信息中的频率,数据点之间必须有相等的间隔,即 2021年1月1日、2021年1月3日、2021年1月5日等,它们之间间隔两天。如果日期时间数据中存在随机跳跃,即 2021年1月1日、2021年1月3日、2021年1月12日,则无法推断出频率。时间序列模型的另一个常见问题是它们无法处理未正确排序的日期时间信息。日期时间值如果不是单调递增(按升序排序),就会遇到此问题,并且无法推断其频率。
为了方便验证您使用的日期时间列是否间隔和排序正确,我们可以利用 DatetimeFormatDataCheck
。初始化数据检查时,传入包含日期时间信息的列名(如果它在您的 X 或 y 索引中,则传入“index”)。
[15]:
from evalml.data_checks import DateTimeFormatDataCheck
X = pd.DataFrame(
pd.date_range("January 1, 2021", periods=8, freq="2D"), columns=["dates"]
)
y = pd.Series([1, 2, 4, 2, 1, 2, 3, 1])
# Replaces the last entry with January 16th instead of January 15th
# so that the data is no longer evenly spaced.
X.iloc[7] = "January 16, 2021"
datetime_format_check = DateTimeFormatDataCheck(datetime_column="dates")
messages = datetime_format_check.validate(X, y)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
print("--------------------------------")
# Reverses the order of the index datetime values to be decreasing.
X = X[::-1]
messages = datetime_format_check.validate(X, y)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Error: Column 'dates' has datetime values that do not align with the inferred frequency.
Error: A frequency was detected in column 'dates', but there are faulty datetime values that need to be addressed.
--------------------------------
Error: Datetime values must be sorted in ascending order.
Error: No frequency could be detected in column 'dates', possibly due to uneven intervals or too many duplicate/missing values.
时间序列参数#
为了支持 AutoML 中的时间序列任务类型,必须满足某些条件。- 参数 gap
、max_delay
、forecast_horizon
和 time_index
必须传递到 problem_configuration
中。- gap
、max_delay
、forecast_horizon
的值必须适合数据的大小。
对于上述第 2 点,这意味着窗口大小(由 gap
+ max_delay
+ forecast_horizon
定义)必须小于数据中的观测值数量除以拆分数加 1。例如,有 100 个观测值和 3 个拆分,则拆分大小为 25。这意味着窗口大小必须小于 25。
[16]:
from evalml.data_checks import TimeSeriesParametersDataCheck
X = pd.DataFrame(pd.date_range("1/1/21", periods=100), columns=["dates"])
y = pd.Series([i % 2 for i in range(100)])
problem_config = {
"gap": 1,
"max_delay": 23,
"forecast_horizon": 1,
"time_index": "dates",
}
# With 3 splits, the split size will be 25 (100/3+1)
# Since gap + max_delay + forecast_horizon is 25, this will
# throw an error for window size.
ts_params_data_check = TimeSeriesParametersDataCheck(
problem_configuration=problem_config, n_splits=3
)
messages = ts_params_data_check.validate(X, y)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
时间序列拆分#
由于时间序列数据的特性,拆分不能涉及混洗,必须按顺序进行。这意味着将数据拆分为 n_splits
+ 1 个不同的部分,并在每次迭代时将训练数据的大小增加拆分大小,同时保持测试大小等于拆分大小。
对于数据中的每个拆分,训练和验证部分必须包含目标数据,这些目标数据包含整个目标集中发现的每个类别的样本,适用于时间序列二元分类和时间序列多类别分类问题。原因在于,如果许多分类机器学习模型在训练时使用的数据不包含某个类别的实例,但模型被期望能够预测该类别,它们就会遇到问题。例如,对于 3 个拆分和拆分大小为 25,这意味着每个训练/验证拆分:(0:25)/(25:50)、(0:50)/(50:75)、(0:75)/(75:100)必须在训练集和验证集中至少包含所有唯一目标类别的至少一个实例。- 在时间序列二元分类问题中,至少包含两个类别的实例。- 在时间序列多类别分类问题中,至少包含所有类别的实例。
[17]:
from evalml.data_checks import TimeSeriesSplittingDataCheck
X = None
y = pd.Series([0 if i < 50 else i % 2 for i in range(100)])
ts_splitting_check = TimeSeriesSplittingDataCheck("time series binary", 3)
messages = ts_splitting_check.validate(X, y)
errors = [message for message in messages if message["level"] == "error"]
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
for error in errors:
print("Error:", error["message"])
Error: Time Series Binary and Time Series Multiclass problem types require every training and validation split to have at least one instance of all the target classes. The following splits are invalid: [1, 2]
数据检查消息#
每个数据检查的 validate
方法返回一个 DataCheckMessage
对象列表,指示发现的警告或错误;警告存储为 DataCheckWarning
对象,错误存储为 DataCheckError
对象。您可以通过检查返回的消息类型来过滤数据检查返回的消息。下面,NoVarianceDataCheck
返回一个包含 DataCheckWarning
和 DataCheckError
消息的列表。我们可以通过检查每条消息的类型来确定哪个是哪个。
[18]:
from evalml.data_checks import NoVarianceDataCheck, DataCheckWarning
X = pd.DataFrame(
{
"no var col": [0, 0, 0],
"no var col with nan": [1, np.nan, 1],
"good col": [0, 4, 1],
}
)
y = pd.Series([1, 0, 1])
no_variance_data_check = NoVarianceDataCheck(count_nan_as_value=True)
messages = no_variance_data_check.validate(X, y)
warnings = [message for message in messages if message["level"] == "warning"]
for warning in warnings:
print("Warning:", warning["message"])
Warning: 'no var col' has 1 unique value.
Warning: 'no var col with nan' has two unique values including nulls. Consider encoding the nulls for this column to be useful for machine learning.
编写您自己的数据检查#
如果您希望编写自己的数据检查,可以通过继承 DataCheck
类并实现 validate(self, X, y)
类方法来实现。下面,我们创建了一个新的 DataCheck
,名为 ZeroVarianceDataCheck
,它类似于 EvalML
中定义的 NoVarianceDataCheck
。validate(self, X, y)
方法应返回一个字典,其中以“warnings”和“errors”作为键,分别映射到警告列表和错误列表。
[19]:
from evalml.data_checks import DataCheck
class ZeroVarianceDataCheck(DataCheck):
def validate(self, X, y):
messages = []
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
warning_msg = "Column '{}' has zero variance"
messages.extend(
[
DataCheckError(warning_msg.format(column), self.name)
for column in X.columns
if len(X[column].unique()) == 1
]
)
return messages
定义数据检查集合#
为了方便起见,EvalML 提供了 DataChecks
类来表示数据检查的集合。我们将介绍 DefaultDataChecks
(API 参考),这是一个用于检查一些最常见数据问题的集合。
默认数据检查#
DefaultDataChecks
是一个数据检查集合,用于检查一些最常见的数据问题。它们包括:
NullDataCheck
IDColumnsDataCheck
TargetLeakageDataCheck
InvalidTargetDataCheck
TargetDistributionDataCheck
(适用于回归任务类型)ClassImbalanceDataCheck
(适用于分类任务类型)NoVarianceDataCheck
DateTimeFormatDataCheck
(适用于时间序列任务类型)TimeSeriesParametersDataCheck
(适用于时间序列任务类型)TimeSeriesSplittingDataCheck
(适用于时间序列分类任务类型)
编写您自己的数据检查集合#
如果您希望创建自己的数据检查集合,可以通过继承 DataChecks
类并设置 self.data_checks
属性为 DataCheck
类或对象的列表,或者将该数据检查列表传递给 DataChecks
类的构造函数。下面,我们使用这两种不同的方法创建了两个相同的数据检查集合。
[20]:
# Create a subclass of `DataChecks`
from evalml.data_checks import (
DataChecks,
NullDataCheck,
InvalidTargetDataCheck,
NoVarianceDataCheck,
ClassImbalanceDataCheck,
TargetLeakageDataCheck,
)
from evalml.problem_types import ProblemTypes, handle_problem_types
class MyCustomDataChecks(DataChecks):
data_checks = [
NullDataCheck,
InvalidTargetDataCheck,
NoVarianceDataCheck,
TargetLeakageDataCheck,
]
def __init__(self, problem_type, objective):
"""
A collection of basic data checks.
Args:
problem_type (str): The problem type that is being validated. Can be regression, binary, or multiclass.
"""
if handle_problem_types(problem_type) == ProblemTypes.REGRESSION:
super().__init__(
self.data_checks,
data_check_params={
"InvalidTargetDataCheck": {
"problem_type": problem_type,
"objective": objective,
}
},
)
else:
super().__init__(
self.data_checks + [ClassImbalanceDataCheck],
data_check_params={
"InvalidTargetDataCheck": {
"problem_type": problem_type,
"objective": objective,
}
},
)
custom_data_checks = MyCustomDataChecks(
problem_type=ProblemTypes.REGRESSION, objective="R2"
)
for data_check in custom_data_checks.data_checks:
print(data_check.name)
NullDataCheck
InvalidTargetDataCheck
NoVarianceDataCheck
TargetLeakageDataCheck
[21]:
# Pass list of data checks to the `data_checks` parameter of DataChecks
same_custom_data_checks = DataChecks(
data_checks=[
NullDataCheck,
InvalidTargetDataCheck,
NoVarianceDataCheck,
TargetLeakageDataCheck,
],
data_check_params={
"InvalidTargetDataCheck": {
"problem_type": ProblemTypes.REGRESSION,
"objective": "R2",
}
},
)
for data_check in custom_data_checks.data_checks:
print(data_check.name)
NullDataCheck
InvalidTargetDataCheck
NoVarianceDataCheck
TargetLeakageDataCheck