#!/usr/bin/env python
#----------------------------------------------------------------------------
# Project Name: MOSS
# File Name: live.py
# Description: Live streaming routine.
#
# Created: 2004. 06. 26
# RCS-ID: $Id: live.py,v 1.69 2004/10/17 08:02:55 myunggoni Exp $
# Copyright: (c) 2004 by myunggoni
# License: GNU General Public License
# Author: Myung-Gon Park <myunggoni@users.kldp.net>
#----------------------------------------------------------------------------

import ConfigParser
import time
import random
import os
import string
import fileinfo
import socket
import select
import sys
import os
import threading
import re
import pwd
import fileinfo
from moss import __package__, __version__

class Timer:
	def __init__(self, interval):
		self.stalltimes = []
		self.interval = interval

	def stall(self, seconds):
		self.stalltimes.append(seconds)

	def sleep(self):
		seconds = self.interval

		for stalltime in self.stalltimes:
			if not stalltime > 0:
				pass
			elif stalltime < seconds:
				seconds = stalltime

		time.sleep(seconds)
		self.stalltimes = []

class Source:
	def __init__(self, files, timer, server):
		self.files = files
		self.timer = timer
		self.server = server
		self.clients = []
		self.stream = None
		self.buffer = []
		self.position = 0

		self.new_file()

	def new_file(self):
		self.file = None

		while not self.file:
			self.file = self.next_file()

			if not os.path.exists(self.file) or not os.path.isfile(self.file):
				self.file = None

		self.stream = FileReader(self.file, self.timer, self.server.read_size, self.server.stream_type)
		self.file_headers = None
	
	def next_file(self):
		if self.position == len(self.files):
			self.position = 0

		file = self.files[self.position]
		self.position = self.position + 1

		return file

	def next_data(self):
		data = self.stream.read()
		
		if len(data) == 0:
			self.stream.close()
			self.new_file()
			
		self.stream.stall()
		
		return data

	def serve_file(self):
		while not self.stream.caught_up():
			data = self.next_data()
			self.send_all(data)

	def send_all(self, data):
		for client in self.clients:
			client.add_to_buffer(data)

	def http_headers(self):
		if self.server.stream_type == "mp3":
			header = "HTTP/1.0 200 OK\nContent-Type: audio/mpeg\n\n"
		elif self.server.stream_type == "ogg":
			header = "HTTP/1.0 200 OK\nContent-Type: application/ogg\n\n"
		else:
			sys.stderr.write("Invalid stream format\n")
			raise SystemExit

		return header
	def stream_headers(self):
		if self.server.stream_type == "ogg":
			if self.file_headers is None:
				try:
					file = open(self.file, "rb")
				except IOError:
					sys.stderr.write("Open failed\n")
					raise SystemExit
				
				self.file_headers = file.read(4422) # Adjust to header size.
				file.close()
		elif self.server.stream_type == "mp3":
			self.file_headers = ""
		else:
			sys.stderr.write("Invalid stream format\n")
			raise SystemExit
		
		return self.file_headers

class FileReader:
	def __init__(self, path, timer, read_size, stream_type):
		try:
			self.handle = open(path, "rb")
		except IOError:
			sys.stderr.write("Open failed\n")
			raise SystemExit

		base, ext = os.path.splitext(path)

		if string.lower(ext) == ".mp3":
			self.meter = MP3Meter(self.handle, path, timer, read_size, stream_type)
		elif string.lower(ext) == ".ogg":
			self.meter = OggMeter(self.handle, path, timer, read_size, stream_type)
		else:
			sys.stderr.write("Invalid file format\n")
			raise SystemExit
	
	def read(self):
		return self.meter.read()
	
	def close(self):
		self.meter.close()

	def stall(self):
		self.meter.stall()

	def caught_up(self):
		return self.meter.caught_up()

class Meter:
	def __init__(self, timer):
		self.data_time = 0
		self.real_time = 0
		self.start_time = time.time()
		self.timer = timer

	def caught_up(self):
		self.real_time = time.time() - self.start_time

		if self.data_time - self.real_time > 0:
			return True
		else:
			return False

	def stall(self):
		seconds = self.data_time - self.real_time

		if seconds > 0:
			self.timer.stall(seconds)
	
class MP3Meter(Meter):
	def __init__(self, handle, path, timer, read_size, stream_type):
		Meter.__init__(self, timer)
		self.handle = handle
		self.read_size = read_size
		self.stream_type = stream_type

		try:
			info = fileinfo.FileInfo(path)
		except InvalidFileException:
			sys.stderr.write("Invalid file format\n")
			raise SystemExit
			
		bitrate = info.bitrate
		self.byterate = float(bitrate * 1024) / 8

	def read(self):
		if self.stream_type == "mp3":
			self.real_time = time.time() - self.start_time

			data = self.handle.read(self.read_size)
			data_time_segment = (float(len(data)) / self.byterate)
			self.data_time = self.data_time + data_time_segment
		elif self.stream_type == "ogg":
			# Converting mp3 to ogg vorbis file format.
			data = ""
		else:
			sys.stderr.write("Invalid stream format\n")
			raise SystemExit

		return data

	def close(self):
		self.handle.close()

class OggMeter(Meter):
	def __init__(self, handle, path, timer, read_size, stream_type):
		Meter.__init__(self, timer)
		self.handle = handle
		self.read_size = read_size
		self.stream_type = stream_type

		try:
			info = fileinfo.FileInfo(path)
		except InvalidFileException:
			sys.stderr.write("Invalid file format\n")
			raise SystemExit
			
		bitrate = info.bitrate
		self.byterate = float(bitrate * 1024) / 8

	def read(self):
		if self.stream_type == "ogg":
			self.real_time = time.time() - self.start_time
			data = self.handle.read(self.read_size)
			data_time_segment = (float(len(data)) / self.byterate)
			self.data_time = self.data_time + data_time_segment
		elif self.stream_type == "mp3":
			# Converting ogg vorbis to mp3 file format.
			data = ""
		else:
			sys.stderr.write("Invalid stream format\n")
			raise SystemExit

		return data

	def close(self):
		self.handle.close()

class Client(threading.Thread):
	def __init__(self, sock, source, server):
		self.socket = sock
		self.address = self.socket.getpeername()
		self.source = source
		self.server = server
		self.buffer = [self.source.http_headers(), self.source.stream_headers()]
		self.last_send = time.time()
		threading.Thread.__init__(self)
		
	def add_to_buffer(self, data):
		buffer = self.buffer
		buffer.append(data)

		if len(buffer) > self.server.max_buffer:
			trim = len(buffer) - self.server.max_buffer
			self.buffer = buffer[trim:]
	
	def run(self):
		buffer = self.buffer
		running = True

		while running:
			if buffer:
				try:
					sent = self.socket.send(buffer[0])

					if sent == len(buffer[0]):
						del buffer[0]
					else:
						buffer[0] = buffer[0][sent:]

					self.last_send = time.time()
				except socket.error, error:
					running = False

					if error[0] == 32 or error[0] == 104:
						self.server.connection = self.server.connection - 1
						self.server.log_message("Connection closed - %s" % self.address[0])
						self.server.debug_message("Total %d connection(s) established" % (self.server.connection, ))
						self.socket.close()
						self.source.clients.remove(self)

			time.sleep(0.5)
	
	def fileno(self):
		return self.socket.fileno()
		
class Server:
	def __init__(self, config_file):
		self.connection = 0
		self.running = False
		self.config = ConfigParser.ConfigParser()

		config_defaults = self.config.defaults()
		
		config_defaults["port"] = "8162"
		config_defaults["hostname"] = ""
		config_defaults["log"] = "/tmp/moss.log"
		config_defaults["user"] = "nobody"
		config_defaults["group"] = "nobody"
		config_defaults["max-connection"] = "10"
		config_defaults["read-size"] = "8192"
		config_defaults["stream-type"] = "mp3"
		config_defaults["max-buffer"] = "81920"
		config_defaults["send-timeout"] = "3"
		config_defaults["shuffle"] = "0"
		config_defaults["recursive"] = "1"

		self.config.read(config_file)
		
		user = self.config.get("server", "user")
		group = self.config.get("server", "group")
		self.port = self.config.getint("server", "port")
		self.hostname = self.config.get("server", "hostname")
		self.log = self.config.get("server", "log")
		self.max_connection = self.config.getint("server", "max-connection")
		self.read_size = self.config.getint("server", "read-size")
		self.stream_type = self.config.get("server", "stream-type")
		self.max_buffer = self.config.getint("server", "max-buffer")
		self.send_timeout = self.config.getint("server", "send-timeout")
		self.shuffle = self.config.getint("server", "shuffle")
		self.recursive = self.config.getint("server", "recursive")
			
		if debug_mode:
			self.log = sys.stdout
		else:
			try:
				self.log = open(self.log, "a")
			except IOError:
				self.log = sys.stdout

		if not self.stream_type == "mp3" and not self.stream_type == "ogg":
			self.log_message("Invalid stream format")
			sys.exit(1)
			
		self.acls = []

		try:
			allowed = re.split(r"[\s\n,]+", self.config.get("acl", "allow"))
		except ConfigParser.NoOptionError:
			allowed = []

		for addr in allowed:
			if '/' in addr:
				addr, masklen = string.split(addr, '/')
				masklen = int(masklen)
			else:
				masklen = 32

			if not re.match(r"^\d+\.\d+\.\d+\.\d+$", addr):
				addr = socket.gethostbyname(addr)
			
			mask = ~((1 << (32 - masklen)) - 1)
			entry = (self.dot2int(addr), mask)

			if not entry in self.acls:
				self.acls.append(entry)

		try:
			auth_file = self.config.get("acl", "auth-file")
			
			if os.path.isfile(auth_file):
				fp = open(auth_file, "r")

				try:
					self.auth_table = {}

					for pair in fp.readlines():
						user, passwd = string.split(pair[:-1], ':')
						self.auth_table[user] = passwd
				except ValueError:
					pass

				fp.close()
			else:
				self.log_message("No such authenticate file")
		except IOError:
				self.log_message("Invalid authenticate file")
		except ConfigParser.NoOptionError:
			self.auth_table = {}

		self.dirs = []
		dirs = []

		for option in self.config.options("sources"):
			if option[:10] == "source-dir":
				dirs.append((int(option[10:]), self.config.get("sources", option)))
				
		if not dirs:
			self.log_message("Source directory not found")
			sys.exit(1)

		dirs.sort()

		for i in range(len(dirs)):
			dir = map(string.strip, string.split(dirs[i][1], ':'))
			
			if len(dir) == 1:
				label = os.path.basename(dir[0])
			else:
				label = dir[1]

			if not os.path.isdir(dir[0]):
				self.log_message("Invalid source directory")
				continue

			if string.find(label, '/') != -1:
				self.log_message("Invalid source label")
				continue

			self.dirs.append((dir[0], label))

		if not self.dirs:
			self.log_message("Source directory not found")
			sys.exit(1)
		
		self.timer = Timer(0.1)
		
		self.files = []
		
		for dir, lable in self.dirs:
			self.files = self.files + self.make_list(dir, self.recursive, self.shuffle)

		if not self.files:
			self.log_message("Source file not found")
			sys.exit(1)
			
		self.source = Source(self.files, self.timer, self)
		
		try:
			self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
	
			if hasattr(socket, "SOL_SOCKET") and hasattr(socket, "SO_REUSEADDR"):
				self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

			self.socket.bind((self.hostname, self.port))
		except socket.error, error:
			self.log_message("%s" % str(error[1]))
			sys.exit(1)

		if os.getuid() == 0:
			try:
				user_id = pwd.getpwnam(user)[2]
			except KeyError:
				self.log_message("Invalid user name")
				sys.exit(1)

			try:
				group_id = pwd.getpwnam(group)[2]
			except KeyError:
				self.log_message("Invalid group name")
				sys.exit(1)

			try:
				os.setgid(group_id)
			except OSError, error:
				self.log_message("Setgid failed: %s" % (error.strerror))
				sys.exit(1)

			try:
				os.setuid(user_id)
			except OSError, error:
				self.log_message("Setuid failed: %s" % (error.strerror))
				sys.exit(1)

	def server_close(self):
		self.running = False

	def serve_forever(self):
		self.socket.listen(5)
		self.running = True
		
		try:
			while self.running:
				self.accept()
				self.timer.sleep()
				self.source.serve_file()

				clients = self.source.clients
				criterion = lambda x, time = time: time.time() - x.last_send > self.send_timeout
				kickers = filter(criterion, clients)

				for client in kickers:
					self.connection = self.connection - 1
					self.log_message("Client kicked - %s" % client.address[0])
					self.debug_message("Total %d connection(s) established" % (self.connection, ))
					client.socket.close()
					self.source.clients.remove(client)

			self.socket.close()
			
			clients = self.source.clients
			
			for client in clients:
				client.socket.close()
				self.source.clients.remove(client)
		finally:
			self.log_message("Please wait while the remaining streams finish...")
			
			clients = self.source.clients
			
			for client in clients:
				client.socket.close()
				self.source.clients.remove(client)
			
	def accept(self):
		rfd, wfd, efd = select.select([self.socket], [], [], 0)
 
		if not rfd:
			return
		else:
			try:
				self.socket.settimeout(1)
				self.client_socket, self.client_address = self.socket.accept()
				self.socket.settimeout(None)
			except:
				self.socket.settimeout(None)
				return


		try:
			self.client_socket.settimeout(3)
			self.headers = self.client_socket.recv(1024)
			
			if not self.acl_ok(self.client_address[0]):
				self.send_error(403, "Forbidden")
				self.client_socket.close()
				return

			if self.auth_table:
				if not self.check_authorization():
					self.client_socket.close()
					return

			self.client_socket.settimeout(None)
		except:
			self.log_message("Connection timed out - %s" % self.client_address[0])
			self.client_socket.close()
			return
		
		if self.connection > self.max_connection:
			self.log_message("Max connection exceeded - %s" % self.client_address[0])
			self.client_socket.close()

		else: 
			self.connection = self.connection + 1
			self.log_message("Connection established - %s" % self.client_address[0])
			self.debug_message("Total %d connection(s) established" % (self.connection, ))
			client = Client(self.client_socket, self.source, self)
			self.source.clients.append(client)
			client.start()

	def acl_ok(self, ipaddr):
		if not self.acls:
			return True

		ipaddr = self.dot2int(ipaddr)

		for allowed, mask in self.acls:
			if (ipaddr & mask) == (allowed & mask):
				return True

		return False

	def check_authorization(self):
		auth_table = self.auth_table
		auth = self.get_header("Authorization")
		
		if auth:
			if string.lower(auth[:6]) == "basic ":
				import base64
				import md5

				[user, passwd] = string.split(
						base64.decodestring(string.split(auth)[-1]), ':')

				hash = md5.new()
				hash.update(passwd)
				passwd = hash.hexdigest()

			if auth_table.has_key(user) and auth_table[user] == passwd:
				self.log_message("Authenticated user: %s - %s" % (user, self.client_address[0]))
				return True
			else:
				self.log_message("Authorization failed: %s - %s" % (user, self.client_address[0]))
				return False
		else:
			realm = "MOSS"
			self.send_response(401)
			self.send_header("WWW-Authenticate", "basic realm=\"%s\"" % (realm, ))
			self.end_headers()
			return False

	def get_header(self, header):
		parts = string.split(self.headers, '\n')

		for part in parts:
			s = string.split(part, ':')

			if s[0] == header:
				return string.strip(s[1])
		else:
			return None

	def send_error(self, code, message = None):
		try:
			short, long = self.responses[code]
		except KeyError:
			short, long = "???", "???"
		
		if message is None:
			message = short
		
		explain = long
		content = (self.error_message_format %
			{'code': code, 'message': message, 'explain': explain})
		
		self.send_response(code, message)
		self.send_header("Content-Type", "text/html")
		self.send_header('Connection', 'close')
		self.end_headers()
		
		if code >= 200 and not code in (204, 304):
			total_sent = 0

			while total_sent < len(content):
				try:
					sent = self.client_socket.send(content[total_sent:])
				except:
					return
				
				if sent == 0:
					return

				total_sent += sent

	def send_response(self, code, message = None):
		if message is None:
			if code in self.responses:
				message = self.responses[code][0]
			else:
				message = ''

		content = "%s %d %s\n" % (self.protocol_version, code, message)
		total_sent = 0

		while total_sent < len(content):
			try:
				sent = self.client_socket.send(content[total_sent:])
			except:
				return
				
			if sent == 0:
				return

			total_sent += sent

		self.send_header("Server", self.version_string())
		self.send_header("Date", self.date_time_string())
		
	def send_header(self, keyword, value):
		content = "%s: %s\n" % (keyword, value)
		total_sent = 0

		while total_sent < len(content):
			try:
				sent = self.client_socket.send(content[total_sent:])
			except:
				return
				
			if sent == 0:
				return

			total_sent += sent

	def end_headers(self):
		content = "\n"
		total_sent = 0

		while total_sent < len(content):
			try:
				sent = self.client_socket.send(content[total_sent:])
			except:
				return
				
			if sent == 0:
				return

			total_sent += sent
	
	def version_string(self):
		return self.server_version

	def date_time_string(self):
		now = time.time()
		year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now)
		s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (self.weekdayname[wd], 
		day, self.monthname[month], year, hh, mm, ss)
		
		return s

	def log_date_time_string(self):
		now = time.time()
		year, month, day, hh, mm, ss, x, y, z = time.localtime(now)
		s = "%04d/%02d/%02d %02d:%02d" % (year, month, day, hh, mm)
		
		return s
						
	def make_list(self, fullpath, recursive, shuffle, files = None):
		if files is None:
			files = []

		list = os.listdir(fullpath)
		list.sort()

		for name in list:
			base, ext = os.path.splitext(name)

			if music_extensions.has_key(string.lower(ext)):
				files.append(os.path.join(fullpath, name))

			if recursive and os.path.isdir(os.path.join(fullpath, name)):
				files = self.make_list(os.path.join(fullpath, name), recursive, shuffle, files)
				
		if shuffle:
			random.shuffle(files)
				
		return files
	
	def log_message(self, message):
		if self.log:
			try:
				self.log.write("%s [-] %s\n" % (self.log_date_time_string(), message))
				self.log.flush()
			except IOError:
				pass

	def debug_message(self, message):
		if debug_mode > 0:
			self.log_message(message)

	def dot2int(self, dotted_addr):
		a, b, c, d = map(int, string.split(dotted_addr, '.'))
		
		return (a << 24) + (b << 16) + (c << 8) + (d << 0)

	error_message_format = '''<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">
<HTML><HEAD>
<TITLE>MOSS - MP3/OGG Streaming Server</TITLE>
</HEAD><BODY>
<H1>Error response</H1>
<P>Error code %(code)d</P>
<P>Message: %(message)s</P>
<P>Error code explanation: %(explain)s</P>
</BODY></HTML>
'''
	responses = {
        100: ('Continue', 'Request received, please continue'),
        101: ('Switching Protocols',
              'Switching to new protocol; obey Upgrade header'),

        200: ('OK', 'Request fulfilled, document follows'),
        201: ('Created', 'Document created, URL follows'),
        202: ('Accepted',
              'Request accepted, processing continues off-line'),
        203: ('Non-Authoritative Information', 'Request fulfilled from cache'),
        204: ('No response', 'Request fulfilled, nothing follows'),
        205: ('Reset Content', 'Clear input form for further input.'),
        206: ('Partial Content', 'Partial content follows.'),

        300: ('Multiple Choices',
              'Object has several resources -- see URI list'),
        301: ('Moved Permanently', 'Object moved permanently -- see URI list'),
        302: ('Found', 'Object moved temporarily -- see URI list'),
        303: ('See Other', 'Object moved -- see Method and URL list'),
        304: ('Not modified',
              'Document has not changed since given time'),
        305: ('Use Proxy',
              'You must use proxy specified in Location to access this '
              'resource.'),
        307: ('Temporary Redirect',
              'Object moved temporarily -- see URI list'),

        400: ('Bad request',
              'Bad request syntax or unsupported method'),
        401: ('Unauthorized',
              'No permission -- see authorization schemes'),
        402: ('Payment required',
              'No payment -- see charging schemes'),
        403: ('Forbidden',
              'Request forbidden -- authorization will not help'),
        404: ('Not Found', 'Nothing matches the given URI'),
        405: ('Method Not Allowed',
              'Specified method is invalid for this server.'),
        406: ('Not Acceptable', 'URI not available in preferred format.'),
        407: ('Proxy Authentication Required', 'You must authenticate with '
              'this proxy before proceeding.'),
        408: ('Request Time-out', 'Request timed out; try again later.'),
        409: ('Conflict', 'Request conflict.'),
        410: ('Gone',
              'URI no longer exists and has been permanently removed.'),
        411: ('Length Required', 'Client must specify Content-Length.'),
        412: ('Precondition Failed', 'Precondition in headers is false.'),
        413: ('Request Entity Too Large', 'Entity is too large.'),
        414: ('Request-URI Too Long', 'URI is too long.'),
        415: ('Unsupported Media Type', 'Entity body in unsupported format.'),
        416: ('Requested Range Not Satisfiable',
              'Cannot satisfy request range.'),
        417: ('Expectation Failed',
              'Expect condition could not be satisfied.'),

        500: ('Internal error', 'Server got itself in trouble'),
        501: ('Not Implemented',
              'Server does not support this operation'),
        502: ('Bad Gateway', 'Invalid responses from another server/proxy.'),
        503: ('Service temporarily overloaded',
              'The server cannot process the request due to a high load'),
        504: ('Gateway timeout',
              'The gateway server did not receive a timely response'),
        505: ('HTTP Version not supported', 'Cannot fulfill request.'),
        }

	server_version = __package__ + '/' + __version__
	protocol_version = "HTTP/1.0"
	weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
	monthname = [None,
                 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
                 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']

music_extensions = {
	".mp3" : "audio/mpeg",
	".ogg" : "application/ogg"
	}

debug_mode = False
