#!/usr/bin/env python
#----------------------------------------------------------------------------
# Project Name: MOSS
# File Name: web.py
# Description: Web-based streaming server routine.
#
# Created: 2004. 06. 20
# RCS-ID: $Id: web.py,v 1.43 2004/10/16 12:32:51 myunggoni Exp $
# Copyright: (c) 2004 by myunggoni
# License: GNU General Public License
# Author: Myung-Gon Park <myunggoni@users.kldp.net>
#----------------------------------------------------------------------------

import SocketServer
import BaseHTTPServer
import ConfigParser
import sys
import string
import os
import cgi
import urllib
import socket
import stat
import random
import time
import re
import pwd
import template
import fileinfo
from moss import __package__, __version__

try:
	import cStringIO
	StringIO = cStringIO
except ImportError:
	import StringIO


class Server(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
	def __init__(self, config_file):
		self.connection = 0
		self.config = ConfigParser.ConfigParser()

		config_defaults = self.config.defaults()
		
		config_defaults["port"] = "8162"
		config_defaults["hostname"] = ""
		config_defaults["user"] = "nobody"
		config_defaults["group"] = "nobody"
		config_defaults["log"] = "/tmp/moss.log"
		config_defaults["max-connection"] = "10"
		config_defaults["read-size"] = "8192"
		config_defaults["template"] = "default.tmpl"
		self.config.read(config_file)
		
		user = self.config.get("server", "user")
		group = self.config.get("server", "group")
		self.log = self.config.get("server", "log")
		self.hostname = self.config.get("server", "hostname")
		self.port = self.config.getint("server", "port")
		self.max_connection = self.config.getint("server", "max-connection")
		self.read_size = self.config.getint("server", "read-size")
		
		if debug_mode:
			self.log = sys.stdout
		else:
			try:
				self.log = open(self.log, "a")
			except IOError:
				self.log = sys.stdout
			
		try:
			template_path = self.config.get("extra", "template-dir")
		except ConfigParser.NoOptionError:
			self.log_message("Template directory was not specified")
			sys.exit(1)
			
		template_file = self.config.get("extra", "template")
		template_file = os.path.join(template_path, template_file)

		if os.path.isfile(template_file):
			try:
				self.template = template.Template(template_file)
			except SystemExit:
				self.log_message("Template error occured")
				sys.exit(1)
		else:
			self.log_message("No such template file")
			sys.exit(1)

		self.acls = []

		try:
			allowed = re.split(r"[\s\n,]+", self.config.get("acl", "allow"))
		except ConfigParser.NoOptionError:
			allowed = []

		for addr in allowed:
			if '/' in addr:
				addr, masklen = string.split(addr, '/')
				masklen = int(masklen)
			else:
				masklen = 32

			if not re.match(r"^\d+\.\d+\.\d+\.\d+$", addr):
				addr = socket.gethostbyname(addr)
			
			mask = ~((1 << (32 - masklen)) - 1)
			entry = (self.dot2int(addr), mask)

			if not entry in self.acls:
				self.acls.append(entry)

		try:
			auth_file = self.config.get("acl", "auth-file")
			
			if os.path.isfile(auth_file):
				fp = open(auth_file, "r")

				try:
					self.auth_table = {}

					for pair in fp.readlines():
						user, passwd = string.split(pair[:-1], ':')
						self.auth_table[user] = passwd
				except ValueError:
					pass

				fp.close()
			else:
				self.log_message("No such authenticate file")
		except IOError:
				self.log_message("Invalid authenticate file")
		except ConfigParser.NoOptionError:
			self.auth_table = {}
		
		self.dirs = []
		dirs = []

		for option in self.config.options("sources"):
			if option[:10] == "source-dir":
				dirs.append((int(option[10:]), self.config.get("sources", option)))
				
		if not dirs:
			self.log_message("Source directory not found")
			sys.exit(1)

		dirs.sort()

		for i in range(len(dirs)):
			dir = map(string.strip, string.split(dirs[i][1], ':'))
			
			if len(dir) == 1:
				label = os.path.basename(dir[0])
			else:
				label = dir[1]

			if not os.path.isdir(dir[0]):
				self.log_message("Invalid source directory")
				continue

			if string.find(label, '/') != -1:
				self.log_message("Invalid source label")
				continue

			self.dirs.append((dir[0], label))

		if not self.dirs:
			self.log_message("Source directory not found")
			sys.exit(1)

		try:
			SocketServer.TCPServer.__init__(self, (self.hostname, self.port), HTTPRequestHandler)
		except socket.error, error:
			self.log_message("%s" % str(error[1]))
			sys.exit(1)

		if os.getuid() == 0:
			try:
				user_id = pwd.getpwnam(user)[2]
			except KeyError:
				self.log_message("Invalid user name")
				sys.exit(1)

			try:
				group_id = pwd.getpwnam(group)[2]
			except KeyError:
				self.log_message("Invalid group name")
				sys.exit(1)

			try:
				os.setgid(group_id)
			except OSError, error:
				self.log_message("Setgid failed: %s" % (error.strerror))
				sys.exit(1)

			try:
				os.setuid(user_id)
			except OSError, error:
				self.log_message("Setuid failed: %s" % (error.strerror))
				sys.exit(1)

	def server_bind(self):
		if hasattr(socket, "SOL_SOCKET") and hasattr(socket, "SO_REUSEADDR"):
			self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

		BaseHTTPServer.HTTPServer.server_bind(self)
		self.socket.settimeout(3)
	
	def acl_ok(self, ipaddr):
		if not self.acls:
			return 1

		ipaddr = self.dot2int(ipaddr)

		for allowed, mask in self.acls:
			if (ipaddr & mask) == (allowed & mask):
				return 1

		return 0

	def log_date_time_string(self):
		now = time.time()
		year, month, day, hh, mm, ss, x, y, z = time.localtime(now)
		s = "%04d/%02d/%02d %02d:%02d" % (year, month, day, hh, mm)
		
		return s

	def log_message(self, message):
		if self.log:
			try:
				self.log.write("%s [-] %s\n" % (self.log_date_time_string(), message))
				self.log.flush()
			except IOError:
				pass

	def debug_message(self, message):
		if debug_mode > 0:
			self.log_message(message)

	def dot2int(self, dotted_addr):
		a, b, c, d = map(int, string.split(dotted_addr, '.'))
		
		return (a << 24) + (b << 16) + (c << 8) + (d << 0)

class HTTPRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
	def do_GET(self):
		try:
			self.perform_GET()
		except:
			self.server.log_message(
			"Exception happened during processing of request from %s" % self.client_address[0])

	def perform_GET(self):

		if not self.server.acl_ok(self.client_address[0]):
			self.send_error(403, "Forbidden")
			return

		if self.server.auth_table:
			if not self.check_authorization():
				return

		path = self.translate_path()

		if path is None:
			self.send_error(400, "Illegal URL Construction")
			return

		if not path and len(self.server.dirs) > 1:
			subdirs = []
			url = '/'
			
			for dir, label in self.server.dirs:
				subdir = {}
				subdir["directory"] = label
				subdir["href"] = url + urllib.quote(label)
				subdir["files"] = len(self.make_list(dir, url + urllib.quote(label), 1))
				subdir["playall"] = subdir["href"] + '/' + "playall.m3u"
				subdir["playrecursively"] = subdir["href"] + '/' + "playrecursively.m3u"
				subdirs.append(subdir) 

			self.display_page(url, subdirs)
		else:
			if len(self.server.dirs) == 1:
				url = '/'
				curdir = self.server.dirs[0][0]
			else:
				url = '/' + urllib.quote(path[0])

				for dir, label in self.server.dirs:
					if path[0] == label:
						curdir = dir
						path.pop(0)
						break
				else:
					if path[0] == "playall.m3u" or path[0] == "playrecursively.m3u":
						if self.server.connection  < self.server.max_connection:
							self.server.connection = self.server.connection + 1
							self.server.log_message("Connection established - %s" % self.client_address[0])
							self.server.debug_message("Total %d connection(s) established" % (self.server.connection, ))
							self.serve_file(path[0], None, '/')
						return
					else:
						self.send_error(404)
						return
				
			for part in path:
				if part == "playall.m3u" or part == "playrecursively.m3u":
					if self.server.connection  < self.server.max_connection:
						self.server.connection = self.server.connection + 1
						self.server.log_message("Connection established - %s" % self.client_address[0])											
						self.server.debug_message("Total %d connection(s) esstablished" % (self.server.connection, ))
						self.serve_file(part, curdir, url)
						return
					else:
						self.server.log_message("Max connection exceeded - %s" % self.client_address[0])
						self.send_error(200, "Max Connection Exceeded")
						return
				
				pathname = os.path.join(curdir, part) 
				
				base, ext = os.path.splitext(part)
	
				if playlist_extensions.has_key(string.lower(ext)):
					base, ext = os.path.splitext(base)
					if music_extensions.has_key(string.lower(ext)):
						pathname = os.path.join(curdir, base + ext)
			
				if not os.path.exists(pathname):
					self.send_error(404)
					return

				if os.path.isfile(pathname):
					if image_extensions.has_key(ext):
						self.serve_file(part, pathname, url)
						return
					else:
						if self.server.connection < self.server.max_connection:
							self.server.connection = self.server.connection + 1
							self.server.log_message("Connection established - %s" % self.client_address[0])											
							self.server.debug_message("Total %d connection(s) established" % (self.server.connection, ))
							self.serve_file(part, pathname, url, self.headers.getheader("range"))
							return
						else:
							self.server.log_message("Max connection exceeded - %s" % self.client_address[0])
							self.send_error(200, "Max Connection Exceeded")
							return
								
				curdir = pathname
				
				if url == '/':
					url = '/' + urllib.quote(part)
				else:
					url = url + '/' + urllib.quote(part)

			subdirs = []
			songs = []

			list = os.listdir(curdir)
			list.sort()
			
			for name in list:
				href = urllib.quote(name)
				base, ext = os.path.splitext(name)
				ext = string.lower(ext)
				fullpath = os.path.join(curdir, name)

				if music_extensions.has_key(ext):
					song = {}
					
					try:
						info = fileinfo.FileInfo(fullpath)
					except fileinfo.InvalidFileException:
						self.server.log_message("Unknown file format")
						info = None
						
					song["filename"] = base
					song["fileinfo"] = info

					if url == '/':
						song["href"] = url + urllib.quote(name) + ".m3u"
					else:
						song["href"] = url + '/' + urllib.quote(name) + ".m3u"
						
					songs.append(song)
				else:
					if os.path.isdir(fullpath):
						subdir = {}
						
						if url == '/':
							subdir["href"] = url + urllib.quote(name)
						else:
							subdir["href"] = url + '/' + urllib.quote(name)
						
						subdir["directory"] = name
						subdir["files"] = len(self.make_list(fullpath, subdir["href"], 1))
						subdir["playall"] = subdir["href"] + '/' + "playall.m3u"
						subdir["playrecursively"] = subdir["href"] + '/' + "playrecursively.m3u"
						subdirs.append(subdir) 

			self.display_page(url, subdirs, songs)

	def check_authorization(self):
		auth_table = self.server.auth_table
		auth = self.headers.getheader("Authorization")
		
		if auth:
			if string.lower(auth[:6]) == "basic ":
				import base64
				import md5

				[user, passwd] = string.split(
						base64.decodestring(string.split(auth)[-1]), ':')
				
				hash = md5.new()
				hash.update(passwd)
				passwd = hash.hexdigest()

			if auth_table.has_key(user) and auth_table[user] == passwd:
				self.server.log_message("Authenticated user: %s - %s" % (user, self.client_address[0]))
				return 1
			else:
				self.server.log_message("Authorization failed: %s - %s" % (user, self.client_address[0]))
				return 0
		else:
			realm = "MOSS"
			self.send_response(401)
			self.send_header("WWW-Authenticate", "basic realm=\"%s\"" % (realm, ))
			self.end_headers()
			return 0

	def display_page(self, url, subdirs, songs = []):
		self.send_response(200)
		self.send_header("Content-Type", "text/html")
		self.end_headers()

		self.server.template.generate(self.wfile, url, subdirs, songs)

	def translate_path(self):
		parts = string.split(urllib.unquote(self.path), '/')
		parts = filter(None, parts)

		while 1:
			try:
				parts.remove('.')
			except ValueError:
				break

		while 1:
			try:
				idx = parts.index("..")
			except ValueError:
				break

			if idx == 0:
				return None

			del parts[idx - 1:idx + 1]

		return parts
			

	def serve_file(self, name, fullpath, url, range = None):
		base, ext = os.path.splitext(name)
		ext = string.lower(ext)

		if supported_extensions.has_key(ext):
			type = supported_extensions[ext]
			file = open(fullpath, "rb")
			file_len = os.fstat(file.fileno())[stat.ST_SIZE]
		elif not playlist_extensions.has_key(ext):
			self.server.connection = self.server.connection - 1
			self.server.log_message("Connection closed - %s" % self.client_address[0])
			self.server.debug_message("Total %d connection(s) established" % (self.server.connection, ))
			self.send_error(404)
			return		
		else:
			type = playlist_extensions[ext]
			
			if name == "playall.m3u" or name == "playrecursively.m3u": 
				recursive = name == "playrecursively.m3u"
				
				if not fullpath:
					songs = []

					for dir, label in self.server.dirs:
						url = '/' + urllib.quote(label)
						list = self.make_list(dir, url, recursive)
						songs = songs + list
				else:
					songs = self.make_list(fullpath, url, recursive)

				file = StringIO.StringIO(string.join(songs, ''))
				file_len = len(file.getvalue())
			else:
				base, ext = os.path.splitext(base)

				if music_extensions.has_key(string.lower(ext)):
					file = StringIO.StringIO(self.build_url(url, base) + ext + '\n')
					file_len = len(file.getvalue())
				else:
					# file = open_playlist(fullpath, url)
					# file_len = len(file.getvalue())
					
					self.server.connection = self.server.connection - 1
					self.server.log_message("Connection closed - %s" % self.client_address[0])
					self.server.debug_message("Total %d connection(s) established" % (self.server.connection, ))
					self.send_error(404)
					return
			
		self.send_response(200)
		self.send_header("Content-Type", type)
		self.send_header("Content-Length", file_len)
		self.send_header("Connection", "close")
		
		if music_extensions.has_key(ext):
			self.send_header("icy-name", base)
		
		self.end_headers()

		if range:
			type, seek = string.split(range, '=')
			seek_start, seek_end = string.split(seek, '-')
			file.seek(int(seek_start))

		while 1:
			data = file.read(self.server.read_size)

			if not data:
				break

			try:
				self.wfile.write(data)
			except ClientAbortedException:
				break
			except socket.error:
				break
			
		if not image_extensions.has_key(ext):
			self.server.connection = self.server.connection - 1
			self.server.log_message("Connection closed - %s" % self.client_address[0])
			self.server.debug_message("Total %d connection(s) established" % (self.server.connection, ))

		file.close()

	def build_url(self, path, file = ""):
		host =  self.headers.getheader("host") or self.server.server_name
		
		if path == '/':
			if string.find(host, ':') != -1:
				url =  "http://%s%s%s" % (host, path, urllib.quote(file))
			else:
				url = "http://%s:%s%s%s" % (host, self.server.server_port, path, urllib.quote(file))
		else:
			if string.find(host, ':') != -1:
				url = "http://%s%s/%s" % (host, path, urllib.quote(file))
			else:
				url = "http://%s:%s%s/%s" % (host, self.server.server_port, path, urllib.quote(file))
		
		return url

	def make_list(self, fullpath, url, recursive, songs = None):
		if songs is None:
			songs = []

		list = os.listdir(fullpath)
		list.sort()

		for name in list:
			base, ext = os.path.splitext(name)

			if music_extensions.has_key(string.lower(ext)):
				songs.append(self.build_url(url, name) + '\n')

			if recursive and os.path.isdir(os.path.join(fullpath, name)):
				songs = self.make_list(os.path.join(fullpath, name), 
							url + '/' + urllib.quote(name),
							recursive, songs)
				
		return songs

	def log_message(self, format, *args):
		if self.server.log:
			message = "%s %s" % (format % args, self.address_string())
			self.server.log_message(message)
	
	def setup(self):
		self.request.settimeout(None)
		SocketServer.StreamRequestHandler.setup(self)
		self.wfile = SocketWriter(self.wfile)

	def finish(self):
		try:
			self.wfile.close()
		except socket.error:
			pass

		try:
			self.rfile.close()
		except socket.error:
			pass

	def version_string(self):
		return self.server_version

	server_version = __package__ + '/' + __version__
	
	error_message_format = '''<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">
<HTML><HEAD>
<TITLE>MOSS - MP3/OGG Streaming Server</TITLE>
</HEAD><BODY>
<H1>Error response</H1>
<P>Error code %(code)d</P>
<P>Message: %(message)s</P>
<P>Error code explanation: %(explain)s</P>
</BODY></HTML>
'''


class SocketWriter:
	def __init__(self, wfile):
		self.wfile = wfile
	
	def __getattr__(self, name):
		return getattr(self.wfile, name)

	def write(self, buf):
		try:
			send_buf = str(buf)
			return self.wfile.write(send_buf)
		except IOError, error:
			if error.errno == 32 or error.errno == 104:
				raise ClientAbortedException
			else:
				raise

class ClientAbortedException(Exception):
	pass

music_extensions = {
	".mp3" : "audio/mpeg",
	".ogg" : "application/ogg",
	".mid" : "audio/mid",
	".wmx" : "audio/x-ms-wmx",
	".wma" : "audio/x-ms-wma",
	".mp2" : "video/mpeg",
	".asf" : "video/x-ms-asf",
	".wmv" : "video/x-ms-wmv",
	".mpg" : "video/mpeg"
	}

playlist_extensions = {
	".m3u" : "audio/x-mpegurl",
	".pls" : "audio/x-scpls",
	".asx" : "video/x-ms-asx"
	}

image_extensions = {
	".jpg" : "image/jpeg",
	".jpeg" : "image/jpeg",
	".gif" : "image/gif",
	".png" : "image/png"
	}

supported_extensions = {}
supported_extensions.update(music_extensions)
supported_extensions.update(image_extensions)

debug_mode = False
