#
# Copyright 2009 Canonical Ltd.
#
# Written by:
#     Gustavo Niemeyer <gustavo.niemeyer@canonical.com>
#     Sidnei da Silva <sidnei.da.silva@canonical.com>
#
# This file is part of the Image Store Proxy.
#
# This program is free software: you can redistribute it and/or modify it 
# under the terms of the GNU General Public License version 3, as published 
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but 
# WITHOUT ANY WARRANTY; without even the implied warranties of 
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR 
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along 
# with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import os
import sys
import signal
import urllib

from optparse import OptionParser
from StringIO import StringIO

import pycurl

from twisted.internet.threads import deferToThread


class FetchError(Exception):
    pass


class HTTPCodeError(FetchError):

    def __init__(self, http_code, body):
        self.http_code = http_code
        self.body = body

    def __str__(self):
        return "Server returned HTTP code %d" % self.http_code

    def __repr__(self):
        return "<HTTPCodeError http_code=%d>" % self.http_code


class LocalFileError(FetchError):

    def __init__(self, path, message):
        self.path = path
        self._message = message

    def _get_message(self):
        return self._message

    def _set_message(self, message):
        self._message = message

    message = property(_get_message, _set_message)

    def __str__(self):
        return self.message

    def __repr__(self):
        return "<LocalFileError args=(%r, '%s')>" % (self.path,
                                                      self.message)


class PyCurlError(FetchError):
    def __init__(self, error_code, message):
        self.error_code = error_code
        self._message = message

    def _get_message(self):
        return self._message

    def _set_message(self, message):
        self._message = message

    message = property(_get_message, _set_message)

    def __str__(self):
        return "Error %d: %s" % (self.error_code, self.message)

    def __repr__(self):
        return "<PyCurlError args=(%d, '%s')>" % (self.error_code,
                                                  self.message)


def fetch(url, post=False, data="", headers={}, cainfo=None, curl=None,
          connect_timeout=30, total_timeout=600, local_path=None,
          resume=False, size=None, progress=None):
    """Retrieve a URL and return the content.

    If a local filename is provided, the data is saved to that file
    with optional resume and the filename is returned. Otherwise, the
    file content is returned as a string.

    @param url: The url to be fetched.
    @param post: If true, the POST method will be used (defaults to GET).
    @param data: Data to be sent to the server as the POST content.
    @param headers: Dictionary of header => value entries to be used
                    on the request.
    @param cainfo: Path to the file with CA certificates.
    @param local_path: Path to the target local file where the
                       download will be saved.

    @param progress: A progress callback factory. It will be
                     instantiated with the partial file size (if
                     resume is true and the file exists) and the
                     expected download size if known. Return any
                     integer to cancel the download.
    """
    output = StringIO(data)

    if curl is None:
        curl = pycurl.Curl()

    if post:
        curl.setopt(pycurl.POST, True)
        curl.setopt(pycurl.POSTFIELDSIZE, len(data))
        curl.setopt(pycurl.READFUNCTION, output.read)

    if cainfo and url.startswith("https:"):
        curl.setopt(pycurl.CAINFO, cainfo)

    if headers:
        curl.setopt(pycurl.HTTPHEADER,
                    ["%s: %s" % pair for pair in sorted(headers.iteritems())])

    # XXX If the URL is unicode, an error occurs.
    curl.setopt(pycurl.URL, str(url))
    curl.setopt(pycurl.FOLLOWLOCATION, True)
    curl.setopt(pycurl.MAXREDIRS, 5)
    curl.setopt(pycurl.CONNECTTIMEOUT, connect_timeout)
    curl.setopt(pycurl.LOW_SPEED_LIMIT, 1)
    curl.setopt(pycurl.LOW_SPEED_TIME, total_timeout)
    curl.setopt(pycurl.NOSIGNAL, 1)

    partial_size = 0

    if local_path is None:
        input = StringIO()
        curl.setopt(pycurl.WRITEFUNCTION, input.write)
    else:
        if resume and os.path.isfile(local_path):
            partial_size = os.path.getsize(local_path)
            if size is not None and partial_size >= size:
                partial_size = 0

        if partial_size:
            open_mode = "a"
            curl.setopt(pycurl.RESUME_FROM_LARGE, long(partial_size))
        else:
            open_mode = "w"
            curl.setopt(pycurl.RESUME_FROM_LARGE, 0L)

        try:
            local = open(local_path, open_mode)
        except (IOError, OSError), e:
            raise LocalFileError(local_path, str(e))

        curl.setopt(pycurl.WRITEDATA, local)

    if progress is not None:
        progress = progress(partial_size, size)
        curl.setopt(pycurl.NOPROGRESS, 0)
        curl.setopt(pycurl.PROGRESSFUNCTION, progress)

    try:
        try:
            curl.perform()
        except pycurl.error, e:
            raise PyCurlError(e.args[0], e.args[1])
    finally:
        if local_path is not None:
            local.close()

    http_code = curl.getinfo(pycurl.HTTP_CODE)

    # 0 for local files (file:///), 206 is for partial content.
    if http_code not in (0, 200, 206):
        if local_path is None:
            msg = input.getvalue()
        else:
            msg = local_path
        raise HTTPCodeError(http_code, msg)

    if local_path is None:
        return input.getvalue()


class Progress(object):

    def __init__(self, partsize, size):
        self._abort = False
        self._partsize = partsize
        self._size = size
        self.trap_signal(signal.SIGINT)

    def __call__(self, downtotal, downcurrent, uptotal, upcurrent):
        current, total = None, None
        if not downtotal:
            if self._size and downcurrent:
                current, total = self._partsize + downcurrent, self._size
        else:
            current, total = (self._partsize + downcurrent,
                              self._partsize + downtotal)

        if current is not None and total is not None:
            print "Downloaded %d MB of %d MB" % (current/1024, total/1024)

        if self._abort:
            # Returning an integer value cancels the download.
            return 1

    def handle_signal(self, sig, frame):
        self._abort = True

    def trap_signal(self, sig):
        signal.signal(sig, self.handle_signal)


def test(args):
    parser = OptionParser()
    parser.add_option("--post", action="store_true")
    parser.add_option("--data", default="")
    parser.add_option("--resume", action="store_true")
    parser.add_option("--output-file", default=None)
    parser.add_option("--progress", action="store_true")
    parser.add_option("--cainfo")
    options, (url,) = parser.parse_args(args)

    print fetch(url, post=options.post, data=options.data,
                cainfo=options.cainfo, local_path=options.output_file,
                resume=options.resume, progress=options.progress and Progress)


def fetch_async(*args, **kwargs):
    return deferToThread(fetch, *args, **kwargs)


if __name__ == "__main__":
    test(sys.argv[1:])
