Source code for climb.tool.impl.smart_testing_helpers.SMART

import ast
import hashlib
import itertools
import re
import warnings
from typing import Any, Dict, Optional

import pandas as pd
from openai import AzureOpenAI
from pydantic import BaseModel, PrivateAttr
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import OneHotEncoder
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

# Import calculate_group_statistics function
from .utils import calculate_group_statistics, calculate_group_statistics_string


[docs] def generate_combinations_for_variable(var_values): # Single value combinations single_value_combos = [(val,) for val in var_values] # Pair value combinations pair_value_combos = list(itertools.combinations(var_values, 2)) return single_value_combos + pair_value_combos
[docs] def clean_query_string(query): # Replace or remove unwanted characters query = query.replace("\\", "") # Remove backslashes query = query.replace("'", "") # Remove single quotes if necessary return query
[docs] def convert_to_string_condition(query): # Regex pattern to extract the column name and the condition pattern = r"\((\w+)\s*==\s*([\d\.]+\s*-\s*[\d\.]+)\)" match = re.search(pattern, query) if match: column_name = match.group(1) condition = match.group(2) # Constructing the new query new_query = f"({column_name} == '{condition}')" return new_query else: # Returning the old query return query
[docs] class SMART(BaseModel): llm: AzureOpenAI config: dict verbose: bool = True _subgroups: Optional[Dict] = PrivateAttr(default=None) # Hypotheses are stored as a string _hypotheses: Optional[str] = PrivateAttr(default=None) _updated_hypotheses: Optional[str] = PrivateAttr(default=None) # Get context and context targets as strings context: Optional[str] = None context_target: Optional[str] = None optimal_queries: Optional[Dict] = None task: Optional[str] = None _selfrefine_steps: Dict[int, str] = PrivateAttr(default_factory=dict) _subgroup_cache: Dict[str, Dict] = PrivateAttr(default_factory=dict) _unique_values: Optional[Dict] = PrivateAttr(default=None)
[docs] class Config: arbitrary_types_allowed = True
def _get_llm_response(self, input_text, system_message=None, metadata_output=False, modelid=None): if self.verbose: print("----------INPUT TEXT --------------") print(input_text) if system_message is None: # LLM response with/without a system message response = self.llm.chat.completions.create( model=self.config["engine"], messages=[{"role": "user", "content": input_text}], temperature=self.config["temperature"], # seed=self.config['seed'], ) else: # Get the response from the LLM with a system message response = self.llm.chat.completions.create( model=self.config["engine"], messages=[{"role": "system", "content": system_message}, {"role": "user", "content": input_text}], temperature=self.config["temperature"], seed=self.config["seed"], ) message = response.choices[0].message.content if self.verbose: print("----------LLM RESPONSE TEXT--------------") print(message) if metadata_output: metadata = { "tools": response.choices[0].message.tool_calls, "function calls": response.choices[0].message.function_call, } return message, metadata else: return message def _generate_cache_key(self, X: pd.DataFrame) -> str: """ Generates a unique cache key based on the DataFrame columns. """ column_string = ",".join(sorted(X.columns)) return hashlib.sha256(column_string.encode()).hexdigest()
[docs] def clear_cache(self): """Clears the subgroup cache.""" self._subgroup_cache.clear() if self.verbose: print("Cache cleared.")
def _get_unique_values(self, X, unique_threshold: int = 30) -> Dict[str, Any]: """ Parses through the dataset and returns the unique values for each column. """ unique_values = {} for col in X.columns: if len(X[col].unique()) <= unique_threshold: unique_values[col] = list(X[col].unique()) else: if X[col].dtype in ["int64", "float64"]: unique_values[col] = {"min": X[col].min(), "mean": X[col].mean(), "max": X[col].max()} else: unique_values[col] = "Too many unique values" if self.verbose: warnings.warn(f"Column {col} has too many unique values.", UserWarning) return unique_values
[docs] def fit( self, X: pd.DataFrame, context: Optional[str] = None, context_target: Optional[str] = None, n: int = 5, evaluate_feasibility=False, ): """Finds subgroups by generating hypotheses, operationalizing them, and summarizing the findings""" cache_key = self._generate_cache_key(X) # Check if the result is already cached if cache_key in self._subgroup_cache: self._subgroups = self._subgroup_cache[cache_key] print("Cached subgroups loaded.") return self unique_values = self._get_unique_values(X) # Save the unique values self._unique_values = unique_values task = self._construct_task(unique_values, context, context_target, n) # Update the context and context target self.context = context self.context_target = context_target self.task = task # Evaluate feasibility if evaluate_feasibility: feasibility_response = self._feasibility_check(unique_values, context, context_target) if feasibility_response.lower().strip() == "yes": # pyright: ignore print("Group discovery is possible. Discovering subgroups...") elif feasibility_response.lower().strip() == "no": # pyright: ignore print("No groups discovered") self._subgroups = {} self._subgroup_cache[cache_key] = self._subgroups return self else: print(f"The response from the feasibility status is: {feasibility_response.lower().strip()}") # pyright: ignore # Assuming that the task is feasible, generating hypotheses hypotheses = self._get_llm_response(task) self._hypotheses = hypotheses # pyright: ignore # Operationalizing the hypotheses operationalization_prompt = self._construct_operationalization_prompt( hypotheses, unique_values, context, context_target ) operationalizations = self._get_llm_response(operationalization_prompt) # Summarizing the findings summarization_prompt = self._construct_summarization_prompt(operationalizations, unique_values) summary_dict = self._get_llm_response(summarization_prompt) # Set regex pattern pattern = r"\{.*?\}" try: summary_dict = re.findall(pattern, summary_dict, re.DOTALL)[0] # pyright: ignore self._subgroups = ast.literal_eval(summary_dict) except Exception: correction_prompt = f"""The following is a dictionary that contains the subgroups. Return ONLY the dictionary with no additional text before or after. {summary_dict}""" if self.verbose: print(correction_prompt) response_correction = self._get_llm_response(correction_prompt) summary_dict = re.findall(pattern, response_correction, re.DOTALL)[0] # pyright: ignore self._subgroups = ast.literal_eval(summary_dict) # Adjust the subgroup queries self._adjust_subgroup_queries(X) # Cache the subgroup findings self._subgroup_cache[cache_key] = self._subgroups # pyright: ignore return self
[docs] def find_subgroup_variables( self, X: pd.DataFrame, context: Optional[str] = None, context_target: Optional[str] = None, n: int = 30 ): """Finds subgroups by generating hypotheses, operationalizing them, and summarizing the findings""" cache_key = self._generate_cache_key(X) # Check if the result is already cached if cache_key in self._subgroup_cache: self._subgroups = self._subgroup_cache[cache_key] print("Cached subgroups loaded.") return self unique_values = self._get_unique_values(X) # Save the unique values self._unique_values = unique_values task = self._construct_task_hypotheses(unique_values, context, context_target, n) # Update the context and context target self.context = context self.context_target = context_target self.task = task hypotheses = self._get_llm_response(task) self._hypotheses = hypotheses # pyright: ignore # Operationalizing the hypotheses operationalization_prompt = self._construct_operationalization_subgroups( hypotheses, unique_values, context, context_target ) operationalizations = self._get_llm_response(operationalization_prompt) # Set regex pattern pattern = r"\{.*?\}" try: summary_dict = re.findall(pattern, operationalizations, re.DOTALL)[0] # pyright: ignore self._subgroups = ast.literal_eval(summary_dict) # Loop and ensure all of the subgroups are lists. If not, convert to lists. for key, value in self._subgroups.items(): # pyright: ignore if not isinstance(value, list): self._subgroups[key] = [value] # pyright: ignore except Exception: correction_prompt = f"""The following is a dictionary that contains the subgroups. Return ONLY the dictionary with no additional text before or after. {operationalizations}""" if self.verbose: print(correction_prompt) response_correction = self._get_llm_response(correction_prompt) summary_dict = re.findall(pattern, response_correction, re.DOTALL)[0] # pyright: ignore self._subgroups = ast.literal_eval(summary_dict) # Loop and ensure all of the subgroups are lists. If not, convert to lists. for key, value in self._subgroups.items(): # pyright: ignore if not isinstance(value, list): self._subgroups[key] = [value] # pyright: ignore # Cache findings if not cached self._subgroup_cache[cache_key] = self._subgroups # pyright: ignore return self
[docs] def predict(self, X: pd.DataFrame) -> pd.DataFrame: """ Predicts group membership for each observation in the DataFrame. :param X: DataFrame containing the observations. :return: DataFrame with additional boolean columns indicating group membership. """ # Check if all column names in the dictionary conditions are valid valid_columns = set(X.columns) for group, condition in self._subgroups.items(): # pyright: ignore # Extract column names from the condition columns_in_condition = [word for word in condition.split() if word in valid_columns] if not columns_in_condition: warnings.warn(f"No valid columns found in condition for group {group}: {condition}", UserWarning) # TODO: Check if values in the dictionary exist in the DataFrame # Create group columns for group, condition in self._subgroups.items(): # pyright: ignore try: indices_condition = X.query(condition).index bool_condition = X.index.isin(indices_condition) X[f"group_{group}"] = bool_condition except Exception as e: warnings.warn(f"Failed to apply condition for group {group}: {condition}. Error: {e}", UserWarning) return X
@property def subgroups(self): """Return the identified subgroups""" return self._subgroups @property def hypotheses(self): """Return the hypotheses""" return self._hypotheses def _self_refine(self, unique_values, context, context_target, previous_response, system_message, n=3): """Self-refines an answer multiple times""" for iter_ in range(n): selfrefine_task = f""" The context is: {context} and the target variable is {context_target} with the following columns: {", ".join(unique_values.keys())}. \nPrevious answer: {previous_response}. \n\nTASK: Critically evaluate the answer below and then re-write it. Make sure to follow the instructions provided before. \n\n """ previous_response = self._get_llm_response(selfrefine_task, system_message=system_message) self._selfrefine_steps[iter_] = previous_response # pyright: ignore return previous_response def _feasibility_check(self, unique_values, context, context_target, system_message=None, n_refine=1): """Checks if the task is feasible. Logic: (1) perform a feasibility query; (2) self-refine the answer; (3) convert to boolean""" feasibility_task = f"""Your task is to evaluate whether it is reasonable to to search for subgroups where a predictive model which perform suboptimally. Given the following context about the dataset: {context} and the target variable: {context_target}, and the following columns: {", ".join(unique_values.keys())}, is it reasonable to search for societally meaningful subgroups? Write the reasons why yes, then why no, and provide an overall summary. """ if system_message is None: system_message = """You are an expert at clearly evaluating whether there is a direct connection between the covariates and the outcome variable. Your goal is to determine whether such a connection exists in academic literature or other sources. Avoid making ridiculous connections that are unlikely to hold in reality. Be critical. Focus on avoiding false positives (i.e. relationships that might not exist) because it is costly to test these assumptions and we might overfit the results. Avoid speculative or weak connections. Prioritize false negatives (missing connections) than false positives (offering weak connections that might not hold)""" # Evaluating the feasibility of the response if self.verbose: print("Evaluating feasibility of the response...") feasibility_response = self._get_llm_response(feasibility_task, system_message=system_message) if self.verbose: print("Self-refining answer...") # Refining the answer feasibility_response = self._self_refine( unique_values, context, context_target, feasibility_response, system_message, n=n_refine ) # Convert to boolean boolean_task = f"""Your task is to return an answer 'yes' or 'no' on whether it is worthwile to inspect subgroups, based on the response provided below. Answer: {feasibility_response} \n\nTASK: Answer whether it is worthwile to inspect subgroups, based on the response provided above. Answer: 'yes' or 'no'.""" feasibility_boolean_response = self._get_llm_response(boolean_task) return feasibility_boolean_response def _construct_task_hypotheses(self, unique_values, context, context_target, n): """Constructs the task description for the LLM.""" task = f"""Your task is to propose possible hypotheses as to which subgroups within the dataset might have worse predictive performance than on average because of societal bias in the dataset, insufficient data, other relationships, or others. The subgroups might be based on any of the provided characteristics, as well as on any combination of such characteristics. Dataset information: {context}. {context_target} The dataset contains {len(unique_values)} columns. The columns are {", ".join(unique_values.keys())}. Task: Create {n} hypotheses as to which subgroups within the dataset the model will perform worse than on average because of societal biases or other reasons. Important: Your hypothesis can contain either one variable or two variables in the condition. Therefore, your goal is to find discrepancies in the model's performance, not the underlying data outcomes. Justify why you think that for each of the {n} hypotheses. You must use this format: Hypothesis: <>; Justification: <>, with the hypothesis and justification on the same line separated by a ';'. e.g. Hypothesis 1: <Hypothesis>; Justification: <Justification> Hypothesis 2: <Hypothesis>; Justification: <Justification> """ return task def _construct_task(self, unique_values, context, context_target, n): """Constructs the task description for the LLM.""" task = f"""Your task is to propose possible hypotheses as to which subgroups within the dataset might have worse predictive performance than on average because of societal bias in the dataset, insufficient data, other relationships, or others. The subgroups might be based on any of the provided characteristics, as well as on any combination of such characteristics. Dataset information: {context}. {context_target} The dataset contains {len(unique_values)} columns. The columns are {", ".join(unique_values.keys())}. The values are {str(unique_values.items())} Task: Create {n} hypotheses as to which subgroups within the dataset the model will perform worse than on average because of societal biases or other reasons. Therefore, your goal is to find discrepancies in the model's performance, not the underlying data outcomes. Justify why you think that. You must use this format of the output: Hypothesis: <>; Justification: <>, with the hypothesis, justification, and operationalization on the same line separated by a ';'. e.g. Hypothesis 1: <Hypothesis>; Justification: <Justification> Hypothesis 2: <Hypothesis>; Justification: <Justification> """ return task def _construct_operationalization_prompt(self, hypotheses, unique_values, context, context_target): """Constructs the operationalization prompt for the LLM.""" operationalization_prompt = f""" The following are hypotheses about which people within a dataset the model might underperform on. Propose specific ranges for each hypothesis. Hypotheses: {hypotheses}. Dataset information: {context}. {context_target} The dataset contains {len(unique_values)} columns. The columns are {", ".join(unique_values.keys())}. The values are {str(unique_values.items())} TASK: Propose specific variable ranges for each hypothesis such that they are clearly operationalizable and defined. **Use the exact column names with the correct casing as they appear in the dataset**. Ensure that each Operationalization is a single-line expression without line breaks. You must use this format: Hypothesis: <>; Operationalization: <>, with the hypothesis and operationalization on the same line separated by a ';'. e.g. Hypothesis 1: <Hypothesis>; Operationalization: <Operationalization> Hypothesis 2: <Hypothesis>; Operationalization: <Operationalization> """ return operationalization_prompt def _construct_operationalization_subgroups(self, hypotheses, unique_values, context, context_target): """Constructs the operationalization prompt for the LLM.""" operationalization_prompt = f""" The following are hypotheses about which people within a dataset the model might underperform on. Propose specific ranges for each hypothesis. Hypotheses: {hypotheses}. TASK: return a dictionary that contains an index number as the key and the column value as the value. If there are multiple columns in that hypothesis, return them in a list. There are the column names: {", ".join(unique_values.keys())}. """ return operationalization_prompt def _construct_revised_operationalization_prompt(self, new_context, unique_values, context, context_target): """Constructs the operationalization prompt for the LLM.""" operationalization_prompt = f""" You have access to the following information. Dataset information: {context}. {context_target} The dataset contains {len(unique_values)} columns. The columns are {", ".join(unique_values.keys())}. The values are {str(unique_values.items())} However, you are no longer working with the same data as just described. Rather, this is the context: {new_context}. These are the hypotheses: {self._updated_hypotheses}. TASK: Propose specific variable ranges for each hypothesis such that they are clearly operationalizable and defined. You must use this format: Hypothesis: <>; Operationalization: <>, with the hypothesis and operationalization on the same line separated by a ';'. e.g. Hypothesis 1: <Hypothesis>; Operationalization: <Operationalization> Hypothesis 2: <Hypothesis>; Operationalization: <Operationalization> """ return operationalization_prompt def _construct_summarization_prompt(self, operationalizations, unique_values): """Constructs the summarization prompt for the LLM.""" summarization_prompt = f""" The following are groups that are defined based on the dataset. Convert them into a Python dictionary format. Each group should be represented as a key-value pair in the dictionary, where the key is an index (0 to 4), and the value is a string representing the group using Python syntax and logical operators. For multiple conditions, use Python's logical 'and' ('&&') or 'or' ('||'). Ensure the format is a valid Python dictionary. Examples: - Single Condition: {{0: 'X > 45'}} - Multiple Conditions: {{1: '(X > 45) and (Y < 20)'}} - String conditions: {{2: 'X == '45 - 60'}} Groups to summarize: {operationalizations} Column names: {", ".join(unique_values.keys())} Column values: {str(unique_values.items())} """ return summarization_prompt def _adjust_subgroup_queries(self, X: pd.DataFrame, n_subgroups=1): """Adjusts the subgroup queries if they do not exist in the dataset.""" unique_values = self._unique_values for group, condition in list(self._subgroups.items())[:n_subgroups]: # pyright: ignore try: # Check if the condition yields any rows if len(X.query(condition)) == 0: # Call LLM to adjust the condition adjustment_prompt = self._construct_adjustment_prompt(condition, unique_values) adjusted_condition = self._get_llm_response(adjustment_prompt) # Update the condition self._subgroups[group] = adjusted_condition # pyright: ignore print("Primary condition: ", condition) print("Adjusted condition: ", adjusted_condition) except Exception as e: warnings.warn(f"Error evaluating condition for group {group}: {condition}. Error: {e}", UserWarning) # Call LLM to adjust the condition print("Adjusting condition...") unique_values_adj = {k: v for k, v in unique_values.items() if k in condition} # pyright: ignore adjustment_prompt = self._construct_adjustment_prompt(condition, unique_values_adj) adjusted_condition = self._get_llm_response(adjustment_prompt) # TODO - make this part more robust. try: X.query(adjusted_condition) # pyright: ignore except Exception as e: adjusted_condition = clean_query_string(adjusted_condition) # Check if the condition has any strings assuming it is a single condition if "and" not in adjusted_condition: adjusted_condition = convert_to_string_condition(adjusted_condition) print("Primary condition: ", condition) print("Adjusted condition: ", adjusted_condition) # Update the condition self._subgroups[group] = adjusted_condition # pyright: ignore def _construct_adjustment_prompt(self, condition, unique_values): """ Constructs the prompt for adjusting a subgroup condition. """ adjustment_prompt = f""" The following pandas query does not match any rows in the dataset: '{condition}'. This condition uses a column and a value to filter values, but the data types are incorrect. Adjust the condition using the datasets values and columns, such that the condition would work on the unique values in the dataset, and the condition would be as close as possible to the original one. Unique values/statistics of relevant columns: {str(unique_values)} TASK: Provide an adjusted dataframe query that uses the specific unique values in the dataset which would not throw an error and would closely match the original condition. Provide ONLY the query. Query: """ return adjustment_prompt
[docs] def extract_hypotheses_and_justifications(self): """ Extracts hypotheses and their justifications from the provided text and organizes them into a pandas DataFrame. Each hypothesis and its justification are in separate columns. The number of rows corresponds to the number of hypotheses. Returns: pd.DataFrame: A DataFrame with 'Hypothesis', 'Justification', and 'Operationalization' columns. """ text = self._hypotheses subgroups = self.subgroups if not text: print("No hypotheses found in '_hypotheses'. Ensure that 'fit' has been called.") return pd.DataFrame(columns=["Hypothesis", "Justification", "Operationalization"]) # Split the text into lines lines = text.split("\n") # Lists to store hypotheses, justifications, and operationalizations hypotheses = [] justifications = [] operationalizations = [] # Initialize a counter for operationalizations subgroup_keys = list(subgroups.keys()) # pyright: ignore subgroup_index = 0 # Process each line to extract hypothesis and justification for line_number, line in enumerate(lines, start=1): line = line.strip() if line.startswith("Hypothesis"): hypothesis_part, justification_part = line.split("; ", 1) # Extract text after 'Hypothesis: ' and 'Justification: ' if present hypothesis = hypothesis_part.split(": ", 1)[1] if ": " in hypothesis_part else hypothesis_part justification = ( justification_part.split(": ", 1)[1] if ": " in justification_part else justification_part ) hypotheses.append(hypothesis) justifications.append(justification) # Assign operationalization if available if subgroup_index < len(subgroup_keys): subgroup_key = subgroup_keys[subgroup_index] operationalization = subgroups[subgroup_key] # pyright: ignore operationalizations.append(operationalization) subgroup_index += 1 else: # If there are more hypotheses than subgroups, assign None or a default operationalizations.append(None) # Create a DataFrame df = pd.DataFrame( {"Hypothesis": hypotheses, "Justification": justifications, "Operationalization": operationalizations} ) # Optional: Check for alignment between hypotheses and operationalizations if len(df) != len(subgroup_keys): print("Number of hypotheses does not match number of subgroups.") return df
[docs] def generate_model_report( self, X_train, y_train, X_test, y_test, model, keys_calculate=[ "group_size", "support", "p_value_bootstrap", "num_criteria", "outcome_diff", "accuracy_diff", "odds_ratio_outcome", "odds_ratio_acc", "lift_outcome", "lift_acc", "weighted_relative_outcome", "weighted_relative_accuracy", ], ): """Currenty supported only for the subgroup_finder without the self-falsification mechanism""" table_summary = self.extract_hypotheses_and_justifications() table_summary_train = table_summary.copy() for oper in table_summary_train["Operationalization"]: for key in keys_calculate: table_summary_train.loc[table_summary_train["Operationalization"] == oper, key] = ( calculate_group_statistics(X_train, y_train, model, oper)[key] ) table_summary_test = table_summary.copy() for oper in table_summary_test["Operationalization"]: for key in keys_calculate: table_summary_test.loc[table_summary_test["Operationalization"] == oper, key] = ( calculate_group_statistics(X_test, y_test, model, oper)[key] ) input_text = f""" The following is the context: {self.context}. The following is the target context: {self.context_target}. The following is a table summarizing the information about the results on the training dataset: {table_summary_train}. The following is a table summarizing the information about the results on the test dataset: {table_summary_test}. TASK: Write recommendations to the user based on the results. Answer these questions: \n 1. When does the model fail and when it is reliable? \n 2. What should the end user be aware of before deploying the model? Keep your recommendations short, brief, and actionable. Avoid repeating information which has already been said. """ response_recommendations = self._get_llm_response(input_text) return response_recommendations
[docs] def revise_hypotheses(self, new_context: str) -> str: """ Revises the existing hypotheses based on a new context. :param new_context: A string representing the new context to consider for revising hypotheses. :return: A string containing the set of new hypotheses. """ if self._hypotheses is None: raise ValueError("No existing hypotheses to revise. Please run 'fit' method first.") # Constructing the new prompt revise_prompt = f""" The original context for the dataset was: {self.context} with the target variable {self.context_target}. Based on this, the following hypotheses were generated: {self._hypotheses}. The subgroups identified were: {str(self._subgroups)} Now, a new context has emerged: {new_context}. TASK: Considering both the original and the new context, revise the earlier hypotheses. Generate a new set of hypotheses instead of the old hypotheses that take into account any changes or additional information provided by the new context. Ensure that these hypotheses are relevant and applicable to the updated scenario. Assume access to the same data as before. """ # Getting the response from LLM new_hypotheses = self._get_llm_response(revise_prompt) # Revise the hypotheses if self.verbose: print("Revising hypotheses...") new_hypotheses = self._self_refine( self._unique_values, new_context, self.context_target, new_hypotheses, system_message=revise_prompt, n=2 ) # Updating the hypotheses attribute self._updated_hypotheses = new_hypotheses # pyright: ignore return new_hypotheses # pyright: ignore
[docs] def revise_fit(self, new_context, X): unique_values = self._unique_values # Operationalizing the hypotheses operationalization_prompt = self._construct_revised_operationalization_prompt( new_context, unique_values, self.context, self.context_target ) operationalizations = self._get_llm_response(operationalization_prompt) # Summarizing the findings summarization_prompt = self._construct_summarization_prompt(operationalizations, unique_values) summary_dict = self._get_llm_response(summarization_prompt) # Set regex pattern pattern = r"\{.*?\}" try: summary_dict = re.findall(pattern, summary_dict, re.DOTALL)[0] # pyright: ignore subgroups = ast.literal_eval(summary_dict) except Exception: correction_prompt = f"""The following is a dictionary that contains the subgroups. Return ONLY the dictionary with no additional text before or after. Return an empty dictionary if none exists. \n {summary_dict}""" if self.verbose: print(correction_prompt) response_correction = self._get_llm_response(correction_prompt) summary_dict = re.findall(pattern, response_correction, re.DOTALL)[0] # pyright: ignore subgroups = ast.literal_eval(summary_dict) return subgroups
[docs] def get_optimal_split_query( self, dataframe, features, outcome, min_group_size=10, test_for_min=True, max_group_size=float("inf") ): """ Generates a query string for splitting the dataframe into two subgroups where the difference in the outcome variable is maximized, based on up to three features. :param dataframe: A pandas DataFrame containing the data. :param features: A list of feature variable names (up to 3 features). :param outcome: The name of the outcome variable. :param min_group_size: The minimum size of each group. :return: A query string for the subgroup where the outcome is minimized. """ if not all(feature in dataframe.columns for feature in features): raise ValueError("All features must be present in the dataframe") if outcome not in dataframe.columns: raise ValueError("Outcome variable must be present in the dataframe") # Determine if the outcome variable is continuous or categorical if pd.api.types.is_numeric_dtype(dataframe[outcome]): tree_model = DecisionTreeRegressor(max_depth=len(features)) else: tree_model = DecisionTreeClassifier(max_depth=len(features)) # Fit the model tree_model.fit(dataframe[features], dataframe[outcome]) # Function to recursively traverse the tree and find the optimal split def traverse_tree(node=0, depth=0, conditions=[]): if ( tree_model.tree_.children_left[node] == tree_model.tree_.children_right[node] # pyright: ignore ): # Leaf node if not conditions: # Check for empty conditions return None, float("-inf") # Evaluate split left_indices = dataframe.query(" and ".join(conditions)).index right_indices = dataframe.index.difference(left_indices) if test_for_min: if len(left_indices) < min_group_size or len(right_indices) < min_group_size: return None, float("-inf") else: if len(left_indices) < max_group_size and len(left_indices) > min_group_size: left_mean = dataframe.loc[left_indices, outcome].mean() right_mean = dataframe.loc[right_indices, outcome].mean() discrepancy = abs(left_mean - right_mean) if left_mean >= right_mean: return " and ".join(conditions), discrepancy else: conditions = [ cond.replace("<=", ">") if "<=" in cond else cond.replace(">", "<=") for cond in conditions ] return " and ".join(conditions), discrepancy elif len(right_indices) < max_group_size and len(right_indices) > min_group_size: right_mean = dataframe.loc[right_indices, outcome].mean() left_mean = dataframe.loc[left_indices, outcome].mean() discrepancy = abs(left_mean - right_mean) if right_mean <= left_mean: conditions = [ cond.replace("<=", ">") if "<=" in cond else cond.replace(">", "<=") for cond in conditions ] return " and ".join(conditions), discrepancy else: return " and ".join(conditions), discrepancy else: return None, float("-inf") left_mean = dataframe.loc[left_indices, outcome].mean() right_mean = dataframe.loc[right_indices, outcome].mean() discrepancy = abs(left_mean - right_mean) if left_mean >= right_mean: return " and ".join(conditions), discrepancy else: conditions = [ cond.replace("<=", ">") if "<=" in cond else cond.replace(">", "<=") for cond in conditions ] return " and ".join(conditions), discrepancy # Not a leaf node, continue splitting feature = features[tree_model.tree_.feature[node]] # pyright: ignore threshold = tree_model.tree_.threshold[node] # pyright: ignore left_condition = f"{feature} <= {threshold}" right_condition = f"{feature} > {threshold}" # Traverse left and right left_query, left_discrepancy = traverse_tree( tree_model.tree_.children_left[node], # pyright: ignore depth + 1, conditions + [left_condition], ) right_query, right_discrepancy = traverse_tree( tree_model.tree_.children_right[node], # pyright: ignore depth + 1, conditions + [right_condition], ) if left_discrepancy == right_discrepancy: return left_query, left_discrepancy else: return ( (left_query, left_discrepancy) if left_discrepancy > right_discrepancy else (right_query, right_discrepancy) ) return traverse_tree()[0]
[docs] def get_optimal_queries_strings(self, X, y, model, min_group_size=10, n_groups=10, alpha=0.05): optimal_queries = {} # Fit the model ohe = OneHotEncoder(sparse=False) X_dummies = pd.DataFrame(ohe.fit_transform(X), columns=ohe.get_feature_names_out()) X_dummies.index = X.index model.fit(X_dummies, y) for group_id, variables in self.subgroups.items(): # pyright: ignore max_difference = float("-inf") optimal_query = None # Check if there is only one variable if len(variables) == 1: var = variables[0] # Generate all non-empty combinations of unique values for the single variable value_combinations = itertools.chain.from_iterable( itertools.combinations(X[var].unique(), r) for r in range(1, len(X[var].unique()) + 1) ) else: # Generate all relevant combinations for each variable variable_combinations = [generate_combinations_for_variable(X[var].unique()) for var in variables] # Generate combinations of these combinations across variables value_combinations = itertools.product(*variable_combinations) # Iterate over each combination for combo in value_combinations: query_parts = [] if len(variables) == 1: var = variables[0] if len(combo) == 1: query_parts.append(f"{var} == '{combo[0]}'") else: query_parts.append(f"{var} in {combo}") else: for var, vals in zip(variables, combo): if len(vals) == 1: query_parts.append(f"{var} == '{vals[0]}'") else: query_parts.append(f"{var} in {vals}") query = " and ".join(query_parts) # Filter the dataframe based on the query and calculate accuracy difference subgroup_X = X.query(query) if len(subgroup_X) >= min_group_size: subgroup_y = y[subgroup_X.index] subgroup_X_ohe = pd.DataFrame(ohe.transform(subgroup_X), columns=ohe.get_feature_names_out()) # pyright: ignore subgroup_X_ohe.index = subgroup_X.index accuracy_diff = self.calculate_accuracy_difference(X_dummies, y, model, subgroup_X_ohe, subgroup_y) if accuracy_diff > max_difference: max_difference = accuracy_diff optimal_query = query if optimal_query: optimal_queries[group_id] = optimal_query # Filter top subgroup_results = {} for group, condition in optimal_queries.items(): try: results_group = calculate_group_statistics_string(X, y, model, condition, ohe) significant_result = results_group["p_value_bootstrap"] < alpha subgroup_results[group] = {"results": results_group, "significant": significant_result} except Exception as e: print(f"Error with group {group}: {condition}", e) continue # Order subgroups by p-value subgroup_results = sorted( subgroup_results.items(), key=lambda x: x[1]["results"]["accuracy_diff"], reverse=True ) # Get the top n subgroups # Filter based on significant to only include when it is significant # subgroup_results = [subgroup for subgroup in subgroup_results if subgroup[1]['significant']] top_subgroups = subgroup_results[:n_groups] # Get only the queries of the top n subgroups top_queries = [query[1]["results"]["query"] for query in top_subgroups] # Convert to dictionary top_queries = {i: query for i, query in enumerate(top_queries)} return top_queries
[docs] def calculate_outcome_difference(self, y, full_y): """ Calculates the difference in the proportion of the most common outcome between the subgroup and the full dataset. :param y: The outcome variable for the subgroup. :param full_y: The outcome variable for the full dataset. :return: The difference in proportions. """ if y.empty or full_y.empty: return 0 # Get the most common outcome in the full dataset most_common_outcome = full_y.mode()[0] # Calculate the proportion of this outcome in both the subgroup and the full dataset subgroup_proportion = (y == most_common_outcome).mean() full_dataset_proportion = (full_y == most_common_outcome).mean() # Calculate the difference in proportions difference = abs(subgroup_proportion - full_dataset_proportion) return difference
[docs] def get_optimal_queries( self, X, y, model, outcome="y_failures", min_group_size=10, alpha=0.1, n_groups=10, test_for_min=True, max_group_size=float("inf"), ): """ Generates a list of query strings for splitting the dataframe into two subgroups where the difference in the outcome variable is maximized, based on up to three features. :param dataframe: A pandas DataFrame containing the data. :param features: A list of feature variable names (up to 3 features). :param outcome: The name of the outcome variable. :param min_group_size: The minimum size of each group. :param n_queries: The number of queries to generate. :return: A list of query strings for the subgroup where the outcome is maximized. """ dataframe = X.copy() # Calculate model failures y_pred = model.predict(X) y_failures = (y_pred != y).astype(int) dataframe[outcome] = y_failures optimal_queries = {} # Get groups group_variables = self.subgroups # Loop through the groups for group, condition in group_variables.items(): # pyright: ignore # Get the optimal query for each group optimal_query = self.get_optimal_split_query( dataframe, condition, outcome, min_group_size, test_for_min, max_group_size ) optimal_queries[group] = optimal_query self.optimal_queries = optimal_queries # Remove values which are None and which repeat themselves optimal_queries = {k: v for k, v in optimal_queries.items() if v is not None} # Loop over each query and if the query already exists, skip. Otherwise, add it to a new dictionary optimal_queries_unique = {} for group, query in optimal_queries.items(): if query not in optimal_queries_unique.values(): optimal_queries_unique[group] = query optimal_queries = optimal_queries_unique subgroup_results = {} for group, condition in optimal_queries.items(): try: results_group = calculate_group_statistics(X, y, model, condition) significant_result = results_group["p_value_bootstrap"] < alpha subgroup_results[group] = {"results": results_group, "significant": significant_result} except Exception as e: print(f"Error with group {group}: {condition}", e) continue # Order subgroups by p-value subgroup_results = sorted( subgroup_results.items(), key=lambda x: x[1]["results"]["accuracy_diff"], reverse=True ) # Get the top n subgroups top_subgroups = subgroup_results[:n_groups] # Get only the queries of the top n subgroups top_queries = [query[1]["results"]["query"] for query in top_subgroups] # Convert to dictionary top_queries = {i: query for i, query in enumerate(top_queries)} return top_queries
[docs] def calculate_accuracy_difference(self, X_tr, y, model, subgroup_X_tr, subgroup_y): """ Calculates the accuracy difference between the model's predictions on the full dataset and a specific subgroup. :param X_tr: Transformed predictor variables of the full dataset. :param y: Outcome variable of the full dataset. :param model: Trained model to make predictions. :param subgroup_X_tr: Transformed predictor variables of the subgroup. :param subgroup_y: Outcome variable of the subgroup. :return: The accuracy difference. """ try: # Calculate accuracy on the full dataset and the subgroup accuracy_dataset = accuracy_score(y, model.predict(pd.get_dummies(X_tr, drop_first=True))) accuracy_subgroup = accuracy_score(subgroup_y, model.predict(subgroup_X_tr)) except Exception as e: # Log the exception if needed print("Error in accuracy calculation:", e) return 0 # Compute the accuracy difference accuracy_diff = abs(accuracy_dataset - accuracy_subgroup) return accuracy_diff