1import argparse
2import os
3import re
4import tarfile
5import tempfile
6import time
7import zipfile
8from functools import wraps
9from typing import Any, Callable, Dict, List, Optional
10
11import gitlab
12
13TR = Callable[..., Any]
14
15
16def retry(func: TR) -> TR:
17    """
18    This wrapper will only catch several exception types associated with
19    "network issues" and retry the whole function.
20    """
21    @wraps(func)
22    def wrapper(self: 'Gitlab', *args: Any, **kwargs: Any) -> Any:
23        retried = 0
24        while True:
25            try:
26                res = func(self, *args, **kwargs)
27            except (IOError, EOFError, gitlab.exceptions.GitlabError) as e:
28                if isinstance(e, gitlab.exceptions.GitlabError) and e.response_code != 500:
29                    # Only retry on error 500
30                    raise e
31                retried += 1
32                if retried > self.DOWNLOAD_ERROR_MAX_RETRIES:
33                    raise e  # get out of the loop
34                else:
35                    print('Network failure in {}, retrying ({})'.format(getattr(func, '__name__', '(unknown callable)'), retried))
36                    time.sleep(2 ** retried)  # wait a bit more after each retry
37                    continue
38            else:
39                break
40        return res
41    return wrapper
42
43
44class Gitlab(object):
45    JOB_NAME_PATTERN = re.compile(r'(\w+)(\s+(\d+)/(\d+))?')
46
47    DOWNLOAD_ERROR_MAX_RETRIES = 3
48
49    def __init__(self, project_id: Optional[int] = None):
50        config_data_from_env = os.getenv('PYTHON_GITLAB_CONFIG')
51        if config_data_from_env:
52            # prefer to load config from env variable
53            with tempfile.NamedTemporaryFile('w', delete=False) as temp_file:
54                temp_file.write(config_data_from_env)
55            config_files = [temp_file.name]  # type: Optional[List[str]]
56        else:
57            # otherwise try to use config file at local filesystem
58            config_files = None
59        self._init_gitlab_inst(project_id, config_files)
60
61    @retry
62    def _init_gitlab_inst(self, project_id: Optional[int], config_files: Optional[List[str]]) -> None:
63        gitlab_id = os.getenv('LOCAL_GITLAB_HTTPS_HOST')  # if None, will use the default gitlab server
64        self.gitlab_inst = gitlab.Gitlab.from_config(gitlab_id=gitlab_id, config_files=config_files)
65        self.gitlab_inst.auth()
66        if project_id:
67            self.project = self.gitlab_inst.projects.get(project_id)
68        else:
69            self.project = None
70
71    @retry
72    def get_project_id(self, name: str, namespace: Optional[str] = None) -> int:
73        """
74        search project ID by name
75
76        :param name: project name
77        :param namespace: namespace to match when we have multiple project with same name
78        :return: project ID
79        """
80        projects = self.gitlab_inst.projects.list(search=name)
81        res = []
82        for project in projects:
83            if namespace is None:
84                if len(projects) == 1:
85                    res.append(project.id)
86                    break
87
88            if project.namespace['path'] == namespace:
89                if project.name == name:
90                    res.insert(0, project.id)
91                else:
92                    res.append(project.id)
93
94        if not res:
95            raise ValueError("Can't find project")
96        return int(res[0])
97
98    @retry
99    def download_artifacts(self, job_id: int, destination: str) -> None:
100        """
101        download full job artifacts and extract to destination.
102
103        :param job_id: Gitlab CI job ID
104        :param destination: extract artifacts to path.
105        """
106        job = self.project.jobs.get(job_id)
107
108        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
109            job.artifacts(streamed=True, action=temp_file.write)
110
111        with zipfile.ZipFile(temp_file.name, 'r') as archive_file:
112            archive_file.extractall(destination)
113
114    @retry
115    def download_artifact(self, job_id: int, artifact_path: str, destination: Optional[str] = None) -> List[bytes]:
116        """
117        download specific path of job artifacts and extract to destination.
118
119        :param job_id: Gitlab CI job ID
120        :param artifact_path: list of path in artifacts (relative path to artifact root path)
121        :param destination: destination of artifact. Do not save to file if destination is None
122        :return: A list of artifact file raw data.
123        """
124        job = self.project.jobs.get(job_id)
125
126        raw_data_list = []
127
128        for a_path in artifact_path:
129            try:
130                data = job.artifact(a_path)  # type: bytes
131            except gitlab.GitlabGetError as e:
132                print("Failed to download '{}' from job {}".format(a_path, job_id))
133                raise e
134            raw_data_list.append(data)
135            if destination:
136                file_path = os.path.join(destination, a_path)
137                try:
138                    os.makedirs(os.path.dirname(file_path))
139                except OSError:
140                    # already exists
141                    pass
142                with open(file_path, 'wb') as f:
143                    f.write(data)
144
145        return raw_data_list
146
147    @retry
148    def find_job_id(self, job_name: str, pipeline_id: Optional[str] = None, job_status: str = 'success') -> List[Dict]:
149        """
150        Get Job ID from job name of specific pipeline
151
152        :param job_name: job name
153        :param pipeline_id: If None, will get pipeline id from CI pre-defined variable.
154        :param job_status: status of job. One pipeline could have multiple jobs with same name after retry.
155                           job_status is used to filter these jobs.
156        :return: a list of job IDs (parallel job will generate multiple jobs)
157        """
158        job_id_list = []
159        if pipeline_id is None:
160            pipeline_id = os.getenv('CI_PIPELINE_ID')
161        pipeline = self.project.pipelines.get(pipeline_id)
162        jobs = pipeline.jobs.list(all=True)
163        for job in jobs:
164            match = self.JOB_NAME_PATTERN.match(job.name)
165            if match:
166                if match.group(1) == job_name and job.status == job_status:
167                    job_id_list.append({'id': job.id, 'parallel_num': match.group(3)})
168        return job_id_list
169
170    @retry
171    def download_archive(self, ref: str, destination: str, project_id: Optional[int] = None) -> str:
172        """
173        Download archive of certain commit of a repository and extract to destination path
174
175        :param ref: commit or branch name
176        :param destination: destination path of extracted archive file
177        :param project_id: download project of current instance if project_id is None
178        :return: root path name of archive file
179        """
180        if project_id is None:
181            project = self.project
182        else:
183            project = self.gitlab_inst.projects.get(project_id)
184
185        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
186            try:
187                project.repository_archive(sha=ref, streamed=True, action=temp_file.write)
188            except gitlab.GitlabGetError as e:
189                print('Failed to archive from project {}'.format(project_id))
190                raise e
191
192        print('archive size: {:.03f}MB'.format(float(os.path.getsize(temp_file.name)) / (1024 * 1024)))
193
194        with tarfile.open(temp_file.name, 'r') as archive_file:
195            root_name = archive_file.getnames()[0]
196            archive_file.extractall(destination)
197
198        return os.path.join(os.path.realpath(destination), root_name)
199
200
201def main() -> None:
202    parser = argparse.ArgumentParser()
203    parser.add_argument('action')
204    parser.add_argument('project_id', type=int)
205    parser.add_argument('--pipeline_id', '-i', type=int, default=None)
206    parser.add_argument('--ref', '-r', default='master')
207    parser.add_argument('--job_id', '-j', type=int, default=None)
208    parser.add_argument('--job_name', '-n', default=None)
209    parser.add_argument('--project_name', '-m', default=None)
210    parser.add_argument('--destination', '-d', default=None)
211    parser.add_argument('--artifact_path', '-a', nargs='*', default=None)
212    args = parser.parse_args()
213
214    gitlab_inst = Gitlab(args.project_id)
215    if args.action == 'download_artifacts':
216        gitlab_inst.download_artifacts(args.job_id, args.destination)
217    if args.action == 'download_artifact':
218        gitlab_inst.download_artifact(args.job_id, args.artifact_path, args.destination)
219    elif args.action == 'find_job_id':
220        job_ids = gitlab_inst.find_job_id(args.job_name, args.pipeline_id)
221        print(';'.join([','.join([str(j['id']), j['parallel_num']]) for j in job_ids]))
222    elif args.action == 'download_archive':
223        gitlab_inst.download_archive(args.ref, args.destination)
224    elif args.action == 'get_project_id':
225        ret = gitlab_inst.get_project_id(args.project_name)
226        print('project id: {}'.format(ret))
227
228
229if __name__ == '__main__':
230    main()
231