From c77dc4b849f6a7063d51c86d4ade25d7d4a57bad Mon Sep 17 00:00:00 2001
From: Andreas Makus <makus.andreas@gmail.com>
Date: Tue, 12 Jul 2016 17:05:30 +0200
Subject: [PATCH] finished class method that adds KB graphs to
 MDOproblem-instance; added helper functions that create edge-tuples and add
 node attributes; created new class MDPgraph that serves as super-class for
 other graph classes (MCG, FPG, PSG); added graph-contraction method to super
 class; added Graph-subclass MCG; added method that generates graph object for
 the MCG

---
 pyKADMOS/KB_CPACS/EMWET-info.json   |    4 +-
 pyKADMOS/KB_CPACS/Q3D_FLC-info.json |    4 +-
 pyKADMOS/MDOgraph.py                |   81 ++
 pyKADMOS/MDOproblem.py              | 1289 +++++++++++++++++++++++----
 pyKADMOS/testRun.py                 |   14 +-
 5 files changed, 1226 insertions(+), 166 deletions(-)
 create mode 100644 pyKADMOS/MDOgraph.py

diff --git a/pyKADMOS/KB_CPACS/EMWET-info.json b/pyKADMOS/KB_CPACS/EMWET-info.json
index dd820133a..afffdd9d0 100644
--- a/pyKADMOS/KB_CPACS/EMWET-info.json
+++ b/pyKADMOS/KB_CPACS/EMWET-info.json
@@ -3,8 +3,6 @@
                   "version":1.0,
                   "creator":"Ali Elham",
                   "description": "provide description here."},
- "execution_info":{  "run time (s)":20,
-                      "fidelity level":"L1",
-                      "precision":0.05}
+ "execution_info":{ "runTime":20, "fidelity":"L1", "precision":0.05}
 }
 
diff --git a/pyKADMOS/KB_CPACS/Q3D_FLC-info.json b/pyKADMOS/KB_CPACS/Q3D_FLC-info.json
index 509c67504..27e6afe77 100644
--- a/pyKADMOS/KB_CPACS/Q3D_FLC-info.json
+++ b/pyKADMOS/KB_CPACS/Q3D_FLC-info.json
@@ -3,8 +3,6 @@
                   "version":1.0,
                   "creator":"Mengmeng Zhang",
                   "description": "In this case Q3D only performs the inviscid VLM analysis in order to provide the loads per wing strip in an aeroDataSetForLoads associated with the flightLoadCase. These loads can then be used for further analysis with other tools. A tool that fits well to this analysis is TUD's wing weight estimation tool EMWET."},
- "execution_info":{  "run time (s)":20,
-                      "fidelity level":"L1",
-                      "precision":0.05}
+ "execution_info":{ "runTime":20, "fidelity":"L1", "precision":0.05}
 }
 
diff --git a/pyKADMOS/MDOgraph.py b/pyKADMOS/MDOgraph.py
new file mode 100644
index 000000000..e53f7c743
--- /dev/null
+++ b/pyKADMOS/MDOgraph.py
@@ -0,0 +1,81 @@
+import networkx as nx
+from pyKADMOS.MDOproblem import MdoProblemInit
+
+class Graph(object):
+    """
+    Super class for graph subclasses (MCG, FPG, PSG), contains common functions.
+    """
+
+    def __init__(self, knowledgeBase):
+        self.knowledgeBase = knowledgeBase
+
+        assert isinstance(knowledgeBase, MdoProblemInit), 'The required argument to instantiate a graph must be of type "MDOknowledgeBase".'
+
+    def get_contracted_graph(self, graph, contractionLevel):
+        """
+        Function to contract the nodes of a graph to a given xpath level.
+
+        :param graph: graph to be contracted
+        :param contractionLevel: int from 0 (highest level) to X (lowest level existing in XML schema)
+        :return: graph with contracted nodes
+        """
+        # check input
+        assert (contractionLevel >= 0) & isinstance(contractionLevel, (int, long)), "Contraction level should be non-negative integer."
+
+        # create clean copy of graph
+        contracted_graph = nx.DiGraph(graph)
+
+        # iterate over each graph node and contract nodes where required.
+        for node, data in contracted_graph.nodes_iter(data=True):
+            if data['level'] > contractionLevel:
+
+                # Find higher level sibling at required level; split node at separator
+                split_xpath = node.split('/')[1:] # remove first entry (this is empty since string starts with '/'
+
+                # Create xpath of required node
+                required_node = '/' + '/'.join(split_xpath[0:contractionLevel + 1])
+
+                # Add node if not in contracted graph
+                if required_node not in contracted_graph:
+                    contracted_graph.add_node(required_node,
+                                              shape = 'd',
+                                              category = 'variable group',
+                                              label = split_xpath[contractionLevel],
+                                              level = contractionLevel)
+
+                # Contract node with its higher level sibling
+                contracted_graph = nx.contracted_nodes(contracted_graph, required_node, node, self_loops=True)
+
+        #TODO: include rules to manage a higher contraction level
+        return contracted_graph
+
+class MCG(Graph):
+
+    def __init__(self, knowledgeBase):
+        super(self.__class__, self).__init__(knowledgeBase)
+
+    def get_graph(self, contractionLevel = None):
+        """
+        Function to create Maximal Connectivity Graph (Pate, 2014) by composing a list of graphs.
+
+        :return: maximal connectivity graph (MCG)
+        """
+        functionGraphs = self.knowledgeBase.functionGraphs
+        
+        # if contracted level procided, reduce function Graphs to that level
+        if contractionLevel is not None:
+            conGraphs = [] # initialize contracted graph list
+            for graph in functionGraphs:
+                conGraph = self.get_contracted_graph(graph, contractionLevel)
+                conGraphs.append(conGraph)
+            functionGraphs = conGraphs # replace function graphs with contracted function graphs
+
+        MCG = nx.DiGraph()  # initilaze MCG
+        for g in functionGraphs: # TODO: this will be contracted graphs
+            MCG = nx.compose(MCG, g)
+
+        return MCG
+
+
+
+
diff --git a/pyKADMOS/MDOproblem.py b/pyKADMOS/MDOproblem.py
index 1b2fc7744..4bdc17de7 100644
--- a/pyKADMOS/MDOproblem.py
+++ b/pyKADMOS/MDOproblem.py
@@ -20,27 +20,31 @@ class MDOproblem:
         """
         Standard class __init__ function. Includes input checks.
 
+
+
         :param knowledge_base: name of the folder that contains the knowledge base
         """
-        self.knowledge_base = knowledge_base
+        # TODO: ADD DETAILED CLASS DESCRIPTION
+
+        self.knowledgeBase = knowledge_base
 
         # Hardcoded Values
-        self.FIXED = {} # uppercase because hardcoded
-        self.FIXED["ALLOWED_FUNCTION_TYPES"] = ["Tool", "Objective", "Constraint"]
-        self.FIXED["GEN_INFO"] = ["name", "type", "version", "creator", "description"]
-        self.FIXED["EXEC_INFO"] = ["runTime", "fidelity", "precision"]
+        self.FIXEDVAL = {} # uppercase because hardcoded
+        self.FIXEDVAL["ALLOWED_FUNCTION_TYPES"] = ["Tool", "Objective", "Constraint"]
+        self.FIXEDVAL["GEN_INFO"] = ["name", "type", "version", "creator", "description"]
+        self.FIXEDVAL["EXEC_INFO"] = ["runTime", "fidelity", "precision"]
 
         print "\n INPUT CHECKS \n ------------"
 
         # Check if knowledge base folder exists
         if os.path.exists(knowledge_base):
-            print "Knowledge base '%s' found." % self.knowledge_base
+            print "Knowledge base '%s' found." % self.knowledgeBase
         else:
-            raise IOError("Specified knowledge base '%s' does not exist." % self.knowledge_base)
+            raise IOError("Specified knowledge base '%s' does not exist." % self.knowledgeBase)
 
         # Read files in the KB
         print "Reading files in the knowledge base."
-        KB_files = [f for f in os.listdir(self.knowledge_base) if isfile(join(self.knowledge_base, f))]
+        KB_files = [f for f in os.listdir(self.knowledgeBase) if isfile(join(self.knowledgeBase, f))]
 
         # Get the data schema in knowledge base, save in instance
         self._get_data_schema(KB_files)
@@ -48,40 +52,45 @@ class MDOproblem:
         # Get input and output files, save in instance
         self._get_in_out_files(KB_files)
 
-        if self.knowledge_base == "KB_CPACS":
+        if self.knowledgeBase == "KB_CPACS":
             # Get Read-Write File, save in instance
             self._get_base_file(KB_files)
 
             ignoreNodes = ['toolspecific']
             self._check_base_against_schema(ignoreNodes)
 
+            # validate all input and output files versus the base file
             self._validate_in_out_files()
 
-        self._get_function_data()
+            # get all data from the input and output files, and save in instance
+            self._get_function_data()
+
+            # get graphs for all functions and save them in instance
+            self._get_kb_graphs()
 
     def _get_data_schema(self, kb_files):
         """
         This function retrieves the data schema (.xsd) file from the KB folder and stores filename in instance.
         :param: kb_files
-        :return:  self.data_schema
+        :return:  self.dataSchema
         """
 
         # Determine name of XML Schema file
         xsd_schema_found = False
         for file_name in kb_files:
             if file_name[-4:] == '.xsd' and not xsd_schema_found:
-                self.data_schema = file_name
+                self.dataSchema = file_name
                 xsd_schema_found = True
             elif file_name[-4:] == '.xsd' and xsd_schema_found:
                 raise IOError('Multiple XML Schemas (.xsd files) found in the knowledge base (%s). '
                               'Only one .xsd file is allowed per knowledge base.'
-                              % self.knowledge_base)
+                              % self.knowledgeBase)
         if not xsd_schema_found:
             raise IOError('No XML Schemas (.xsd files) found in the knowledge base (%s). '
                           'A single .xsd file is required per knowledge base.'
-                          % self.knowledge_base)
+                          % self.knowledgeBase)
         else:
-            print "XML Schema '%s' found." % self.data_schema
+            print "XML Schema '%s' found." % self.dataSchema
 
         return
 
@@ -89,25 +98,25 @@ class MDOproblem:
         """
         This function saves all files in class instance. It ensures that all names
         :param: kb_files
-        :return: self.function_files
+        :return: self.functionFiles
         """
 
         # Read input and output XML files and info json files
-        self.function_files = dict(input=[], output=[], info=[])
-        self.function_files['input'] = [file for file in kb_files if file[-10:] == '-input.xml']
-        self.function_files['output'] = [file for file in kb_files if file[-11:] == '-output.xml']
-        self.function_files['info'] = [file for file in kb_files if file[-10:] == '-info.json']
-        print "Input files found: %s" % self.function_files['input']
-        print "Output files found: %s" % self.function_files['output']
-        print "Info files found: %s" % self.function_files['info']
-
-        assert len(self.function_files['input']) == len(self.function_files['output']), 'Amount of function input and output XML files does not match.'
-        assert len(self.function_files['input']) == len(self.function_files['info']), 'Amount of function input XML files and info json files does not match.'
+        self.functionFiles = dict(input=[], output=[], info=[])
+        self.functionFiles['input'] = [file for file in kb_files if file[-10:] == '-input.xml']
+        self.functionFiles['output'] = [file for file in kb_files if file[-11:] == '-output.xml']
+        self.functionFiles['info'] = [file for file in kb_files if file[-10:] == '-info.json']
+        print "Input files found: %s" % self.functionFiles['input']
+        print "Output files found: %s" % self.functionFiles['output']
+        print "Info files found: %s" % self.functionFiles['info']
+
+        assert len(self.functionFiles['input']) == len(self.functionFiles['output']), 'Amount of function input and output XML files does not match.'
+        assert len(self.functionFiles['input']) == len(self.functionFiles['info']), 'Amount of function input XML files and info json files does not match.'
         print 'Amount of function input, output, and info files accepted.'
 
          # check info files for name and type TODO: move this into a test function
-        for file in self.function_files["info"]:
-            with open(self.knowledge_base + '/' + file) as info:
+        for file in self.functionFiles["info"]:
+            with open(self.knowledgeBase + '/' + file) as info:
                 infoData = json.load(info)
 
         # name assertion
@@ -121,12 +130,12 @@ class MDOproblem:
             # typeCond = False
             # if not isinstance(infoData["general_info"]["type"], basestring):
             #     raise TypeError("Function type in {} must be a string.".format(file))
-            # for funcType in self.FIXED["ALLOWED_FUNCTION_TYPES"]:
+            # for funcType in self.FIXEDVAL["ALLOWED_FUNCTION_TYPES"]:
             #     if infoData["general_info"]["type"].lower() == funcType.lower():
             #         typeCond = True
             #         break
             # if not typeCond:
-            #     raise TypeError("Function type in {} must be one of these: ".format(file) + ", ".join(self.FIXED["ALLOWED_FUNCTION_TYPES"]))
+            #     raise TypeError("Function type in {} must be one of these: ".format(file) + ", ".join(self.FIXEDVAL["ALLOWED_FUNCTION_TYPES"]))
 
         # Additional checks
         # TODO: Add checks on input given (naming conventions, required data, etcetera).
@@ -174,11 +183,11 @@ class MDOproblem:
         """
 
         # Parse the XML Schema
-        xmlschema_doc = etree.parse(self.knowledge_base + '/' + self.data_schema)
+        xmlschema_doc = etree.parse(self.knowledgeBase + '/' + self.dataSchema)
         xmlschema = etree.XMLSchema(xmlschema_doc)
 
         # Parse the Read-Write File
-        tree = etree.parse(self.knowledge_base + '/' + self.baseFile)
+        tree = etree.parse(self.knowledgeBase + '/' + self.baseFile)
 
         if ignoreNodes is not None:
 
@@ -211,20 +220,23 @@ class MDOproblem:
     def _validate_in_out_files(self):
         """
         Class method that validates all present input and output XML files by comparing each child node for
-        equvalence in base file.
+        "equivalence" in base file. The base file is scanned for the input and output child nodes of similar tag,
+        and matching child nodes' ancestors are compared for tag and UID attribute. Only if the child node with
+        matching ancestors is found in the base file is the node valid. Only when all nodes are validated
+        is the file valid.
 
         :return: IOError
         """
 
         leafNodesMissing = {}
-        baseTree = etree.parse(self.knowledge_base + '/' + self.baseFile)
+        baseTree = etree.parse(self.knowledgeBase + '/' + self.baseFile)
 
         print "Validating input and output XML files..."
 
         fileType = ['input', 'output']  #
         for type in fileType:
-            for xml_file in self.function_files[type]:
-                fileTree = etree.parse(self.knowledge_base + '/' + xml_file)
+            for xml_file in self.functionFiles[type]:
+                fileTree = etree.parse(self.knowledgeBase + '/' + xml_file)
                 for el in fileTree.iter():
                     if not el.getchildren():
                         nodeIsEquivalent = False
@@ -233,7 +245,7 @@ class MDOproblem:
                         if foundNodes:
                             for elem in foundNodes:
                                 # compare all ancestor tags and attributes
-                                nodeIsEquivalent = self._ensureEquivalentAncestors(el, elem)
+                                nodeIsEquivalent = self._ensure_equivalent_ancestors(el, elem)
                                 if nodeIsEquivalent:
                                     break
 
@@ -243,10 +255,11 @@ class MDOproblem:
                                 leafNodesMissing[xml_file] = []
                             leafNodesMissing[xml_file].append(fileTree.getpath(el))
 
-        self._printNodes(leafNodesMissing)
+        self._print_nodes(leafNodesMissing)
         assert len(leafNodesMissing) == 0, "There are missing nodes in the base file!"
+        print "Validation of input and output XML file is complete. Everything in order."
 
-    def _ensureEquivalentAncestors(self, treeNode, baseNode):
+    def _ensure_equivalent_ancestors(self, treeNode, baseNode):
         """
         Class method that compares all ancestors for two given nodes for tag and attribute equivalence.
         Nodes must ElementTree objects.
@@ -274,18 +287,18 @@ class MDOproblem:
         eq = True
         return eq
 
-    def _printNodes(self, obj):
+    def _print_nodes(self, obj):
         if type(obj) == dict:
             for k, v in obj.items():
                 if hasattr(v, '__iter__'):
                     print k
-                    self._printNodes(v)
+                    self._print_nodes(v)
                 else:
                     print '%s : %s' % (k, v)
         elif type(obj) == list:
             for v in obj:
                 if hasattr(v, '__iter__'):
-                    self._printNodes(v)
+                    self._print_nodes(v)
                 else:
                     print v
         else:
@@ -322,34 +335,31 @@ class MDOproblem:
 
         self.functionData = []
 
-        for file in self.function_files["info"]:
+        for file in self.functionFiles["info"]:
             # initiate a dict for each function
             funcDict = {'info':{   'generalInfo':{},
                                     'executionInfo': {}
-                                }
-                        }
-
-            # TODO: enforce naming conventions for general_info and execution_info; now only matching names included!
+                                }}
 
-            with open(self.knowledge_base + '/' + file) as info:
+            with open(self.knowledgeBase + '/' + file) as info:
                 infoData = json.load(info)
 
-            for inf in self.FIXED["GEN_INFO"]:
-                try:
-                    infoVal = infoData["general_info"].get(inf)
-                    if infoVal is not None:
-                        funcDict['info']['generalInfo'][inf] = infoVal
+            # add function info from file to funcDict
+            for inf in self.FIXEDVAL["GEN_INFO"]:
+                if inf in ['name', 'type']:
+                    # make sure that function name and type is defined, is string
+                    assert isinstance(infoData["general_info"].get(inf), basestring), "Function name and type must be defined in the info-file, and must be a string!"
+                # add info if given
+                try: funcDict['info']['generalInfo'][inf] = infoData["general_info"].get(inf)
                 except KeyError: continue
 
-            for inf in self.FIXED["EXEC_INFO"]:
-                try:
-                    infoVal =infoData["execution_info"].get(inf)
-                    if infoVal is not None:
-                        funcDict['info']['executionInfo'][inf] = infoVal
+            for inf in self.FIXEDVAL["EXEC_INFO"]:
+                try: funcDict['info']['executionInfo'][inf] = infoData["execution_info"].get(inf)
                 except KeyError: continue
 
+            # TODO: enforce naming conventions for general_info and execution_info; important for later use! All tests should be done here
 
-            # get input file
+            # get input and output data
             funcName = file[:-10] # slice -info.json
             funcDict['input'] = self._get_in_out_data(funcName, 'input')
             funcDict['output'] = self._get_in_out_data(funcName, 'output')
@@ -368,7 +378,7 @@ class MDOproblem:
         """
         dic = {"allXpaths":[], "leafNodes": []}
 
-        tree = etree.parse(self.knowledge_base + '/' + funcName + "-" + inOut + ".xml")
+        tree = etree.parse(self.knowledgeBase + '/' + funcName + "-" + inOut + ".xml")
 
         for el in tree.iter():
             d1, d2 = {}, {}
@@ -376,7 +386,7 @@ class MDOproblem:
             d1['xpath'] = path
             d1['tag'] = el.tag
             d1['attributes'] = el.attrib
-            if not el.getchildren():
+            if not el.getchildren(): # if child node
                 d2['xpath'] = path
                 d2['tag'] = el.tag
                 d2['attributes'] = el.attrib
@@ -391,6 +401,118 @@ class MDOproblem:
 
         return dic
 
+    def _get_kb_graphs(self):
+        """
+        This class method generates all graphs for all present functions in the knowledge base.
+
+        :return: self.functionGraphs
+        """
+        funcList = [self.functionData[i]['info']["generalInfo"]['name'] for i in range(len(self.functionData))]
+        graphList = []
+
+        for func in funcList:
+            graphList.append(self._get_function_graph(func))
+
+        self.functionGraphs = graphList
+        return
+
+    def _get_function_graph(self, funcName, inOut=None):
+        """
+        This function builds a directed graph (object) for the specified function using the "networkx" package. If inOut
+        argument is specified, only the input or output of the function will be included in the graph, otherwise both.
+
+        :param: funcName: function name for which the graph is generated; must be present in knowledge base.
+        :param: inOut: default = None; if specified, must be "input" or "output" string. Specification of this argument enables the generation of the function graph with only input or output variables.
+        :return: functionGraph
+        """
+        assert isinstance(funcName, basestring), 'Provided function name must be a string!'
+
+        # assert funcName exists and get index of function in self.functionData list
+        funcIndex = None
+        for idx, funcDict in enumerate(self.functionData):
+            if funcDict['info']['generalInfo']['name'] == funcName:
+                funcIndex = idx #funcIndex is index of the function in list
+                break
+        assert funcIndex is not None, "The provided function name can not be found in knowledge base."
+
+        # assert inOut, if defined, is string and either input or output
+        if inOut is not None:
+            assert isinstance(inOut, basestring), "inOut argument must be a string if specified."
+            assert inOut.lower() in ["input", "output"], "inOut argument must be either 'input' or 'output'."
+
+        # initiate directed graph and list of edges
+        DG, edges = nx.DiGraph(), []
+
+        # add edges to list, then to graph
+        if inOut is not None:
+            edges += self._create_edge_tuples(funcIndex, inOut)
+        else:
+            for io in ['input', 'output']:
+                edges += self._create_edge_tuples(funcIndex, io)
+        DG.add_edges_from(edges)
+
+        # add node attributes to graph
+        self._add_node_attribs(funcIndex, DG)
+
+        return DG
+
+    def _create_edge_tuples(self, funcIndex, inOut):
+        """
+        This function creates a list of edge tuples in order to generate a graph.
+
+        :param funcIndex: index of function in list of tool dicts in self.functionData
+        :param inOut: specified whether input or output nodes, None adds all to graph
+        :return: graphEdges: list of edges to build graph
+        """
+        graphEdges = []
+        fdata = self.functionData[funcIndex]
+        funcName = fdata['info']['generalInfo']['name']
+
+        for leafNode in fdata[inOut.lower()]['leafNodes']:
+            if inOut == 'input':
+                tpl = (leafNode['xpath'], funcName)  # var --> tool
+            else:
+                tpl = (funcName, leafNode['xpath'])  # tool --> var
+            graphEdges.append(tpl)
+        # to include additional information in edges, adjust these loops ! check doc on networkx!
+
+        return graphEdges
+
+    def _add_node_attribs(self, funcIndex, G):
+        """
+        Function that adds node attributes to the nodes of the graph.
+
+        :param funcIndex: index of function in list of tool dicts in self.functionData
+        :param G: Grpah w/o attribs
+        :return: Graph w/ attribs
+        """
+
+        fdata = self.functionData[funcIndex]
+        funcName = fdata['info']['generalInfo']['name']
+
+        for node in G.nodes_iter():
+
+            if node == funcName:
+                G.node[node]['shape'] = 's'  # square for functions
+                G.node[node]['category'] = 'function' # TODO: this can be adjusted. Maybe differentiate between tool, constraint, objective etc? This info can then be used later on!
+                G.node[node]['label'] = funcName
+                G.node[node]['level'] = None
+                try: # add available exectuion info to function node
+                    for inf in fdata['info']['executionInfo']:
+                        G.node[node][inf] = fdata['info']['executionInfo'][inf]
+                except KeyError: pass # if not present, continue
+
+            else:
+                G.node[node]['shape'] = 'o'  # circle for variables
+                G.node[node]['category'] = 'variable'
+                G.node[node]['label'] = node.split('/')[-1]
+                G.node[node]['level'] = node.count('/') - 1
+                G.node[node]['execution time'] = 1 # TODO: Why is execution time included in variable info?
+
+        return G
+
+    # TODO: >>>>>>>>>>>>>>>>>> Cut the class here ?? <<<<<<<<<<<<<<<<<<<<<<
+
     def get_function_names(self):
         # TODO: This function has to be re-written since the function names are not present in the in-out xml files!
         """
@@ -543,94 +665,6 @@ class MDOproblem:
             idx += 1
         return common_nodes
 
-    def _get_function_graph(self, functionName, inOut=None):
-        """
-        This function builds a directed graph (object) for the specified function using the "networkx" package.
-
-        :param: functionName: function name for which the graph is generated; must be present in knowledge base.
-        :param: inOut: default = None; if specified, must be "input" or "output" string. Specification of this argument enables the generation of the function graph with only inout or output variables.
-        :return: functionGraph
-        """
-        assert isinstance(functionName, basestring), 'Provided function name must be a string!'
-
-        # assert funcName exists and get index of function in self.functionData list
-        funcIndex = None
-        for idx, funcDict in enumerate(self.functionData):
-            if funcDict['info']['generalInfo']['name'] == functionName:
-                funcIndex = idx #funcIndex is index of the function in list
-                break
-        assert funcIndex is not None, "The provided function name can not be found in knowledge base."
-
-        # assert inOut, if defined, is string and either input or output
-        if inOut is not None:
-            assert isinstance(inOut, basestring), "inOut argument must be a string if specified."
-            assert inOut.lower() in ["input", "output"], "inOut argument must be either 'input' or 'output'."
-
-        # initiate directed graph and list of edges
-        DG, edges = nx.DiGraph(), []
-
-        # add edges to list, then to graph
-        if inOut is not None:
-            edges += self._create_edge_tuples(funcIndex, inOut, functionName)
-        else:
-            for io in ['input', 'output']:
-                edges += self._create_edge_tuples(funcIndex, io, functionName)
-        DG.add_edges_from(edges)
-
-        # add node attributes to graph
-        self._add_node_attribs(DG, funcIndex)
-
-        return DG
-
-    def _create_edge_tuples(self, funcIndex, inOut, functionName):
-        """
-        This function creates a list of edge tuples in order to generate a graph.
-
-        :param funcIndex: index of function in list of tool dicts in self.functionData
-        :param inOut: specified whether input or output nodes, None adds all to graph
-        :return: graphEdges: list of edges to build graph
-        """
-        graphEdges = []
-        fdata = self.functionData[funcIndex]
-        for leafNode in fdata[inOut.lower()]['leafNodes']:
-            if inOut == 'input':
-                tpl = (leafNode['xpath'], functionName)  # (variable, tool)
-            else:
-                tpl = (functionName, leafNode['xpath'])  # (tool, variable)
-            graphEdges.append(tpl)
-        # TODO: adjust these loops to include additional information in edges! check doc on networkx!
-
-        return graphEdges
-
-    def _add_node_attribs(self, G, funcIndex):
-        """
-        Function that add node attributes to the nodes of the graph.
-
-        :param G: Considered graph
-        :param funcIndex: index of function in list of tool dicts in self.functionData
-        :return: Graph
-        """
-        # TODO >>>>>> CONITNUE HERE <<<<<<<<
-        for node, data in G.nodes_iter(data=True):
-            if node == function_input_analysis['properties'][0]['attributes']['tool_name']: # replace this, rest looks okay!
-                G.node[node]['shape'] = 's'  # square for functions
-                G.node[node]['category'] = 'function'
-                G.node[node]['label'] = function_input_analysis['properties'][0]['attributes']['tool_name']
-                G.node[node]['level'] = None
-                with open(self.knowledge_base + '/' + json_file_info) as data_file:
-                    if 'execution time' in G.node[node]:
-                        G.node[node]['execution time'] = int(1000 * json.load(data_file)['executing_info']
-                        ['run time (s)'])
-                    else:
-                        G.node[node]['execution time'] = 1
-            else:
-                G.node[node]['shape'] = 'o'  # circle for variables
-                G.node[node]['category'] = 'variable'
-                G.node[node]['label'] = node.split('/')[-1]
-                G.node[node]['level'] = node.count('/') - 1
-                G.node[node]['execution time'] = 1
-        return G
-
     def get_function_graph(self, function_name):
         """
         Function to automatically create the digraph of the function element.
@@ -685,8 +719,6 @@ class MDOproblem:
             graph_list.append(self.get_function_graph(function_name))
         return graph_list
 
-    # TODO: >>>>>>>>>>>>>>>>>> Cut the class here <<<<<<<<<<<<<<<<<<<<<<
-
     def get_MCG(self):
         """
         Function to create Maximal Connectivity Graph (Pate, 2014) by composing a list of graphs.
@@ -1254,8 +1286,957 @@ class MDOproblem:
             PSG_data.add_edge(node,'Optimizer')
 
         return {'data flow':PSG_data,'process flow':PSG_process}
+    
+class MdoProblemInit(object):
+    """
+    Class that can be used to formally specify an MDO problem and analyze it based on graph theoretical analyses.
+    """
+
+    # TODO: ADD DETAILED CLASS DESCRIPTION
+    def __init__(self, knowledge_base):
+        """
+        Standard class __init__ function. Includes input checks.
+
+        :param knowledge_base: name of the folder that contains the knowledge base
+        """
+
+        self.knowledgeBase = knowledge_base
+
+        # Hardcoded Values stored in dict
+        self.FIXEDVAL = {} # uppercase because hardcoded
+        self.FIXEDVAL["ALLOWED_FUNCTION_TYPES"] = ["Tool", "Objective", "Constraint"] # allowed function types
+        self.FIXEDVAL["GEN_INFO"] = ["name", "type", "version", "creator", "description"] # prescribed general info
+        self.FIXEDVAL["EXEC_INFO"] = ["runTime", "fidelity", "precision"] # prescribed execution info
+        self.FIXEDVAL["IGNORE_VALID"] = ['toolspecific'] # nodes to ignore in basefile validation
+
+        print "\n INPUT CHECKS \n ------------"
+
+        # Check if knowledge base folder exists
+        if os.path.exists(knowledge_base):
+            print "Knowledge base '%s' found." % self.knowledgeBase
+        else:
+            raise IOError("Specified knowledge base '%s' does not exist." % self.knowledgeBase)
+
+        # Read files in the KB
+        print "Reading files in the knowledge base."
+        KB_files = [f for f in os.listdir(self.knowledgeBase) if isfile(join(self.knowledgeBase, f))]
+
+        # Get the data schema in knowledge base, save in instance
+        self._get_data_schema(KB_files)
+
+        # Get input and output files, save in instance
+        self._get_in_out_files(KB_files)
+
+        # Get Read-Write File, save in instance
+        self._get_base_file(KB_files)
+
+        # validate base file against provided xml schema
+        self._check_base_against_schema(self.FIXEDVAL["IGNORE_VALID"])
+
+        # validate all input and output files versus the base file
+        self._validate_in_out_files()
+
+        # get all data from the input and output files, and save in instance
+        self._get_function_data()
+
+        # get graphs for all functions and save them in instance
+        self._get_kb_graphs()
+
+
+    def _get_data_schema(self, kb_files):
+        """
+        This function retrieves the data schema (.xsd) file from the KB folder and stores filename in instance.
+        :param: kb_files
+        :return:  self.dataSchema
+        """
+
+        # Determine name of XML Schema file
+        xsd_schema_found = False
+        for file_name in kb_files:
+            if file_name[-4:] == '.xsd' and not xsd_schema_found:
+                self.dataSchema = file_name
+                xsd_schema_found = True
+            elif file_name[-4:] == '.xsd' and xsd_schema_found:
+                raise IOError('Multiple XML Schemas (.xsd files) found in the knowledge base (%s). '
+                              'Only one .xsd file is allowed per knowledge base.'
+                              % self.knowledgeBase)
+        if not xsd_schema_found:
+            raise IOError('No XML Schemas (.xsd files) found in the knowledge base (%s). '
+                          'A single .xsd file is required per knowledge base.'
+                          % self.knowledgeBase)
+        else:
+            print "XML Schema '%s' found." % self.dataSchema
+
+        return
+
+    def _get_in_out_files(self, kb_files):
+        """
+        This function saves all files in class instance. It ensures that all names
+        :param: kb_files
+        :return: self.functionFiles
+        """
+
+        # Read input and output XML files and info json files
+        self.functionFiles = dict(input=[], output=[], info=[])
+        self.functionFiles['input'] = [file for file in kb_files if file[-10:] == '-input.xml']
+        self.functionFiles['output'] = [file for file in kb_files if file[-11:] == '-output.xml']
+        self.functionFiles['info'] = [file for file in kb_files if file[-10:] == '-info.json']
+        print "Input files found: %s" % self.functionFiles['input']
+        print "Output files found: %s" % self.functionFiles['output']
+        print "Info files found: %s" % self.functionFiles['info']
+
+        assert len(self.functionFiles['input']) == len(self.functionFiles['output']), 'Amount of function input and output XML files does not match.'
+        assert len(self.functionFiles['input']) == len(self.functionFiles['info']), 'Amount of function input XML files and info json files does not match.'
+        print 'Amount of function input, output, and info files accepted.'
+
+         # check info files for name and type TODO: move this into a test function
+        for file in self.functionFiles["info"]:
+            with open(self.knowledgeBase + '/' + file) as info:
+                infoData = json.load(info)
+
+        # name assertion
+        if not isinstance(infoData["general_info"]["name"], basestring):
+            raise TypeError("Function name in {} must be a string.".format(file))
+        if len(infoData["general_info"]["name"]) < 1:
+            raise ValueError("Function name in {} must be non-empty string.".format(file))
 
+        # TODO: incorporte tool type for sellar and simple problems!
+            # # type assertion
+            # typeCond = False
+            # if not isinstance(infoData["general_info"]["type"], basestring):
+            #     raise TypeError("Function type in {} must be a string.".format(file))
+            # for funcType in self.FIXEDVAL["ALLOWED_FUNCTION_TYPES"]:
+            #     if infoData["general_info"]["type"].lower() == funcType.lower():
+            #         typeCond = True
+            #         break
+            # if not typeCond:
+            #     raise TypeError("Function type in {} must be one of these: ".format(file) + ", ".join(self.FIXEDVAL["ALLOWED_FUNCTION_TYPES"]))
+
+        # Additional checks
+        # TODO: Add checks on input given (naming conventions, required data, etcetera).
+        # TODO: Add check on unique tool names in input and output XMLs.
+        # TODO: Add check on completeness of files for each function name.
+
+        return
+
+    def _get_base_file(self, kb_files):
+        """
+        This function finds the CPACS base (read-write) file and saves it to instance
+        :return: self.baseFile
+        """
+
+        # define tool file pattern for name matching, save basefile to instance
+        basePattern = r"(-base.xml)$"
+        self.baseFile = None
+        foundBaseFile = False
+        for file in kb_files:
+            matchObj = re.search(basePattern, file)
+            if matchObj and foundBaseFile == False:
+                self.baseFile = file
+                foundBaseFile = True
+                print "Base file {} found.".format(file)
+            elif matchObj and foundBaseFile == True:
+                raise IOError("Multiple '-base.xml' files found! Please ensure only one file present.")
+
+        assert self.baseFile is not None, "No '-base.xml' found! Please provide a '-base.xml' file."
+
+        return
+
+    def _check_base_against_schema(self, ignoreNodes=None):
+        """
+        Check the read-write XML file in the knowledge base against the XML Schema.
+        Argument is list/tuple of nodes to ignore in validation. Root node can not be ignored.
+
+        :param: ignoreNodes: iterable of nodes to be ignored in validation (must be list or tuple)
+        :rtype: Error
+        """
+        # TODO: This function must be re-written for a different puropsose:
+        """
+        The problem is that the provided schema requires a certain structure of the XML, which makes it impossible to
+        minimize the amount of nodes required in the input/output XMLs. This function should only check the
+        read-write-XML, and a separate function will be written to "validate" the in-out XMLs.
+        """
+
+        # Parse the XML Schema
+        xmlschema_doc = etree.parse(self.knowledgeBase + '/' + self.dataSchema)
+        xmlschema = etree.XMLSchema(xmlschema_doc)
+
+        # Parse the Read-Write File
+        tree = etree.parse(self.knowledgeBase + '/' + self.baseFile)
+
+        if ignoreNodes is not None:
+
+            # making sure that input is iterable list or tuple, not basestring
+            if not isinstance(ignoreNodes, (list, tuple)):
+                raise IOError('Argument "ignoreNodes" not list or tuple.')
+
+            # Remove nodes that should not be validated
+            root = tree.getroot()
+            for ignoreNode in ignoreNodes:
+                for elem in root.iter():
+                    if (elem.tag == ignoreNode) & (elem.tag != root.tag): #make sure root can not be removed
+                        parent = elem.getparent()
+                        parent.remove(elem)
+
+        # Validate XML file against the given schema
+        # TODO: Need to make sure that valid file is presented >> VERY IMPORTANT!!!
+        # xmlschema.assertValid(tree)
+        # print '\n XML files successfully validated against schema. \n'
+
+        # TODO: REMOVE THIS ONE VALID FILE PROVIDED!
+        baseFileValid = xmlschema.validate(tree)
+        if baseFileValid:
+            print 'The base file is valid!'
+        else:
+            print 'Could not validate base file!'
+
+        return
+
+    def _validate_in_out_files(self):
+        """
+        Class method that validates all present input and output XML files by comparing each child node for
+        equvalence in base file.
+
+        :return: IOError
+        """
+
+        leafNodesMissing = {}
+        baseTree = etree.parse(self.knowledgeBase + '/' + self.baseFile)
+
+        print "Validating input and output XML files...",
+
+        fileType = ['input', 'output']  #
+        for type in fileType:
+            for xml_file in self.functionFiles[type]:
+                fileTree = etree.parse(self.knowledgeBase + '/' + xml_file)
+                for el in fileTree.iter():
+                    if not el.getchildren():
+                        nodeIsEquivalent = False
+                        findNode = './/' + str(el.tag)
+                        foundNodes = baseTree.findall(findNode)
+                        if foundNodes:
+                            for elem in foundNodes:
+                                # compare all ancestor tags and attributes
+                                nodeIsEquivalent = self._ensure_equivalent_ancestors(el, elem)
+                                if nodeIsEquivalent:
+                                    break
+
+                        # add missing nodes to dict by xml file
+                        if not nodeIsEquivalent:
+                            if not leafNodesMissing.get(xml_file):
+                                leafNodesMissing[xml_file] = []
+                            leafNodesMissing[xml_file].append(fileTree.getpath(el))
+
+        self._print_nodes(leafNodesMissing)
+        assert len(leafNodesMissing) == 0, "There are missing nodes in the base file!"
+        print "Complete."
+
+    def _ensure_equivalent_ancestors(self, treeNode, baseNode):
+        """
+        Class method that compares all ancestors for two given nodes for tag and attribute equivalence.
+        Nodes must ElementTree objects.
+
+        :param treeNode
+        :param baseNode
+        :return True/False
+        """
+
+        eq = False
+
+        # check if ancestor count is the same
+        treeAnc = [i for i in baseNode.iterancestors()]
+        baseAnc = [i for i in treeNode.iterancestors()]
+        if len(treeAnc) != len(baseAnc):
+            return eq
+
+        # check if node tags and attributes of ancestors match; only 'uID' attribute is matched
+        for i in range(len(treeAnc)):
+            tagC = (baseAnc[i].tag == treeAnc[i].tag)
+            attC = (baseAnc[i].attrib.get('uID') == treeAnc[i].attrib.get('uID'))
+            if not tagC or not attC:
+                return eq
+
+        eq = True
+        return eq
+
+    def _print_nodes(self, obj):
+        if type(obj) == dict:
+            for k, v in obj.items():
+                if hasattr(v, '__iter__'):
+                    print k
+                    self._print_nodes(v)
+                else:
+                    print '%s : %s' % (k, v)
+        elif type(obj) == list:
+            for v in obj:
+                if hasattr(v, '__iter__'):
+                    self._print_nodes(v)
+                else:
+                    print v
+        else:
+            print obj
+            # TODO: include where this is from
+
+    def _get_function_data(self):
+        """"
+        This method adds a new attribute functionData to the class instance that contains all information in the knowledge base.
+        functionData =
+        [
+            {
+                "info": {
+                                "generalInfo": {"name": str, "type": str, "version": float, "creator": str, "description": str},
+                                "executionInfo": {"runTime": int, "fidelity": int, "precision": float}
+                        }
+                        ,
+                "input": 	{
+                                "allXpaths": 	[ {"xpath": str, "tag": str, "attributes": dict}, ... ], # list of all xpaths
+                                "leafNodes": 	[ {"xpath": str, "tag": str, "attributes": dict, "value": str, "level": int}, ...] # list of all leafNodes
+
+                            },
+                "output": 	{
+                                "allXpaths": 	[ {"xpath": str, "tag": str, "attributes": dict}, ... ], # list of all xpaths
+                                "leafNodes": 	[ {"xpath": str, "tag": str, "value": str, "level": int}, ...] # list of all leafNodes
+
+                            }
+            }, # tool1
+            ...
+        ]
+        :param
+        :return self.functionData
+        """
+
+        self.functionData = []
+
+        for file in self.functionFiles["info"]:
+            # initiate a dict for each function
+            funcDict = {'info':{   'generalInfo':{},
+                                    'executionInfo': {}
+                                }}
+
+            with open(self.knowledgeBase + '/' + file) as info:
+                infoData = json.load(info)
+
+            # add function info from file to funcDict
+            for inf in self.FIXEDVAL["GEN_INFO"]:
+                if inf in ['name', 'type']:
+                    # make sure that function name and type is defined, is string
+                    assert isinstance(infoData["general_info"].get(inf), basestring), "Function name and type must be defined in the info-file, and must be a string!"
+                # add info if given
+                try: funcDict['info']['generalInfo'][inf] = infoData["general_info"].get(inf)
+                except KeyError: continue
+
+            for inf in self.FIXEDVAL["EXEC_INFO"]:
+                try: funcDict['info']['executionInfo'][inf] = infoData["execution_info"].get(inf)
+                except KeyError: continue
+
+            # TODO: enforce naming conventions for general_info and execution_info; important for later use! All tests should be done here
+
+            # get input and output data
+            funcName = file[:-10] # slice -info.json
+            funcDict['input'] = self._get_in_out_data(funcName, 'input')
+            funcDict['output'] = self._get_in_out_data(funcName, 'output')
+
+            # add function dictionary to list of function data
+            self.functionData.append(funcDict)
+
+        return
+
+    def _get_in_out_data(self, funcName, inOut):
+        """
+        This function writes the data in the input and output files to a dictionary.
+        :param funcName:
+        :param inOut: must be "input" or "output"
+        :return: dict
+        """
+        dic = {"allXpaths":[], "leafNodes": []}
+
+        tree = etree.parse(self.knowledgeBase + '/' + funcName + "-" + inOut + ".xml")
+
+        for el in tree.iter():
+            d1, d2 = {}, {}
+            path = tree.getpath(el)
+            d1['xpath'] = path
+            d1['tag'] = el.tag
+            d1['attributes'] = el.attrib
+            if not el.getchildren(): # if child node
+                d2['xpath'] = path
+                d2['tag'] = el.tag
+                d2['attributes'] = el.attrib
+                if el.text is not None:
+                    d2['value'] = el.text.strip()
+                else:
+                    d2['value'] = el.text # adding None if empty
+                d2['level'] = path.count('/') -1
+                dic['leafNodes'].append(d2)
+
+            dic['allXpaths'].append(d1)
+
+        return dic
+
+    # TODO:  >>>>>>>>>>>>>>>>>> Cut the class here, create graph class that contains all graphs?? <<<<<<<<<<<<<<<<<<<<<<
+
+    def _get_kb_graphs(self):
+        """
+        This class method generates all graphs for all present functions in the knowledge base.
+
+        :return: self.functionGraphs
+        """
+        funcList = [self.functionData[i]['info']["generalInfo"]['name'] for i in range(len(self.functionData))]
+        graphList = []
+
+        for func in funcList:
+            graphList.append(self._get_function_graph(func))
+
+        self.functionGraphs = graphList
+        return
+
+    def _get_function_graph(self, funcName, inOut=None):
+        """
+        This function builds a directed graph (object) for the specified function using the "networkx" package. If inOut
+        argument is specified, only the input or output of the function will be included in the graph, otherwise both.
+
+        :param: funcName: function name for which the graph is generated; must be present in knowledge base.
+        :param: inOut: default = None; if specified, must be "input" or "output" string. Specification of this argument enables the generation of the function graph with only input or output variables.
+        :return: functionGraph
+        """
+        assert isinstance(funcName, basestring), 'Provided function name must be a string!'
+
+        # assert funcName exists and get index of function in self.functionData list
+        funcIndex = None
+        for idx, funcDict in enumerate(self.functionData):
+            if funcDict['info']['generalInfo']['name'] == funcName:
+                funcIndex = idx #funcIndex is index of the function in list
+                break
+        assert funcIndex is not None, "The provided function name can not be found in knowledge base."
+
+        # assert inOut, if defined, is string and either input or output
+        if inOut is not None:
+            assert isinstance(inOut, basestring), "inOut argument must be a string if specified."
+            assert inOut.lower() in ["input", "output"], "inOut argument must be either 'input' or 'output'."
+
+        # initiate directed graph and list of edges
+        DG, edges = nx.DiGraph(), []
+
+        # add edges to list, then to graph
+        if inOut is not None:
+            edges += self._create_edge_tuples(funcIndex, inOut)
+        else:
+            for io in ['input', 'output']:
+                edges += self._create_edge_tuples(funcIndex, io)
+        DG.add_edges_from(edges)
+
+        # add node attributes to graph
+        self._add_node_attribs(funcIndex, DG)
+
+        return DG
+
+    def _create_edge_tuples(self, funcIndex, inOut):
+        """
+        This function creates a list of edge tuples in order to generate a graph.
+
+        :param funcIndex: index of function in list of tool dicts in self.functionData
+        :param inOut: specified whether input or output nodes, None adds all to graph
+        :return: graphEdges: list of edges to build graph
+        """
+        graphEdges = []
+        fdata = self.functionData[funcIndex]
+        funcName = fdata['info']['generalInfo']['name']
+
+        for leafNode in fdata[inOut.lower()]['leafNodes']:
+            if inOut == 'input':
+                tpl = (leafNode['xpath'], funcName)  # var --> tool
+            else:
+                tpl = (funcName, leafNode['xpath'])  # tool --> var
+            graphEdges.append(tpl)
+        # to include additional information in edges, adjust these loops ! check doc on networkx!
+
+        return graphEdges
+
+    def _add_node_attribs(self, funcIndex, G):
+        """
+        Function that adds node attributes to the nodes of the graph.
+
+        :param funcIndex: index of function in list of tool dicts in self.functionData
+        :param G: Grpah w/o attribs
+        :return: Graph w/ attribs
+        """
+
+        fdata = self.functionData[funcIndex]
+        funcName = fdata['info']['generalInfo']['name']
+
+        for node in G.nodes_iter():
+
+            if node == funcName:
+                G.node[node]['shape'] = 's'  # square for functions
+                G.node[node]['category'] = 'function' # TODO: this can be adjusted. Maybe differentiate between tool, constraint, objective etc? This info can then be used later on!
+                G.node[node]['label'] = funcName
+                G.node[node]['level'] = None
+                try: # add available exectuion info to function node
+                    for inf in fdata['info']['executionInfo']:
+                        G.node[node][inf] = fdata['info']['executionInfo'][inf]
+                except KeyError: pass # if not present, continue
+
+            else:
+                G.node[node]['shape'] = 'o'  # circle for variables
+                G.node[node]['category'] = 'variable'
+                G.node[node]['label'] = node.split('/')[-1]
+                G.node[node]['level'] = node.count('/') - 1
+                G.node[node]['execution time'] = 1 # TODO: Why is execution time included in variable info?
+
+        return G
+
+    # TODO: >>>>>>>>>>>>>>>>>>  <<<<<<<<<<<<<<<<<<<<<<
+
+    def get_MCG(self):
+        """
+        Function to create Maximal Connectivity Graph (Pate, 2014) by composing a list of graphs.
+
+        :return: maximal connectivity graph (MCG)
+        """
+        list_of_graphs = self.get_function_graphs()
+        MCG = list_of_graphs[0]
+        for i in range(1, len(list_of_graphs)):
+            MCG = nx.compose(MCG, list_of_graphs[i])
+        return MCG
+
+    def get_extended_MCG(self):
+        """
+        Function to create extended maximal connectivity graph (additional variable nodes for shared variables).
+
+        :return: extended maximal connectivity graph (MCG)
+        """
+
+        # Get required inputs
+        function_names = self.get_function_names()
+        common_nodes = self.get_common_nodes()
+        tool_function_graphs = self.get_function_graphs()
+
+        # Start with an empty DiGraph
+        ext_MCG = nx.DiGraph()
+
+        for idx, tool_name in enumerate(function_names):
+            ext_MCG = nx.union(ext_MCG, tool_function_graphs[idx], rename=(None, function_names[idx] + ':'))
+
+            for common_node in common_nodes:
+                for key in common_node['functions']:
+                    # check if tool name is the new tool in the graph
+                    if common_node['functions'][key]['tool_name'] == tool_name:
+                        if ext_MCG.has_node(common_node['xpath']) and \
+                                        common_node['functions'][key]['in_or_output'] == 'input' and \
+                                not any(common_node['functions'][key]['in_or_output'] == 'output' for key in
+                                        common_node['functions']):  # check if node already exists and if it's input
+                            ext_MCG.add_edge(common_node['xpath'],
+                                             common_node['functions'][key]['tool_name'] + ':' + common_node['xpath'])
+                        elif common_node['functions'][key]['in_or_output'] == 'output':
+                            # if tool has a common node output that is input for another tool,
+                            # check if input tool node is in graph and connect
+                            for tl in function_names:
+                                if ext_MCG.has_node(tl + ':' + common_node['xpath']) and not tl == tool_name:
+                                    ext_MCG.add_edge(tool_name + ':' + common_node['xpath'],
+                                                     tl + ':' + common_node['xpath'])
+                        elif common_node['functions'][key]['in_or_output'] == 'input' and \
+                                any(common_node['functions'][key]['in_or_output'] == 'output' for key in
+                                    common_node['functions']):
+                            # if tool has a common node input that is output of another tool,
+                            # check if output tool node is in graph and connect
+                            for tl in function_names:
+                                if ext_MCG.has_node(tl + ':' + common_node['xpath']) and not tl == tool_name:
+                                    ext_MCG.add_edge(tl + ':' + common_node['xpath'],
+                                                     tool_name + ':' + common_node['xpath'])
+                        elif common_node['functions'][key]['in_or_output'] == 'input' and \
+                                not any(common_node['functions'][key]['in_or_output'] == 'output' for key in
+                                        common_node['functions']):
+                            ext_MCG.add_node(common_node['xpath'],
+                                             label=common_node['tag'],
+                                             shape='o',
+                                             category='variable',
+                                             level=common_node['xpath'].count('/') - 1)
+                            ext_MCG.add_edge(common_node['xpath'],
+                                             common_node['functions'][key]['tool_name'] + ':' + common_node['xpath'])
+        return ext_MCG
+
+    def get_contracted_graph(self, graph, contraction_level):
+        """
+        Function to contract the nodes of a graph to a given xpath level.
+
+        :param graph: input graph
+        :param contraction_level: int from 0 (highest level) to X (lowest level existing in XML schema)
+        :return: graph with contracted nodes
+        """
+
+        # Input checks
+        function_data = self.analyze_function_files()
+        max_con = len(function_data['sorted_xpaths']) - 1  # maximum contraction level value
+        assert (contraction_level <= max_con), \
+            "Contraction level {} is higher than maximum allowed value of {} for the given XML Schema." \
+                .format(contraction_level, max_con)
+        assert (contraction_level >= 0), "Contraction level should be a positive value."
+
+        if contraction_level > max_con:
+            raise ValueError('Contraction level is higher than maximum level possible according to XML schema.')
+
+        # Start for loop to iterate over each graph node and contract nodes where required.
+        contracted_graph = nx.compose(nx.DiGraph(), graph)
+
+        for node, data in contracted_graph.nodes_iter(data=True):
+            if data['category'] == 'function':
+                pass
+            elif data['level'] <= contraction_level:
+                pass
+            elif data['level'] > contraction_level:
+                # Find higher level brother at required level
+                # Split node at separator character
+                split_xpath = node.split('/')
+                split_xpath = split_xpath[1:]  # remove first entry (this is empty since string starts with '/'
+
+                # Create xpath of required node
+                required_node = '/' + '/'.join(split_xpath[0:contraction_level + 1])
+
+                # Check if the existing contracted graph has the required node
+                if not contracted_graph.has_node(required_node):
+                    contracted_graph.add_node(required_node,
+                                              shape='d',
+                                              category='variable group',
+                                              label=split_xpath[contraction_level],
+                                              level=contraction_level)
+
+                # Contract node with its higher level brother
+                contracted_graph = nx.contracted_nodes(contracted_graph,
+                                                       required_node,
+                                                       node, self_loops=True)
+        return contracted_graph
+
+    def get_FPG_based_on_sink(self, sink):
+        """
+        Function to get the Fundamental Problem Graph based on the required output variable.
+
+        :param sink: node name of desired variable
+        :return: Fundamental problem graph (FPG) object
+        """
+        MCG = self.get_MCG()
+        MCG.graph['sinks'] = sink
+        ancestors = nx.ancestors(MCG, sink)
+        ancestors.add(sink)
+        return MCG.subgraph(ancestors)
+
+    def get_FPG_based_on_sinks(self, list_of_sinks):
+        """
+        Function to get the Fundamental Problem Graph based on a list of sinks / required output variables.
+
+        :param list_of_sinks: list with strings that specify the desired output
+        :return: Fundamental Problem Graph (FPG) object
+        """
+        MCG = self.get_MCG()
+        FPG = nx.DiGraph(sinks=list_of_sinks)
+        for sink in list_of_sinks:
+            ancestors = nx.ancestors(MCG, sink)
+            ancestors.add(sink)
+            FPG_sink = MCG.subgraph(ancestors)
+            FPG = nx.compose(FPG, FPG_sink)
+        return FPG
+
+    def get_partitioned_graph(self, G, n_parts, tpwgts=None, recursive=False, contig=False, output='DiGraph'):
+        """
+        Partition a graph using the Metis algorithm (http://glaros.dtc.umn.edu/gkhome/metis/metis/overview). Note that
+        partitioning can only be performed on undirected graphs. Therefore every graph input is translated into an
+        undirected graph.
+
+        :param G: graph object
+        :param n_parts: number of partitions requested (algorithm might provide less)
+        :param tpwgts: list of target partition weights
+        :param recursive: Metis option
+        :param contig: Metis option
+        :param output: set whether expected output is a DiGraph or normal Graph
+        :return: list of edges that have been cut
+        :return: list of partition group to which each node belongs
+        """
+        G_und = nx.Graph(G)  # make graph undirected for partitioning
+        color_list = MDOvisualization.color_list()
+        if n_parts > len(color_list):
+            raise IOError('Maximum number of partitions is {}. {} partitions have been specified.' \
+                          .format(len(color_list), n_parts))
+        (edgecuts, parts) = metis.part_graph(G_und, n_parts, tpwgts=tpwgts, recursive=recursive, contig=contig)
+
+        # Store partition colors
+        colors = color_list[0:n_parts]
+        i = 0
+        if output == 'DiGraph':
+            G_out = nx.DiGraph(G)
+        elif output == 'Graph':
+            G_out = nx.Graph(G)
+        else:
+            raise IOError("Invalid graph output ({}) specified. Only 'DiGraph' or 'Graph' are allowed inputs." \
+                          .format(output))
+        for node, data in G_und.nodes_iter(data=True):
+            G_out.node[node]['part_color'] = colors[parts[i]]
+            G_out.node[node]['part_id'] = parts[i]
+            i += 1
+        return G_out
+
+    def get_PSG_for_MDF(self, FPG, MDA_type, analysis_order):
+        """
+        Create the PSG graph for the MDF method.
+
+        :param FPG: fundamental problem graph with required node properties
+        :type FPG: DiGraph
+        :param MDA_type: type of multidisciplinary analysis to be implemented ('Gauss-Seidel' or 'Jacobi')
+        :type MDA_type: str
+        :param analysis_order: list with the order of the analyses in the MDA
+        :type analysis_order: list
+        :return: dictionary with the PSG process flow and the PSG data flow
+        :rtype: dict
+        """
+        # ------------------#
+        #  FPG+input check  #
+        # ------------------#
+        # Make clean copy of FPG to avoid attribute updates
+        FPG = nx.DiGraph(FPG)
+        # Find all function nodes
+        function_nodes = set(find_all_nodes(FPG, attr_cond=['category', '==', 'function']))
+
+        # Select design variables, parameters, constraint and objective functions
+        des_var_nodes = find_all_nodes(FPG, attr_cond=['PSG role', '==', 'design variable'])
+        assert len(des_var_nodes) > 0, "No design variables are specified. Use the 'PSG role' attribute for this."
+        parameter_nodes = list(set(find_all_nodes(FPG, category='variable', subcategory='all inputs')). \
+                               difference(set(des_var_nodes)))
+        constraint_nodes = find_all_nodes(FPG, attr_cond=['PSG role', '==', 'constraint'])
+        assert len(
+            constraint_nodes) > 0, "No constraint variables are specified. Use the 'PSG role' attribute for this."
+
+        objective_node = find_all_nodes(FPG, attr_cond=['PSG role', '==', 'objective'])
+        assert len(objective_node) == 1, "%d design variables are specified. Only one objective node is allowed." \
+                                         "Use the 'PSG role' attribute for this." % len(objective_node)
+        optimizer_nodes_in = objective_node + constraint_nodes
+        constraint_functions = list()
+        for idx, node in enumerate(objective_node + constraint_nodes):
+            assert FPG.node[node]['indegree'] == 1, "Invalid indegree of %d, while it should be 1." \
+                                                    % FPG.node[node]['indegree']
+            assert FPG.node[node]['outdegree'] == 0, "Invalid outdegree of %d, while it should be 0." \
+                                                     % FPG.node[node]['indegree']
+            if idx == 0:
+                objective_function = FPG.in_edges(node)[0][0]
+            elif not (FPG.in_edges(node)[0][0] in set(constraint_functions)):
+                constraint_functions.append(FPG.in_edges(node)[0][0])
+        optimizer_functions = [objective_function] + constraint_functions
+
+        # Select analysis order functions
+        for node in analysis_order:
+            assert set([node]).intersection(function_nodes), \
+                "One of the names ('%s') in the analysis_order input is invalid." % node
+
+        # Remove the objective function and constraint functions from the set
+        MDA_analysis_nodes = function_nodes.difference(set(optimizer_functions))
+
+        # Check if any functions are left between analysis nodes and analysis order
+        assert len(MDA_analysis_nodes.difference(set(analysis_order))) == 0, \
+            "There are undefined functions present in the FPG, namely %s! These should be added to the analysis order" \
+            " or become objective/constraints functions." % MDA_analysis_nodes.difference(set(analysis_order))
+
+        #------------------#
+        # PSG process flow #
+        #------------------#
+        # Set up PSG process graph
+        PSG_process = nx.DiGraph()
+        PSG_process.graph['architecture'] = 'MDF'
+        PSG_process.graph['number_of_diagonal_blocks'] = 3 + len(analysis_order) + 1 + len(constraint_functions)
+        PSG_process.graph['number_of_MDA_analyses'] = len(analysis_order)
+        PSG_process.graph['number_of_OPT_functions'] = 1 + len(constraint_functions)
+
+        # Add MDA block
+        PSG_process.add_node('MDA',
+                             category='architecture element',
+                             subcategory='MDA',
+                             shape='8',
+                             label='MDA',
+                             level=None,
+                             diagonal_position=2,
+                             process_step=2,
+                             converger_step=3+len(analysis_order))
+
+        # Connect MDA + analyses
+        if MDA_type == 'Gauss-Seidel':
+            from_node = 'MDA'
+            for idx, node in enumerate(analysis_order):
+                PSG_process.add_node(node,FPG.node[node],diagonal_position=3+idx)
+                PSG_process.node[node]['category'] = 'architecture element'
+                PSG_process.node[node]['subcategory'] = 'MDA analysis'
+                PSG_process.node[node]['process_step'] = idx+3
+                PSG_process.add_edge(from_node, node, process_step=idx+3)
+                from_node = node
+            PSG_process.add_edge(from_node,'MDA',process_step=idx+4)
+
+        # Add optimization block
+        PSG_process.add_node('Optimizer',
+                             category='architecture element',
+                             subcategory='optimizer',
+                             shape='8',
+                             label='OPT',
+                             level=None,
+                             diagonal_position=1,
+                             process_step=1,
+                             converger_step=3+len(analysis_order)+2)
+
+        # Connect optimization with MDA
+        PSG_process.add_edge('Optimizer', 'MDA', process_step=2)
+
+        # Connect MDA with functions and functions with optimizer
+        for idx, node in enumerate(optimizer_functions):
+            PSG_process.add_node(node, FPG.node[node],diagonal_position=3+len(analysis_order)+idx)
+            PSG_process.node[node]['category'] = 'architecture element'
+            PSG_process.node[node]['subcategory'] = 'optimizer function'
+            PSG_process.node[node]['process_step'] = 4+len(analysis_order)
+            PSG_process.add_edge('MDA',node, process_step=4+len(analysis_order))
+            PSG_process.add_edge(node,'Optimizer', process_step=5+len(analysis_order))
+
+        # Add Initiator block
+        PSG_process.add_node('Initiator',
+                             category='architecture element',
+                             subcategory='initiator',
+                             shape='8',
+                             label='INI',
+                             level=None,
+                             diagonal_position=0,
+                             process_step = 0,
+                             converger_step = 6+len(constraint_functions))
+        # Connect initiator with optimizer
+        PSG_process.add_edge('Initiator', 'Optimizer', process_step=1)
+        PSG_process.add_edge('Optimizer', 'Initiator', process_step=6+len(constraint_functions))
+
+        # ------------------#
+        #   PSG data flow   #
+        # ------------------#
+        # Set up PSG process graph
+        PSG_data = nx.compose(nx.DiGraph(),FPG)
+        PSG_data.graph['architecture'] = 'MDF'
+        PSG_data.graph['number_of_diagonal_blocks'] = 2 + len(analysis_order) + 1 + len(constraint_functions)
+        PSG_data.graph['number_of_MDA_analyses'] = len(analysis_order)
+        PSG_data.graph['number_of_OPT_functions'] = 1 + len(constraint_functions)
+
+        # Add MDA block
+        PSG_data.add_node('MDA',
+                          category='architecture element',
+                          subcategory='MDA',
+                          shape='8',
+                          label='MDA',
+                          level=None,
+                          diagonal_position=2)
+
+        # Add Initiator block
+        PSG_data.add_node('Initiator',
+                             category='architecture element',
+                             subcategory='initiator',
+                             shape='8',
+                             label='INI',
+                             level=None,
+                             diagonal_position=0)
+
+        # Loop over MDA analyses, add copy variables and adjust edges
+        if MDA_type == 'Gauss-Seidel':
+            for idx, analysis in enumerate(analysis_order):
+                PSG_data.node[analysis]['category'] = 'architecture element'
+                PSG_data.node[analysis]['subcategory'] = 'MDA analysis'
+                PSG_data.node[analysis]['diagonal_position'] = 3+idx
+
+                # Check incoming edges
+                in_edges = FPG.in_edges(analysis)
+                for edge in in_edges:
+                    # Check if edge is an input
+                    in_node = edge[0]
+                    if not set([in_node]).intersection(set(find_all_nodes(FPG,subcategory='all inputs'))):
+                        if set([in_node]).intersection(set(find_all_nodes(FPG, subcategory='all problematic nodes'))):
+                            raise IOError("A problematic node is still present in the FPG.")
+                        elif set([in_node]).intersection(set(find_all_nodes(FPG, subcategory='all couplings'))):
+                            # Check if the node is coupled to a future analysis
+                            coupled_functions = map((lambda x: x[0]), FPG.in_edges(in_node))
+                            if set(coupled_functions).intersection(set(analysis_order[idx:])):
+                                # Add variable copy node between MDA and function
+                                new_node = '/PSG/coupling_variables/MDA/' + FPG.node[in_node]['label'] + str('^c')
+                                PSG_data.add_node(new_node,
+                                                  category='architecture element',
+                                                  subcategory='MDA coupling variable',
+                                                  shape='o',
+                                                  label=FPG.node[in_node]['label'] + '^c',
+                                                  level=3)
+                                PSG_data.add_edge('MDA',new_node)
+                                PSG_data.add_edge(new_node, analysis)
+                                PSG_data.remove_edge(in_node,edge[1])
+                                # Add edge between the coupling variable and the MDA
+                                PSG_data.add_edge(in_node,'MDA')
+                                # Add initial guess MDA coupling variable
+                                new_node = '/PSG/coupling_variables/MDA/' + FPG.node[in_node]['label'] + str('^{c0}')
+                                PSG_data.add_node(new_node,
+                                                  category='architecture element',
+                                                  subcategory='initial guess MDA coupling variable',
+                                                  shape='o',
+                                                  label=FPG.node[in_node]['label'] + '^{c0}',
+                                                  level=3)
+                                PSG_data.add_edge(new_node, 'MDA')
+                                PSG_data.add_edge('Initiator', new_node)
+                # Check outcoming edges
+                out_edges = FPG.out_edges(analysis)
+                for edge in out_edges:
+                    # Check if edge is a coupling variable
+                    out_node = edge[1]
+                    if set([out_node]).intersection(set(find_all_nodes(FPG, subcategory='all couplings'))):
+                        # Add final coupling variable node and connect to analysis function and initiator
+                        new_node = '/PSG/coupling_variables/MDA/' + FPG.node[out_node]['label'] + str('^*')
+                        PSG_data.add_node(new_node,
+                                          category='architecture element',
+                                          subcategory='final MDA coupling variable',
+                                          shape='o',
+                                          label=FPG.node[out_node]['label'] + '^*',
+                                          level=3)
+                        PSG_data.add_edge(analysis,new_node)
+                        PSG_data.add_edge(new_node,'Initiator')
+
+        # Recategorize design variables and connect to optimizer and initiator
+        # Add optimization block
+        PSG_data.add_node('Optimizer',
+                          category='architecture element',
+                          subcategory='optimizer',
+                          shape='8',
+                          label='OPT',
+                          level=None,
+                          diagonal_position=1)
+
+        for node in des_var_nodes:
+            # Connect design variables to optimizer
+            PSG_data.add_edge('Optimizer',node)
+
+            # Add input variables x^(0) and connect to optimizer
+            new_node = '/PSG/design_variables/initial_guesses/' + FPG.node[node]['label'] + '^0'
+            PSG_data.add_node(new_node,
+                              category='architecture element',
+                              subcategory='initial guess design variable',
+                              shape='o',
+                              label=FPG.node[node]['label'] + '^0',
+                              level=3)
+            PSG_data.add_edge(new_node, 'Optimizer')
+            PSG_data.add_edge('Initiator', new_node)
+
+            # Add output variables x^* and connect to optimizer and initiator
+            new_node = '/PSG/design_variables/final_value/' + FPG.node[node]['label'] + '^*'
+            PSG_data.add_node(new_node,
+                              category='architecture element',
+                              subcategory='final design variable',
+                              shape='o',
+                              label=FPG.node[node]['label'] + '^*',
+                              level=3)
+            PSG_data.add_edge('Optimizer', new_node)
+            PSG_data.add_edge(new_node, 'Initiator')
+
+        # Add parameters and connect with INI function
+        for node in parameter_nodes:
+            # Connect parameter node to initiator
+            PSG_data.add_edge('Initiator', node)
+
+        # Adjust and connect optimizer functions output to optimizer
+        for idx, node in enumerate(optimizer_functions):
+            PSG_data.node[node]['diagonal_position']= 3 + len(analysis_order) + idx
+            PSG_data.node[node]['category'] = 'architecture element'
+            PSG_data.node[node]['subcategory'] = 'optimizer function'
+        for node in optimizer_nodes_in:
+            PSG_data.add_edge(node,'Optimizer')
+
+        return {'data flow':PSG_data,'process flow':PSG_process}
 
-if __name__ == '__main__':
-    print 'Trial run successfull.'
-    print 'To test the MDOproblem class, run the sellarProblem.py file.'
diff --git a/pyKADMOS/testRun.py b/pyKADMOS/testRun.py
index d76f8d45e..8957112e5 100644
--- a/pyKADMOS/testRun.py
+++ b/pyKADMOS/testRun.py
@@ -1,11 +1,13 @@
 import pprint
 import networkx as nx
-from pyKADMOS.MDOproblem import MDOproblem
-
-exProb = MDOproblem('KB_CPACS')
-
-graph3 = exProb._get_function_graph('EMWET')
-
+from pyKADMOS.MDOproblem import MdoProblemInit
+from pyKADMOS.MDOgraph import Graph, MCG
+from pyKADMOS.MDOvisualization import plot_graph
+
+kb = MdoProblemInit('KB_CPACS')
+mcg = MCG(kb)
+mcGraph = mcg.get_graph(3)
+plot_graph(mcGraph, 1,  show_now=True)
 
 print "ALL GOOD!!!"
 
-- 
GitLab