# starsim/universedb.py
"""
StarSim universe database module

Keeps track of all of the objects in the universe by ID and by location.
Loads and saves the universe from and to files.
"""

import ConfigParser
import cPickle
import csv
import glob
from math import floor
import os
import re
import sets
import types

import starsim
from starsim.server import model
from starsim.const import SCAN_LIMIT

UNIVERSE_DB_VERSION = 1


def create(module_name=None):
    """
    Create a new Universe object and populate it with items.

    module_name:
        The name of a module containing a create() function, that will
        create the universe. If no module name is given, a default universe
        is created.

        The module name is the name of a Python .py file, without the .py
        part. The module is searched for first in the starsim.scenarios
        package, and then sys.path is searched in the normal manner.
    """
    if not module_name:
        return createDefaultUniverse()
    mod = _getModule('starsim.scenarios.' + module_name)
    return mod.create()


def createDefaultUniverse():
    """
    Create a small, default universe.

    Call the create() function in one of the modules in the
    starsim.scenarios package to create more interesting universes.
    """
    u = Universe()
    # create center starbase
    u.home_base = model.StarBase(u, (0.0, 0.0), 
                    color='lightgreen', name='EarthBase')
    # Plain starship
    #model.StarShip(u, (+0.100, +0.100), color='gray', 
    #                user=None, angle=0.0, vel=(0.0, 0.0), va=0.0)
    model.Garrett1(u, (+0.200, +0.200), color='red', 
                   user=None, angle=0.0, vel=(0.0, 0.0), va=0.0)
    model.Nathan1(u, (+0.200, -0.200), color='magenta', 
                   user=None, angle=0.0, vel=(0.0, 0.0), va=0.0)
    model.Nathan2(u, (-0.200, -0.200), color='blue', 
                   user=None, angle=0.0, vel=(0.0, 0.0), va=0.0)
    model.Evan1(u, (-0.200, +0.200), color='green', 
                   user=None, angle=0.0, vel=(0.0, 0.0), va=0.0)
    return u


def _getModule(module_name):
    mod = __import__(module_name)
    components = module_name.split('.')
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod


def load(data_file):
    """
    Load a Universe from a file.
    """
    print 'Loading "%s"' % data_file
    f = file(data_file, 'rb')
    u = cPickle.Unpickler(f)
    header = u.load()
    program = header.get('program')
    version = header.get('version')
    print "Data file was saved by %s version %s" % (program, version)
    universe_db_version = header.get('universe_db_version')
    if universe_db_version is None:
        msg = "Data file does not contain universe data"
        print >> sys.stderr, msg
        raise IOError(msg)
    if universe_db_version <> UNIVERSE_DB_VERSION:
        msg = "Universe db version in file is %s, expecting %s" % (
            universe_db_version, UNIVERSE_DB_VERSION)
        print >> sys.stderr, msg
        raise IOError(msg)
    universe = u.load()
    f.close()
    universe.printStats()
    return universe


def save(data_file, universe):
    """
    Save a Universe to a file.
    """
    print 'Saving "%s"' % data_file
    f = file(data_file, 'wb')
    p = cPickle.Pickler(f, 2) # protocol 2 required: new-style classes
    header = {
        'program': 'StarSim',
        'version': starsim.__version__,
        'universe_db_version': UNIVERSE_DB_VERSION,
        }
    p.dump(header)
    p.dump(universe)
    f.close()
    universe.printStats()


class Sector(dict):
    # If I were guarenteed Python 2.4, I would use the built-in set object.
    """
    Collection of objects in a sector.
    Keeps track of its own sector key.

    Only do the following with instances of this class:
        Read the key attribute
        Call add and discard methods
        Use list(Sector()) to get contents
        Evaluate as true/false to see if it is empty
    Anything else will void the warranty!
    """

    def __init__(self, key):
        super(Sector, self).__init__()
        self.key = key

    def add(self, obj):
        self[obj] = True

    def discard(self, obj):
        self.pop(obj, None)


class SectorDB:
    """
    Database for storing objects indexed by location.

    Doctest:
    >>> import starsim.universedb as udb
    >>> sdb = udb.SectorDB()
    >>> import starsim.model as m
    >>> s1 = m.StarShip(1, (1.5, -1.5))
    >>> sdb._calcSectorKey(s1.loc)
    (1, -2)
    >>> sdb.storeObject(s1)
    >>> s1._SectorDB__sector.key
    (1, -2)
    >>> s1.loc = (-1.5, 0.5)
    >>> sdb.updateObject(s1)
    >>> s1._SectorDB__sector.key
    (-2, 0)
    >>> display_list = sdb.getDisplayList((-1.25, 0.25))
    >>> print len(display_list)
    1
    >>> print display_list[0][:2]
    (1, (-1.5, 0.5))
    """

    # This constant is used for calculating the sector key from a location.
    SECTOR_SIZE = SCAN_LIMIT * 2

    def __init__(self):
        self.__db = {} # { sector_key: sector }
        self.__display_cache = {} # { sector_key: display_list }

    def clearDisplayCache(self):
        self.__display_cache = {} # { sector_key: display_list }
    
    def deleteObject(self, sim_obj):
        """
        Delete an object from the sector database.
        """
        try:
            sector = sim_obj.__sector
        except AttributeError:
            return
        sector.discard(sim_obj)
        if not sector:
            # Last one out turn off the lights
            self.__db.pop(sector.key, None)

    def getDisplayList(self, loc):
        """
        Return a sequence of display info tuples within SCAN_LIMIT units
        of the location at (loc).
        """
        x, y = loc
        result = []
        limit = SCAN_LIMIT
        key_min_x, key_min_y = self._calcSectorKey((x - limit, y - limit))
        key_max_x, key_max_y = self._calcSectorKey((x + limit, y + limit))
        for i in xrange(key_min_x, key_max_x + 1):
            for j in xrange(key_min_y, key_max_y + 1):
                for display_info in self._getSectorDisplayList((i, j)):
                    obj_loc = display_info[1]
                    if (abs(obj_loc[0] - x) > limit or
                        abs(obj_loc[1] - y) > limit):
                        continue
                    result.append(display_info)
        return result

    def updateObject(self, sim_obj):
        """
        Update the database with the new position of an object.
        """
        try:
            old_sector = sim_obj.__sector
        except AttributeError:
            self.storeObject(sim_obj)
            return
        sector_key = self._calcSectorKey(sim_obj.loc)
        new_sector = self.__db.get(sector_key)
        if old_sector is new_sector:
            return
        if not new_sector:
            new_sector = Sector(sector_key)
            self.__db[sector_key] = new_sector
        sim_obj.__sector = new_sector
        new_sector.add(sim_obj)
        old_sector.discard(sim_obj)
        if not old_sector:
            # Last one out turn off the lights
            self.__db.pop(old_sector.key, None)

    def storeObject(self, sim_obj):
        """
        Store an object in the sector database for the first time.
        """
        sector_key = self._calcSectorKey(sim_obj.loc)
        sector = self.__db.get(sector_key)
        if sector is None:
            sector = Sector(sector_key)
            self.__db[sector_key] = sector
        sim_obj.__sector = sector
        sector.add(sim_obj)

    def _calcSectorKey(self, loc):
        return (int(floor(loc[0] / self.SECTOR_SIZE)), 
                int(floor(loc[1] / self.SECTOR_SIZE)))

    def _getSectorDisplayList(self, sector_key):
        display_list = self.__display_cache.get(sector_key)
        if display_list is not None:
            return display_list
        display_list = []
        # I don't use an iterator here any more because I once got a
        # "RuntimeError: dictionary changed size during iteration" message.
        for sim_obj in list(self.__db.get(sector_key, ())):
            if sim_obj.private:
                continue
            display_list.append(sim_obj.getView())
        self.__display_cache[sector_key] = display_list
        return display_list


class Universe:

    """
    Class responsibilities:
    * Maintain a database of objects queryable by id, by location, by state
      (active or not active), and by owning user.
    * Provide a display list cache such that during an update cycle, when
      many objects are getting display lists, the results are not computed
      more than once for any sector (each sector may be filtered multiple
      times, however).
    * Provide queries and operations on groups of objects in the database.
    * Keep track of the current simulation time (number of updates since the
      simulation began).
    * Must be picklable for persistence.
    """

    def __init__(self):
        self.__version__ = starsim.__version__
        self.ticks = 0
        self.__cur_id = 0
        self.__objects = {} # { id: sim_obj }
        self.__sector_db = SectorDB()
        self.__active_objs = {} # { sim_obj: True }
        self.__user_obj_map = {} # { user: [sim_obj,] }
        self.home_base = None # a create() method should fill this in

    def __getstate__(self):
        state = self.__dict__.copy()
        # delete unpicklable attributes (if any)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        # restore unpickelable attributes (if any)

    def deleteObject(self, sim_obj):
        if not sim_obj:
            return
        self.__objects.pop(sim_obj.id, None)
        self.__sector_db.deleteObject(sim_obj)
        self.__active_objs.pop(sim_obj, None)
        user = getattr(sim_obj, 'user', None)
        try:
            self.__user_obj_map.pop(user, []).remove(sim_obj)
        except ValueError:
            pass

    def getActiveObjects(self):
        return list(self.__active_objs)

    def getDisplayList(self, loc):
        """
        Return a sequence of tuples describing the visible state of all
        objects within scan range of location 'loc'.

        Implementation note: a cache speeds this up. The update() method
        calls clearDisplayCache() to refresh the cache.
        """
        return self.__sector_db.getDisplayList(loc)

    def getObject(self, id):
        """
        Get an object by it's persistent integer ID.
        Return None if there is no object (any more) with that ID.
        """
        return self.__objects.get(id)

    def getObjIds(self):
        """Use this method to avoid iterating over the object map directly.
        Bad things happen when something changes the map while you are
        iterating on it."""
        return self.__objects.keys()

    def getObjsForUser(self, user):
        """
        Return a sequence of all of the objects in the universe owned by the
        user. (Objects that have a user attribute that matches the user
        parameter - not all objects have a user attribute.)
        """
        return self.__user_obj_map.get(user, [])

    def nextId(self):
        self.__cur_id += 1
        return self.__cur_id

    def registerObject(self, sim_obj):
        """Register a new simulation object in the universe database.

        Special indexing is done based on the following attributes:
            sim_obj.active
            sim_obj.user
            sim_obj.loc  -- (x, y) location

        Return the new, unique ID of the object.
        """
        id = self.nextId()
        self.__objects[id] = sim_obj
        self.__sector_db.storeObject(sim_obj)
        if getattr(sim_obj, 'active', False):
            self.__active_objs[sim_obj] = True
        user = getattr(sim_obj, 'user', None)
        if user:
            self.__user_obj_map.setdefault(user, []).append(sim_obj)
        return id

    def printStats(self):
        """
        Print out some statistics about the universe.
        """
        class_count_map = {}
        obj_ids = self.getObjIds()
        for id in obj_ids:
            sim_obj = self.getObject(id)
            if not sim_obj:
                continue
            class_name = sim_obj.__class__.__name__
            class_count_map[class_name] = class_count_map.get(class_name, 0) + 1
        class_list = class_count_map.keys()
        class_list.sort()
        for class_name in class_list:
            print "%-20.20s %5.5d" % (class_name, class_count_map[class_name])
        print len(obj_ids), "objects in universe"

    def queryInitInfo(self, obj_ids):
        """
        Get invariant view information about a list of objects:
            (id, ClassName, ...)
        """
        result = []
        for id in obj_ids:
            sim_obj = self.__objects.get(id, None)
            if sim_obj:
                result.append(sim_obj.getInitInfo())
            else:
                result.append(None)
        return result

    def setObjectUser(self, sim_obj, user):
        old_user = getattr(sim_obj, 'user', None)
        if user <> old_user:
            sim_obj.user = user
            if old_user:
                try:
                    self.__user_obj_map.pop(user, []).remove(sim_obj)
                except ValueError:
                    pass
            if user:
                self.__user_obj_map.setdefault(user, []).append(sim_obj)

    def update(self):
        """
        Update the state of every active object in the universe.
        Return the number of updates (time ticks) since the universe
        started.
        """
        self.ticks += 1
        for sim_obj in self.getActiveObjects():
            old_loc = sim_obj.loc
            sim_obj.update()
            if sim_obj.loc <> old_loc:
                self.__sector_db.updateObject(sim_obj)
        self.__sector_db.clearDisplayCache()
        return self.ticks


def test():
    import doctest
    import sys
    module = sys.modules[__name__]
    doctest.testmod(module)


if __name__ == '__main__':
    test()

# end-of-file
