########################################################################
#
# File Name:            DbmDatabase.py
#
# Documentation:        http://docs.4suite.org/Lib/DbmDatabase.py.html
#
"""
A utility class for anydbm databases.
WWW: http://4suite.org/         e-mail: support@4suite.org

Copyright (c) 1999 Fourthought Inc, USA.   All Rights Reserved.
See  http://4suite.org/COPYRIGHT  for license and copyright information
"""

import os, string, shutil, sys, re
import stat, time
from distutils import dir_util
from cPickle import dumps, loads
from Ft import LOCALSTATEDIR

__version__ = '0.5.0'

class DbmError(Exception):
    def __init__(self, cmd, msg):
        self.cmd = cmd
        self.msg = msg

    def __str__(self):
        return '%s: %s' % (self.cmd, self.msg)

FT_DATABASE = 'ft_database'

DATABASE_DIR = os.path.join(LOCALSTATEDIR, 'DbmDatabases')
DATABASE_DIR = os.environ.get('FT_DATABASE_DIR', DATABASE_DIR)

# Normalize path, remove unnecessary slashes
DATABASE_DIR = os.path.abspath(DATABASE_DIR)

def SysInit():
    version_file = os.path.join(DATABASE_DIR, 'VERSION')
    if os.path.exists(version_file):
        if open(version_file).read() != __version__:
            raise DbmError('BOOTSTRAP',
                           'Different version of DbmDatabase at %s' %
                           DATABASE_DIR)
    else:
        dir_util.mkpath(DATABASE_DIR)
        f = open(version_file,'w')
        f.write(__version__)
        f.close()
    return

try:
    SysInit()
except DbmError:
    pass

CREATE = 'c'
NEW = 'n'
WRITEABLE = 'w'
READONLY = 'r'

class Table:

    def __init__(self, name, file):
        self._name = name
        self._file = file
        self._mtime = -1
        self._inTx = 1
        self._cached = {}
        self._added = {}
        self._deleted = []
        #self._update()

    def __repr__(self):
        return "<DbmTable at %x: %s>" % (id(self), self._name)

    def _update(self):
        if not self._inTx:
            raise DbmError('TRANSACTION ERROR', "Transaction not in progress")
        if os.path.isfile(self._file):
            mtime = os.stat(self._file)[stat.ST_MTIME]
            if mtime >= self._mtime:
                self._mtime = mtime
                f = open(self._file, 'r')
                self._data = eval(f.read())
                f.close()
        else:
            f = open(self._file, 'w')
            f.write(repr({}))
            f.close()
            self._mtime = os.stat(self._file)[stat.ST_MTIME]
            self._data = {}

    def getName(self):
        return self._name

    def rollback(self):
        if not self._inTx:
            raise DbmError('TRANSACTION ERROR', "Transaction not in progress")
        self._inTx = 0

    def checkpoint(self):
        self._update()
        # We need to save loaded entries because to objects
        # they point to might have been changed
        for k,v in self._cached.items():
            self._data[k] = dumps(v, 1)

        # Add any new entries
        for k,v in self._added.items():
            self._data[k] = dumps(v, 1)
        self._cached.update(self._added)
        self._added.clear()

        # Remove any deleted items from the database itself
        for k in self._deleted:
            if self._data.has_key(k):
                del self._data[k]
        self._deleted = []

        f = open(self._file, 'wb')
        f.write(repr(self._data))
        f.close()
    
    def commit(self):
        self.checkpoint()
        self._inTx = 0

    def keys(self):
        self._update()
        keys = self._data.keys()
        keys.extend(self._added.keys())
        return filter(lambda k, d=self._deleted: k not in d, keys)

    def values(self):
        return map(lambda k, s=self: s[k], self.keys())

    def items(self):
        return map(lambda k, s=self:(k, s[k]), self.keys())

    def has_key(self, key):
        return key in self.keys()

    def get(self, key, default=None):
        if self.has_key(key):
            return self[key]
        return default

    def has_key_regex(self, key_patt, ignoreCase=0):
        patt = re.compile(key_patt + '$', ignoreCase and re.I)
        matched = filter(lambda key, pattern=patt:
                         pattern.match(key),
                         self.keys())
        return matched and 1 or 0

    def get_regex(self, key_patt, ignoreCase=0, default=None):
        patt = re.compile(key_patt + '$', ignoreCase and re.I)
        matched = filter(lambda key, pattern=patt:
                         pattern.match(key),
                         self.keys())
        return matched and matched[0] or default                       
        
    def __len__(self):
        self._update()
        length = len(self._data)
        length = length + len(self._added)
        length = length - len(self._deleted)
        return length

    def __getitem__(self, key):
        if not self._inTx:
            raise DbmError('TRANSACTION ERROR', "Transaction not in progress")
        data = None
        if key in self._deleted:
            raise KeyError, key
        if self._cached.has_key(key):
            return self._cached[key]
        if self._added.has_key(key):
            return self._added[key]
        self._update()
        data = loads(self._data[key])  # may raise KeyError
        self._cached[key] = data
        return data

    def __setitem__(self, key, value):
        if not self._inTx:
            raise DbmError('TRANSACTION ERROR', "Transaction not in progress")
        if key in self._deleted:
            self._deleted.remove(key)
        if self._cached.has_key(key):
            self._cached[key] = value
        else:
            self._added[key] = value

    def __delitem__(self, key):
        if not self._inTx:
            raise DbmError('TRANSACTION ERROR', "Transaction not in progress")
        if not self.has_key(key):
            raise KeyError, key
        self._deleted.append(key)
        if self._cached.has_key(key):
            del self._cached[key]
        if self._added.has_key(key):
            del self._added[key]

    def __del__(self):
        try:
            if self._inTx:
                self.rollback()
        except:
            pass


def CreateDatabase(dbName):
    CheckVersion()
    dbpath = os.path.join(DATABASE_DIR, dbName)
    if os.path.isdir(dbpath):
        raise DbmError('CREATE DATABASE', 'Database "%s" already exists' % dbName)

    dir_util._path_created = {}
    dir_util.mkpath(str(dbpath))
    return Database(dbName)

def DropDatabase(dbName):
    CheckVersion()
    dbpath = os.path.join(DATABASE_DIR, dbName)
    if not os.path.isdir(dbpath):
        raise DbmError('DROP DATABASE',
                       "Database '%s' does not exist (or is invalid)" % dbName)

    #Delete all of the files
    for f in os.listdir(dbpath):
        os.unlink(os.path.join(dbpath,f))
    os.rmdir(dbpath)
    
def DatabaseExists(dbName):
    CheckVersion()
    dbpath = os.path.join(DATABASE_DIR, dbName)
    return os.path.isdir(dbpath)

def GetAllDatabaseNames():
    #Get all from the file system
    CheckVersion()
    names = []
    for name in os.listdir(DATABASE_DIR):
        if os.path.isdir(os.path.join(DATABASE_DIR, name)):
            # Directory listings must be unique
            names.append(name)
    return names


def CheckVersion():
    version_file = os.path.join(DATABASE_DIR, 'VERSION')
    if not os.path.exists(version_file):
        raise "No database at %s" % DATABASE_DIR
    f = open(version_file,'r')
    if f.read() != __version__:
        raise "Wrong Database version"


class Database:
    """Each instance of a database represents a single transactionm
    """
    def __init__(self, name):
        """
        initialize the database object.
        """
        self._name = name
        CheckVersion()
        self._dbpath = os.path.join(DATABASE_DIR, name)

        if not os.path.exists(self._dbpath):
            raise DbmError('CONNECT DATABASE', 'Database "%s" does not exist' % name)

        self.__tables = {}
        self._inTx = 1

    def _clear(self):
        self.__tables = {}

    def getName(self):
        return self._name

    def createTable(self, table):
        if not self._inTx:
            raise DbmError('TRANSACTION ERROR', 'Transaction Not in Progress')
        table_file = os.path.join(self._dbpath, table)
        if os.path.exists(table_file):
            raise DbmError('CREATE TABLE', 'Table %s Already exists' % table)
            
        t = Table(table,table_file)
        self.__tables[table] = t
        return t

    def getAllTableNames(self):
        names = {}
        # Get the in-memory table names
        for table in self.__tables.values():
            names[table.getName()] = 1
        # Get all from the file system
        for table in os.listdir(self._dbpath):
            names[os.path.splitext(table)[0]] = 1
        return names.keys()

    keys = getAllTableNames

    def getTable(self, table):
        if self.__tables.has_key(table):
            return self.__tables[table]
        table_file = os.path.join(self._dbpath, table)
        if not os.path.isfile(table_file):
            raise DbmError('GET TABLE', 'Table %s does not exists in database %s' % (table,self._name))
            
        t = Table(table,table_file)
        self.__tables[table] = t
        return t

    __getitem__ = getTable
    
    def dropTable(self, table):
        table_file = os.path.join(self._dbpath, table)
        if os.path.exists(table_file):
            os.remove(table_file)
            return 1
        return 0

    __delitem__ = dropTable


    def rollback(self):
        #Tell all of our tables to roll back
        for t in self.__tables.values():
            t.rollback()
        self._clear()
        self._inTx = 0


    def checkpoint(self):
        #Tell all of our tables to roll back
        for t in self.__tables.values():
            t.checkpoint()

    def commit(self):
        for t in self.__tables.values():
            t.commit()
        self._clear()
        self._inTx = 0


    def __del__(self):
        try:
            if self._inTx:
                self.rollback()
        except:
            pass
        self._clear()

if __name__ == '__main__':
    import sys
    if len(sys.argv) == 4:
        d = Database(sys.argv[1])
        t = d[sys.argv[2]]
        print t[sys.argv[3]]
    elif len(sys.argv) == 3:
        d = Database(sys.argv[1])
        t = d[sys.argv[2]]
        print t.keys()
    elif len(sys.argv) == 2:
        d = Database(sys.argv[1])
        print d.getAllTableNames()
    else:
        print GetAllDatabaseNames()
