#!/usr/bin/env python # -*- coding: utf-8 -*- """ Copyright(C) 2007 INL Written by Damien Boucard <damien.boucard AT inl.fr> This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 3 of the License. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, see <http://www.gnu.org/licenses/>. --- descxml.py is a module for exporting or importing a desc.xml file. """ from checkdesc.descmodels import Desc, Firewall, Interface, Address, \ Network, InternetConnection, DirectConnection, RoutedConnection from checkdesc.desc_warnings import UselessWarning, VersionWarning from checkdesc.tools import compareVersion import re from sys import exit, stderr from IPy import IP from warnings import warn from nupyf.nupyf_etree import etree class Loader: earlier_version = "1.2" direct_set = ("0", "1") ipv4_regex_base = r'(\d{1,3}\.){3}\d{1,3}' ipv4_address_regex = re.compile('^%s$' %(ipv4_regex_base)) ipv4_network_regex = re.compile("^%s/(%s|%s)$" %(ipv4_regex_base, ipv4_regex_base, r'\d{1,2}')) def __init__(self, file): self.doc = etree.parse(file) self.desc = None def build(self, check_version=None): self.__firewall_dict = {} self.__interface_dict = {} xml_network = self.doc.getroot() desc_version = xml_network.get("version") if check_version and desc_version != check_version: print >>stderr, \ "Error: Version %r is required (current version: %r)" \ % (check_version, desc_version) exit(1) if desc_version is None: desc_version = self.earlier_version warn(VersionWarning('Omitted version property; set to "%s" by default.' %(self.earlier_version)), stacklevel=2) self.desc = Desc() self.desc.firewalls = self.build_firewalls(xml_network) self.desc.networks = self.build_networks(xml_network) return self.desc def build_firewalls(self, xml_network): firewalls = [] xml_fws = xml_network.find('fws') if not xml_fws: raise ValueError('Unable to find <fws>') for xml_fw in xml_fws.findall('fw'): firewall_args = {"queue": None} firewall_kw = {} for property, value in xml_fw.items(): if property in ("queue",): firewall_args[property] = value elif property in ("name", "type", "id",): firewall_kw[property] = value else: warn(UselessWarning('Unexpected property in "fw" tag: %s="%s"' %(property, value)),stacklevel=2) self._check_mandatory_properties("fw", firewall_args) firewall = Firewall(firewall_args["queue"], **firewall_kw) firewall.interfaces = self.build_interfaces(xml_fw, firewall) self.__firewall_dict[firewall.id] = firewall firewalls.append(firewall) return firewalls def build_interfaces(self, xml_fw, firewall): interfaces = [] xml_interfaces = xml_fw.find('interfaces') if not xml_interfaces: raise ValueError('Unable to find <interfaces>') for xml_interface in xml_interfaces.findall('interface'): interface_args = {"name": None} interface_kw = {} for property, value in xml_interface.items(): if property in ("name",): interface_args[property] = value elif property in ("id",): interface_kw[property] = value elif property == "is_vlan": interface_kw[property] = (value == "1") else: warn(UselessWarning('Unexpected property in "interface" tag: %s="%s"' %(property, value)),stacklevel=2) self._check_mandatory_properties("interface", interface_args) interface = Interface(firewall, interface_args["name"], **interface_kw) interface.extend(self.build_addresses(xml_interface)) interface_key = "%s,%s"%(firewall.id, interface.id) self.__interface_dict[interface_key] = interface interfaces.append(interface) return interfaces def build_addresses(self, xml_interface): addresses = [] for xml_address in xml_interface.findall('address'): address_args = {"addr": None} address_kw = {} for property, value in xml_address.items(): if property == "addr": address_args["addr"] = self._get_ipv4_address("address", "addr", value) elif property in ("id", "type"): address_kw[property] = value else: warn(UselessWarning('Unexpected property in "address" tag: %s="%s"' %(property, value)),stacklevel=2) self._check_mandatory_properties("address", address_args) address = Address(address_args["addr"], **address_kw) addresses.append(address) return addresses def build_networks(self, xml_network): networks = [] xml_nets = xml_network.find('nets') if not xml_nets: raise ValueError('Unable to find <nets>') for xml_net in xml_nets.findall("net"): network_args = {"name": None, "type": None, "addr": None} network_kw = {} for property, value in xml_net.items(): if property in ("name", "type",): network_args[property] = value elif property == "remote": network_kw[property] = self._get_ipv4_address("net", property, value) elif property == "addr": network_args[property] = self._get_ipv4_network("net", property, value) elif property in ("id", "local_id"): network_kw[property] = value elif property == "enabled": network_kw[property] = (value == "1") else: warn(UselessWarning('Unexpected property in "net" tag: %s="%s"' %(property, value)),stacklevel=2) self._check_mandatory_properties("net", network_args) network = Network(network_args["name"], network_args["type"], network_args["addr"], **network_kw) self.build_connections(xml_net, network) networks.append(network) return networks def build_connections(self, xml_net, network): for xml_connection in xml_net.findall('connection'): connection_dict = dict(xml_connection.items()) # Get the interface instance referenced by fwid and iface properties if "fwid" not in connection_dict: raise OmittedPropertyError("connection", "fwid") elif "iface" not in connection_dict: raise OmittedPropertyError("connection", "iface") else: interface_key = "%s,%s" %(connection_dict["fwid"], connection_dict["iface"]) if interface_key not in self.__interface_dict: raise KeyError("No interface with ID #%s in firewall #%s" % ( connection_dict["iface"], connection_dict["fwid"])) interface = self.__interface_dict[interface_key] del connection_dict["fwid"] del connection_dict["iface"] # Check direct property direct = connection_dict.get("direct") del connection_dict["direct"] if direct is None: raise OmittedPropertyError("connection", "direct") if direct not in self.direct_set: raise ValueError('Invalid direct value in "connection" tag: "%s"; must be one of these values : "%s"' %(direct, '", "'.join(self.direct_set))) self.direct = (direct != "0") # Construct the appropriate Connection instance if self.direct: if "dftgateway" in connection_dict: connection_dftgateway = self._get_ipv4_address("connection", "dftgateway", connection_dict["dftgateway"]) InternetConnection(network, interface, connection_dftgateway) del connection_dict["dftgateway"] else: DirectConnection(network, interface) else: if network.remote: RoutedConnection(network, interface, network.remote) elif "gateway" in connection_dict: connection_gateway = self._get_ipv4_address("connection", "gateway", connection_dict["gateway"]) RoutedConnection(network, interface, connection_gateway) del connection_dict["gateway"] else: raise OmittedPropertyError("connection", "gateway") # Warn superfluity properties for property, value in connection_dict.iteritems(): warn(UselessWarning('Unexpected property in "connection" tag: %s="%s"' %(property, value)),stacklevel=2) def _get_ipv4_address(self, tag, property, value): if not self.ipv4_address_regex.match(str(value)): raise ValueError('%s property from "%s" tag has not a clean IPv4 syntax: "%s".' %(property, tag, value)) try: ipy_instance = IP(str(value)) except ValueError, e: if self.ipv4_address_regex.match(str(value)): raise ValueError('Invalid %s value in "%s" tag: "%s".' %(property, tag, e)) raise ValueError('Invalid %s value in "%s" tag: "%s"; must be an IP address.' %(property, tag, value)) return ipy_instance def _get_ipv4_network(self, tag, property, value): try: ipy_instance = IP(str(value)) except ValueError, e: if self.ipv4_network_regex.match(str(value)): raise ValueError('Invalid %s value in "%s" tag: "%s".' %(property, tag, e)) raise ValueError('Invalid %s value in "%s" tag: "%s"; must be an IP address with its netmask.' %(property, tag, value)) if not self.ipv4_network_regex.match(str(value)): raise ValueError('%s property from "%s" tag has not a clean IPv4 syntax: "%s".' %(property, tag, value)) return ipy_instance def _check_mandatory_properties(self, tag, prop_dict): for property, value in prop_dict.iteritems(): if value is None: raise OmittedPropertyError(tag, property) def load(file): """ Generate model objects from a given XML file. @param file: XML file-like object to import. @type file: file @rtype: descmodels.Desc """ return Loader(file).build() class Dumper: def __init__(self, desc): self.desc = desc self.doc = None def build(self): xml_network = etree.Element("network", version=self.desc.version, applied="1") self.build_fws(xml_network) self.build_nets(xml_network) indent(xml_network) self.doc = etree.ElementTree(xml_network) #doc.docinfo.xml_version = "1.0" def build_fws(self, xml_network): xml_fws = etree.SubElement(xml_network, "fws") for firewall in self.desc.firewalls: firewall_dict = {} for (k, v) in firewall.__dict__.iteritems(): if v is None or callable(v): continue elif k in ('type', 'name', 'queue', 'id'): firewall_dict[k] = str(v) xml_fw = etree.SubElement(xml_fws, "fw", firewall_dict) self.build_interfaces(xml_fw, firewall) def build_interfaces(self, xml_fw, firewall): xml_interfaces = etree.SubElement(xml_fw, "interfaces") for interface in firewall.interfaces: interface_dict = {} for (k, v) in interface.__dict__.iteritems(): if v is None or callable(v): continue elif k in ('name', 'id'): interface_dict[k] = str(v) elif k in ('is_vlan',): interface_dict[k] = str(int(v)) if compareVersion(self.desc.version, "1.1") >= 0: xml_interface = etree.SubElement(xml_interfaces, "interface", interface_dict) self.build_addresses(xml_interface, interface) else: address = list(iter(interface)) if len(address) != 1: raise ValueError('Interface "%s" has %s addresses instead of exactly one' \ % (interface.name, len(address))) address = address[0] interface_dict['addr'] = str(address.addr) etree.SubElement(xml_interfaces, "interface", interface_dict) def build_addresses(self, xml_interface, interface): for address in interface: address_dict = {} for (k, v) in address.__dict__.iteritems(): if v is None or callable(v): continue elif k in ('addr', 'id', 'type'): address_dict[k] = str(v) etree.SubElement(xml_interface, "address", address_dict) def build_nets(self, xml_network): xml_nets = etree.SubElement(xml_network, "nets") for network in self.desc.networks: network_dict = {} for (k, v) in network.__dict__.iteritems(): if v is None or callable(v): continue if k in ('type', 'name', 'id', 'addr', 'remote', 'local_id'): v = str(v) elif k == 'enabled': v = str(int(v)) else: continue if k == 'addr' and '/' not in v: v += '/32' network_dict[k] = v xml_net = etree.SubElement(xml_nets, "net", network_dict) self.build_connections(xml_net, network) def build_connections(self, xml_net, network): for connection in network: connection_dict = {} for (k, v) in connection.__dict__.iteritems(): if v is None or callable(v): continue elif k in ('gateway',): connection_dict[k] = str(v) elif k == 'default_gateway': connection_dict["dftgateway"] = str(v) connection_dict["iface"] = str(connection.interface.id) connection_dict["fwid"] = str(connection.interface.firewall.id) connection_dict["direct"] = isinstance(connection, DirectConnection) and "1" or "0" etree.SubElement(xml_net, "connection", connection_dict) def write(self, file): #self.doc.write(file, "utf-8") self.doc.write(file) file.write("\n") def dump(desc, file): """ Generate an XML file from a given desc model. @param desc: data to export into XML file. @type desc: descmodels.Desc @param file: XML file-like object to export to. @type file: file """ dumper = Dumper(desc) dumper.build() dumper.write(file) def indent(elem, level=0): i = "\n" + level*" " if len(elem): if not elem.text or not elem.text.strip(): elem.text = i + " " for elem in elem: indent(elem, level+1) if not elem.tail or not elem.tail.strip(): elem.tail = i else: if level and (not elem.tail or not elem.tail.strip()): elem.tail = i class OmittedPropertyError(ValueError): def __init__(self, tag, property): ValueError.__init__(self, 'Omitted %s property from "%s" tag; this property is mandatory.' %(property, tag)) self.tag = tag self.property = property if __name__ == "__main__": import sys from StringIO import StringIO s = StringIO() dump(load(open(sys.argv[1], 'r')), s) print s.getvalue()