Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/pcstable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
class PCStable:
graph = nx.Graph()

def __init__(self, data_path, method = 'chisq', independence_threshold = 0.01) -> None:
def __init__(self, data, method = 'chisq', independence_threshold = 0.01) -> None:
if data == None:
raise ValueError('data cannot be None, must be an instance of ConditionalIndependence or a path to training csv file')
if type(data) == str:
self.ci = ci(data, method)
elif type(data) == ci:
self.ci = data
# initialise empty fully connected graph
self.ci = ci(data_path, method)
self.variables = self.ci.rand_vars
self.graph = PCStable.fully_connected_graph(self.variables, self.graph)
self.immoralNodes = []
Expand Down Expand Up @@ -63,6 +68,24 @@ def create_valid_path_set(self, paths) -> list[Edge]:
edges.append(Edge(v_i, v_j, parents, conditional_independence_test=self.ci, threshold = self.independence_threshold))
return edges

def copy(self):
'''
Method to copy the PCStable object for temporary use
'''
copy = PCStable(
data = self.ci,
method = self.ci.test,
independence_threshold = self.independence_threshold
)

copy.variables = self.ci.rand_vars
copy.graph = PCStable.fully_connected_graph(self.variables, self.graph)
copy.immoralNodes = self.immoralNodes
copy.markovChains = self.markovChains
copy.all_nodes = self.all_nodes

return copy

# def independence_test(self):
# # TODO: SOME_CONDITION while the independence test continues to keep running

Expand Down