#
# Copyright 2009 Canonical Ltd.
#
# Written by:
#     Gustavo Niemeyer <gustavo.niemeyer@canonical.com>
#     Sidnei da Silva <sidnei.da.silva@canonical.com>
#
# This file is part of the Image Store Proxy.
#
# This program is free software: you can redistribute it and/or modify it 
# under the terms of the GNU General Public License version 3, as published 
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but 
# WITHOUT ANY WARRANTY; without even the implied warranties of 
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR 
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along 
# with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import pycurl
import os

from imagestore.lib.tests import TestCase
from imagestore.lib.fetch import (
    fetch, fetch_async, HTTPCodeError, PyCurlError, LocalFileError)


class CurlStub(object):

    def __init__(self, result=None, http_code=200, error=None):
        self.result = result
        self._http_code = http_code
        self.options = {}
        self.performed = False
        self.error = error

    def getinfo(self, what):
        if what == pycurl.HTTP_CODE:
            return self._http_code
        raise RuntimeError("Stub doesn't know about %d info" % what)

    def setopt(self, option, value):
        if self.performed:
            raise AssertionError("setopt() can't be called after perform()")
        self.options[option] = value

    def perform(self):
        if self.error:
            raise self.error
        if self.performed:
            raise AssertionError("Can't perform twice")
        if pycurl.WRITEFUNCTION in self.options:
            self.options[pycurl.WRITEFUNCTION](self.result)
        else:
            resume_index = self.options[pycurl.RESUME_FROM_LARGE]
            self.options[pycurl.WRITEDATA].write(self.result[resume_index:])
        self.performed = True


class Any(object):
    def __eq__(self, other):
        return True


class FetchTest(TestCase):

    def getContent(self, filename):
        file = open(filename)
        try:
            return file.read()
        finally:
            file.close()


    def test_basic(self):
        curl = CurlStub("result")
        result = fetch("http://example.com", curl=curl)
        self.assertEquals(result, "result")
        self.assertEquals(curl.options,
                          {pycurl.URL: "http://example.com",
                           pycurl.FOLLOWLOCATION: True,
                           pycurl.MAXREDIRS: 5,
                           pycurl.CONNECTTIMEOUT: 30,
                           pycurl.LOW_SPEED_LIMIT: 1,
                           pycurl.LOW_SPEED_TIME: 600,
                           pycurl.NOSIGNAL: 1,
                           pycurl.WRITEFUNCTION: Any()})

    def test_post(self):
        curl = CurlStub("result")
        result = fetch("http://example.com", post=True, curl=curl)
        self.assertEquals(result, "result")
        self.assertEquals(curl.options,
                          {pycurl.URL: "http://example.com",
                           pycurl.FOLLOWLOCATION: True,
                           pycurl.MAXREDIRS: 5,
                           pycurl.CONNECTTIMEOUT: 30,
                           pycurl.LOW_SPEED_LIMIT: 1,
                           pycurl.LOW_SPEED_TIME: 600,
                           pycurl.NOSIGNAL: 1,
                           pycurl.WRITEFUNCTION: Any(),
                           pycurl.POSTFIELDSIZE: 0,
                           pycurl.READFUNCTION: Any(),
                           pycurl.POST: True})

    def test_post_data(self):
        curl = CurlStub("result")
        result = fetch("http://example.com", post=True, data="data", curl=curl)
        self.assertEquals(result, "result")
        self.assertEquals(curl.options[pycurl.READFUNCTION](), "data")
        self.assertEquals(curl.options,
                          {pycurl.URL: "http://example.com",
                           pycurl.FOLLOWLOCATION: True,
                           pycurl.MAXREDIRS: 5,
                           pycurl.CONNECTTIMEOUT: 30,
                           pycurl.LOW_SPEED_LIMIT: 1,
                           pycurl.LOW_SPEED_TIME: 600,
                           pycurl.NOSIGNAL: 1,
                           pycurl.WRITEFUNCTION: Any(),
                           pycurl.POST: True,
                           pycurl.POSTFIELDSIZE: 4,
                           pycurl.READFUNCTION: Any()})

    def test_cainfo(self):
        curl = CurlStub("result")
        result = fetch("https://example.com", cainfo="cainfo", curl=curl)
        self.assertEquals(result, "result")
        self.assertEquals(curl.options,
                          {pycurl.URL: "https://example.com",
                           pycurl.FOLLOWLOCATION: True,
                           pycurl.MAXREDIRS: 5,
                           pycurl.CONNECTTIMEOUT: 30,
                           pycurl.LOW_SPEED_LIMIT: 1,
                           pycurl.LOW_SPEED_TIME: 600,
                           pycurl.NOSIGNAL: 1,
                           pycurl.WRITEFUNCTION: Any(),
                           pycurl.CAINFO: "cainfo"})

    def test_cainfo_on_http(self):
        curl = CurlStub("result")
        result = fetch("http://example.com", cainfo="cainfo", curl=curl)
        self.assertEquals(result, "result")
        self.assertTrue(pycurl.CAINFO not in curl.options)

    def test_headers(self):
        curl = CurlStub("result")
        result = fetch("http://example.com", headers={"a":"1", "b":"2"},
                       curl=curl)
        self.assertEquals(result, "result")
        self.assertEquals(curl.options,
                          {pycurl.URL: "http://example.com",
                           pycurl.FOLLOWLOCATION: True,
                           pycurl.MAXREDIRS: 5,
                           pycurl.CONNECTTIMEOUT: 30,
                           pycurl.LOW_SPEED_LIMIT: 1,
                           pycurl.LOW_SPEED_TIME: 600,
                           pycurl.NOSIGNAL: 1,
                           pycurl.WRITEFUNCTION: Any(),
                           pycurl.HTTPHEADER: ["a: 1", "b: 2"]})

    def test_timeouts(self):
        curl = CurlStub("result")
        result = fetch("http://example.com", connect_timeout=5, total_timeout=30,
                       curl=curl)
        self.assertEquals(result, "result")
        self.assertEquals(curl.options,
                          {pycurl.URL: "http://example.com",
                           pycurl.FOLLOWLOCATION: True,
                           pycurl.MAXREDIRS: 5,
                           pycurl.CONNECTTIMEOUT: 5,
                           pycurl.LOW_SPEED_LIMIT: 1,
                           pycurl.LOW_SPEED_TIME: 30,
                           pycurl.NOSIGNAL: 1,
                           pycurl.WRITEFUNCTION: Any()})

    def test_non_200_result(self):
        curl = CurlStub("result", http_code=404)
        try:
            fetch("http://example.com", curl=curl)
        except HTTPCodeError, error:
            self.assertEquals(error.http_code, 404)
            self.assertEquals(error.body, "result")
        else:
            self.fail("HTTPCodeError not raised")

    def test_http_error_str(self):
        self.assertEquals(str(HTTPCodeError(501, "")),
                          "Server returned HTTP code 501")

    def test_http_error_repr(self):
        self.assertEquals(repr(HTTPCodeError(501, "")),
                          "<HTTPCodeError http_code=501>")

    def test_pycurl_error(self):
        curl = CurlStub(result=None, http_code=None,
                        error=pycurl.error(60, "pycurl error"))
        try:
            fetch("http://example.com", curl=curl)
        except PyCurlError, error:
            self.assertEquals(error.error_code, 60)
            self.assertEquals(error.message, "pycurl error")
        else:
            self.fail("PyCurlError not raised")

    def test_pycurl_error_str(self):
        self.assertEquals(str(PyCurlError(60, "pycurl error")),
                          "Error 60: pycurl error")

    def test_pycurl_error_repr(self):
        self.assertEquals(repr(PyCurlError(60, "pycurl error")),
                          "<PyCurlError args=(60, 'pycurl error')>")

    def test_create_curl(self):
        curls = []
        def pycurl_Curl():
            curl = CurlStub("result")
            curls.append(curl)
            return curl
        Curl = pycurl.Curl
        try:
            pycurl.Curl = pycurl_Curl
            result = fetch("http://example.com")
            curl = curls[0]
            self.assertEquals(result, "result")
            self.assertEquals(curl.options,
                              {pycurl.URL: "http://example.com",
                               pycurl.FOLLOWLOCATION: True,
                               pycurl.MAXREDIRS: 5,
                               pycurl.CONNECTTIMEOUT: 30,
                               pycurl.LOW_SPEED_LIMIT: 1,
                               pycurl.LOW_SPEED_TIME: 600,
                               pycurl.NOSIGNAL: 1,
                               pycurl.WRITEFUNCTION: Any()})
        finally:
            pycurl.Curl = Curl

    def test_async_fetch(self):
        curl = CurlStub("result")
        d = fetch_async("http://example.com/", curl=curl)
        def got_result(result):
            self.assertEquals(result, "result")
        return d.addCallback(got_result)

    def test_async_fetch_with_error(self):
        curl = CurlStub("result", http_code=501)
        d = fetch_async("http://example.com/", curl=curl)
        def got_error(failure):
            self.assertEquals(failure.value.http_code, 501)
            self.assertEquals(failure.value.body, "result")
            return failure
        d.addErrback(got_error)
        self.assertFailure(d, HTTPCodeError)
        return d

    def test_real_fetch_from_local_file(self):
        filename = self.makeFile("content")
        result = fetch("file://"+filename)
        self.assertEquals(result, "content")

    def test_real_fetch_from_local_file_into_file(self):
        remote = self.makeFile("content")
        local = self.makeFile()
        result = fetch("file://"+remote, local_path=local)
        self.assertEquals(result, None)
        self.assertTrue(os.path.isfile(local))
        file = open(local)
        try:
            self.assertEquals(file.read(), "content")
        finally:
            file.close()

    def test_real_fetch_with_local_file_error(self):
        remote = self.makeFile("content")
        local = "/non-existent/filename"
        self.assertRaises(LocalFileError,
                          fetch, "file://"+remote, local_path=local)

    def test_resume_partial_download(self):
        total_content = "result"
        partial_content = "resu"
        curl = CurlStub(total_content, http_code=206)
        local = self.makeFile(partial_content)
        fetch("https://example.com", cainfo="cainfo", curl=curl,
              local_path=local, resume=True)
        self.assertEquals(curl.options,
                          {pycurl.URL: "https://example.com",
                           pycurl.FOLLOWLOCATION: True,
                           pycurl.MAXREDIRS: 5,
                           pycurl.CONNECTTIMEOUT: 30,
                           pycurl.LOW_SPEED_LIMIT: 1,
                           pycurl.LOW_SPEED_TIME: 600,
                           pycurl.NOSIGNAL: 1,
                           pycurl.WRITEDATA: Any(),
                           pycurl.RESUME_FROM_LARGE: len(partial_content),
                           pycurl.CAINFO: "cainfo"})
        self.assertEquals(curl.options[pycurl.WRITEDATA].name, local)
        self.assertEquals(curl.options[pycurl.WRITEDATA].closed, True)
        self.assertEquals(self.getContent(local), total_content)

    def test_resume_partial_download_with_larger_local(self):
        total_content = "result"
        partial_content = "result which is larger"
        curl = CurlStub(total_content, http_code=206)
        local = self.makeFile(partial_content)
        fetch("https://example.com", cainfo="cainfo", curl=curl,
              local_path=local, resume=True, size=len(total_content))
        self.assertEquals(curl.options,
                          {pycurl.URL: "https://example.com",
                           pycurl.FOLLOWLOCATION: True,
                           pycurl.MAXREDIRS: 5,
                           pycurl.CONNECTTIMEOUT: 30,
                           pycurl.LOW_SPEED_LIMIT: 1,
                           pycurl.LOW_SPEED_TIME: 600,
                           pycurl.NOSIGNAL: 1,
                           pycurl.WRITEDATA: Any(),
                           pycurl.RESUME_FROM_LARGE: 0,
                           pycurl.CAINFO: "cainfo"})
        self.assertEquals(curl.options[pycurl.WRITEDATA].name, local)
        self.assertEquals(curl.options[pycurl.WRITEDATA].closed, True)
        self.assertEquals(self.getContent(local), total_content)

    def test_real_fetch_with_progress(self):
        data = "0123456789" * 1024 * 100 # 1MB
        remote = self.makeFile(data)
        local = self.makeFile()
        called = []
        def progress(partial_size, total_size):
            def update(down_total, down_current, up_total, up_current):
                called.append((down_total, down_current, up_total, up_current))
            return update

        fetch("file://"+remote, local_path=local, progress=progress)

        self.assertTrue(len(called) > 2)
        self.assertEquals(called[-1][0], 1024000)
        self.assertEquals(called[0][2:], (0, 0))
        self.assertTrue(called[0][1] < called[1][1] < called[2][1]) # ...

    def test_real_fetch_with_progress_in_memory(self):
        data = "0123456789" * 1024 * 100 # 1MB
        remote = self.makeFile(data)
        called = []
        def progress(partial_size, total_size):
            def update(down_total, down_current, up_total, up_current):
                called.append((down_total, down_current, up_total, up_current))
            return update

        result = fetch("file://"+remote, progress=progress)

        self.assertTrue(len(called) > 2)
        self.assertEquals(called[-1][0], 1024000)
        self.assertEquals(called[0][2:], (0, 0))
        self.assertTrue(called[0][1] < called[1][1] < called[2][1]) # ...
