Sophie

Sophie

distrib > Mandriva > 2010.0 > i586 > media > contrib-release > by-pkgid > dca483b59ba61f3fa092de932ddd570e > files > 780

nuface-2.0.14-2mdv2009.1.i586.rpm

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Copyright(C) 2007 INL
Written by Damien Boucard <damien.boucard AT inl.fr>

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, see <http://www.gnu.org/licenses/>.

---
descxml.py is a module for exporting or importing a desc.xml file.
"""

from checkdesc.descmodels import Desc, Firewall, Interface, Address, \
    Network, InternetConnection, DirectConnection, RoutedConnection
from checkdesc.desc_warnings import UselessWarning, VersionWarning
from checkdesc.tools import compareVersion
import re
from sys import exit, stderr
from IPy import IP
from warnings import warn
from nupyf.nupyf_etree import etree

class Loader:
    earlier_version = "1.2"
    direct_set = ("0", "1")
    ipv4_regex_base = r'(\d{1,3}\.){3}\d{1,3}'
    ipv4_address_regex = re.compile('^%s$' %(ipv4_regex_base))
    ipv4_network_regex = re.compile("^%s/(%s|%s)$" %(ipv4_regex_base, ipv4_regex_base, r'\d{1,2}'))

    def __init__(self, file):
        self.doc = etree.parse(file)
        self.desc = None

    def build(self, check_version=None):
        self.__firewall_dict = {}
        self.__interface_dict = {}
        xml_network = self.doc.getroot()
        desc_version = xml_network.get("version")
        if check_version and desc_version != check_version:
            print >>stderr, \
                "Error: Version %r is required (current version: %r)" \
                % (check_version, desc_version)
            exit(1)
        if desc_version is None:
            desc_version = self.earlier_version
            warn(VersionWarning('Omitted version property; set to "%s" by default.'
                                         %(self.earlier_version)), stacklevel=2)
        self.desc = Desc()
        self.desc.firewalls = self.build_firewalls(xml_network)
        self.desc.networks = self.build_networks(xml_network)
        return self.desc

    def build_firewalls(self, xml_network):
        firewalls = []
        xml_fws = xml_network.find('fws')
        if not xml_fws:
            raise ValueError('Unable to find <fws>')
        for xml_fw in xml_fws.findall('fw'):
            firewall_args = {"queue": None}
            firewall_kw = {}
            for property, value in xml_fw.items():
                if property in ("queue",):
                    firewall_args[property] = value
                elif property in ("name", "type", "id",):
                    firewall_kw[property] = value
                else:
                    warn(UselessWarning('Unexpected property in "fw" tag: %s="%s"'
                                              %(property, value)),stacklevel=2)
            self._check_mandatory_properties("fw", firewall_args)
            firewall = Firewall(firewall_args["queue"], **firewall_kw)
            firewall.interfaces = self.build_interfaces(xml_fw, firewall)

            self.__firewall_dict[firewall.id] = firewall
            firewalls.append(firewall)
        return firewalls

    def build_interfaces(self, xml_fw, firewall):
        interfaces = []
        xml_interfaces = xml_fw.find('interfaces')
        if not xml_interfaces:
            raise ValueError('Unable to find <interfaces>')
        for xml_interface in xml_interfaces.findall('interface'):
            interface_args = {"name": None}
            interface_kw = {}
            for property, value in xml_interface.items():
                if property in ("name",):
                    interface_args[property] = value
                elif property in ("id",):
                    interface_kw[property] = value
                elif property == "is_vlan":
                    interface_kw[property] = (value == "1")
                else:
                    warn(UselessWarning('Unexpected property in "interface" tag: %s="%s"'
                                              %(property, value)),stacklevel=2)
            self._check_mandatory_properties("interface", interface_args)
            interface = Interface(firewall, interface_args["name"], **interface_kw)
            interface.extend(self.build_addresses(xml_interface))

            interface_key = "%s,%s"%(firewall.id, interface.id)
            self.__interface_dict[interface_key] = interface
            interfaces.append(interface)
        return interfaces

    def build_addresses(self, xml_interface):
        addresses = []
        for xml_address in xml_interface.findall('address'):
            address_args = {"addr": None}
            address_kw = {}
            for property, value in xml_address.items():
                if property == "addr":
                    address_args["addr"] = self._get_ipv4_address("address", "addr", value)
                elif property in ("id", "type"):
                    address_kw[property] = value
                else:
                    warn(UselessWarning('Unexpected property in "address" tag: %s="%s"'
                                              %(property, value)),stacklevel=2)
            self._check_mandatory_properties("address", address_args)
            address = Address(address_args["addr"], **address_kw)
            addresses.append(address)
        return addresses

    def build_networks(self, xml_network):
        networks = []
        xml_nets = xml_network.find('nets')
        if not xml_nets:
            raise ValueError('Unable to find <nets>')
        for xml_net in xml_nets.findall("net"):
            network_args = {"name": None, "type": None, "addr": None}
            network_kw = {}
            for property, value in xml_net.items():
                if property in ("name", "type",):
                    network_args[property] = value
                elif property == "remote":
                    network_kw[property] = self._get_ipv4_address("net", property, value)
                elif property == "addr":
                    network_args[property] = self._get_ipv4_network("net", property, value)
                elif property in ("id", "local_id"):
                    network_kw[property] = value
                elif property == "enabled":
                    network_kw[property] = (value == "1")
                else:
                    warn(UselessWarning('Unexpected property in "net" tag: %s="%s"'
                                              %(property, value)),stacklevel=2)
            self._check_mandatory_properties("net", network_args)
            network = Network(network_args["name"], network_args["type"], network_args["addr"], **network_kw)
            self.build_connections(xml_net, network)
            networks.append(network)
        return networks

    def build_connections(self, xml_net, network):
        for xml_connection in xml_net.findall('connection'):
            connection_dict = dict(xml_connection.items())
            # Get the interface instance referenced by fwid and iface properties
            if "fwid" not in connection_dict:
                raise OmittedPropertyError("connection", "fwid")
            elif "iface" not in connection_dict:
                raise OmittedPropertyError("connection", "iface")
            else:
                interface_key = "%s,%s" %(connection_dict["fwid"], connection_dict["iface"])
                if interface_key not in self.__interface_dict:
                    raise KeyError("No interface with ID #%s in firewall #%s" % (
                        connection_dict["iface"], connection_dict["fwid"]))
                interface = self.__interface_dict[interface_key]
                del connection_dict["fwid"]
                del connection_dict["iface"]
            # Check direct property
            direct = connection_dict.get("direct")
            del connection_dict["direct"]
            if direct is None:
                raise OmittedPropertyError("connection", "direct")
            if direct not in self.direct_set:
                raise ValueError('Invalid direct value in "connection" tag: "%s"; must be one of these values : "%s"'
                                       %(direct, '", "'.join(self.direct_set)))
            self.direct = (direct != "0")
            # Construct the appropriate Connection instance
            if self.direct:
                if "dftgateway" in connection_dict:
                    connection_dftgateway = self._get_ipv4_address("connection", "dftgateway", connection_dict["dftgateway"])
                    InternetConnection(network, interface, connection_dftgateway)
                    del connection_dict["dftgateway"]
                else:
                    DirectConnection(network, interface)
            else:
                if network.remote:
                    RoutedConnection(network, interface, network.remote)
                elif "gateway" in connection_dict:
                    connection_gateway = self._get_ipv4_address("connection", "gateway", connection_dict["gateway"])
                    RoutedConnection(network, interface, connection_gateway)
                    del connection_dict["gateway"]
                else:
                    raise OmittedPropertyError("connection", "gateway")
            # Warn superfluity properties
            for property, value in connection_dict.iteritems():
                warn(UselessWarning('Unexpected property in "connection" tag: %s="%s"'
                                              %(property, value)),stacklevel=2)

    def _get_ipv4_address(self, tag, property, value):
        if not self.ipv4_address_regex.match(str(value)):
            raise ValueError('%s property from "%s" tag has not a clean IPv4 syntax: "%s".'
                                         %(property, tag, value))
        try:
            ipy_instance = IP(str(value))
        except ValueError, e:
            if self.ipv4_address_regex.match(str(value)):
                raise ValueError('Invalid %s value in "%s" tag: "%s".'
                                                           %(property, tag, e))
            raise ValueError('Invalid %s value in "%s" tag: "%s"; must be an IP address.'
                                                       %(property, tag, value))
        return ipy_instance

    def _get_ipv4_network(self, tag, property, value):
        try:
            ipy_instance = IP(str(value))
        except ValueError, e:
            if self.ipv4_network_regex.match(str(value)):
                raise ValueError('Invalid %s value in "%s" tag: "%s".'
                                                           %(property, tag, e))
            raise ValueError('Invalid %s value in "%s" tag: "%s"; must be an IP address with its netmask.'
                                                        %(property, tag, value))
        if not self.ipv4_network_regex.match(str(value)):
            raise ValueError('%s property from "%s" tag has not a clean IPv4 syntax: "%s".'
                                                       %(property, tag, value))
        return ipy_instance

    def _check_mandatory_properties(self, tag, prop_dict):
        for property, value in prop_dict.iteritems():
            if value is None:
                raise OmittedPropertyError(tag, property)


def load(file):
    """
    Generate model objects from a given XML file.

    @param file: XML file-like object to import.
    @type file: file
    @rtype: descmodels.Desc
    """
    return Loader(file).build()


class Dumper:

    def __init__(self, desc):
        self.desc = desc
        self.doc = None

    def build(self):
        xml_network = etree.Element("network", version=self.desc.version, applied="1")

        self.build_fws(xml_network)
        self.build_nets(xml_network)

        indent(xml_network)
        self.doc = etree.ElementTree(xml_network)
        #doc.docinfo.xml_version = "1.0"

    def build_fws(self, xml_network):
        xml_fws = etree.SubElement(xml_network, "fws")
        for firewall in self.desc.firewalls:
            firewall_dict = {}
            for (k, v) in firewall.__dict__.iteritems():
                if v is None or callable(v):
                    continue
                elif k in ('type', 'name', 'queue', 'id'):
                    firewall_dict[k] = str(v)
            xml_fw = etree.SubElement(xml_fws, "fw", firewall_dict)
            self.build_interfaces(xml_fw, firewall)

    def build_interfaces(self, xml_fw, firewall):
        xml_interfaces = etree.SubElement(xml_fw, "interfaces")
        for interface in firewall.interfaces:
            interface_dict = {}
            for (k, v) in interface.__dict__.iteritems():
                if v is None or callable(v):
                    continue
                elif k in ('name', 'id'):
                    interface_dict[k] = str(v)
                elif k in ('is_vlan',):
                    interface_dict[k] = str(int(v))
            if compareVersion(self.desc.version, "1.1") >= 0:
                xml_interface = etree.SubElement(xml_interfaces, "interface", interface_dict)
                self.build_addresses(xml_interface, interface)
            else:
                address = list(iter(interface))
                if len(address) != 1:
                    raise ValueError('Interface "%s" has %s addresses instead of exactly one' \
                        % (interface.name, len(address)))
                address = address[0]
                interface_dict['addr'] = str(address.addr)
                etree.SubElement(xml_interfaces, "interface", interface_dict)

    def build_addresses(self, xml_interface, interface):
        for address in interface:
            address_dict = {}
            for (k, v) in address.__dict__.iteritems():
                if v is None or callable(v):
                    continue
                elif k in ('addr', 'id', 'type'):
                    address_dict[k] = str(v)
            etree.SubElement(xml_interface, "address", address_dict)

    def build_nets(self, xml_network):
        xml_nets = etree.SubElement(xml_network, "nets")
        for network in self.desc.networks:
            network_dict = {}
            for (k, v) in network.__dict__.iteritems():
                if v is None or callable(v):
                    continue
                if k in ('type', 'name', 'id', 'addr', 'remote', 'local_id'):
                    v = str(v)
                elif k == 'enabled':
                    v = str(int(v))
                else:
                    continue

                if k == 'addr' and '/' not in v:
                    v += '/32'
                network_dict[k] = v
            xml_net = etree.SubElement(xml_nets, "net", network_dict)
            self.build_connections(xml_net, network)

    def build_connections(self, xml_net, network):
        for connection in network:
            connection_dict = {}
            for (k, v) in connection.__dict__.iteritems():
                if v is None or callable(v):
                    continue
                elif k in ('gateway',):
                    connection_dict[k] = str(v)
                elif k == 'default_gateway':
                    connection_dict["dftgateway"] = str(v)
                connection_dict["iface"] = str(connection.interface.id)
                connection_dict["fwid"] = str(connection.interface.firewall.id)
                connection_dict["direct"] = isinstance(connection, DirectConnection) and "1" or "0"
            etree.SubElement(xml_net, "connection", connection_dict)

    def write(self, file):
        #self.doc.write(file, "utf-8")
        self.doc.write(file)
        file.write("\n")


def dump(desc, file):
    """
    Generate an XML file from a given desc model.

    @param desc: data to export into XML file.
    @type desc: descmodels.Desc
    @param file: XML file-like object to export to.
    @type file: file
    """
    dumper = Dumper(desc)
    dumper.build()
    dumper.write(file)


def indent(elem, level=0):
    i = "\n" + level*"  "
    if len(elem):
        if not elem.text or not elem.text.strip():
            elem.text = i + "  "
        for elem in elem:
            indent(elem, level+1)
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
    else:
        if level and (not elem.tail or not elem.tail.strip()):
            elem.tail = i


class OmittedPropertyError(ValueError):
    def __init__(self, tag, property):
        ValueError.__init__(self, 'Omitted %s property from "%s" tag; this property is mandatory.'
                                               %(property, tag))
        self.tag = tag
        self.property = property


if __name__ == "__main__":
    import sys
    from StringIO import StringIO
    s = StringIO()
    dump(load(open(sys.argv[1], 'r')), s)
    print s.getvalue()