1#!/usr/bin/env python3
2#
3# Plot CSV files with matplotlib.
4#
5# Example:
6# ./scripts/plotmpl.py bench.csv -xSIZE -ybench_read -obench.svg
7#
8# Copyright (c) 2022, The littlefs authors.
9# SPDX-License-Identifier: BSD-3-Clause
10#
11
12import codecs
13import collections as co
14import csv
15import io
16import itertools as it
17import logging
18import math as m
19import numpy as np
20import os
21import shlex
22import shutil
23import time
24
25import matplotlib as mpl
26import matplotlib.pyplot as plt
27
28# some nicer colors borrowed from Seaborn
29# note these include a non-opaque alpha
30COLORS = [
31    '#4c72b0bf', # blue
32    '#dd8452bf', # orange
33    '#55a868bf', # green
34    '#c44e52bf', # red
35    '#8172b3bf', # purple
36    '#937860bf', # brown
37    '#da8bc3bf', # pink
38    '#8c8c8cbf', # gray
39    '#ccb974bf', # yellow
40    '#64b5cdbf', # cyan
41]
42COLORS_DARK = [
43    '#a1c9f4bf', # blue
44    '#ffb482bf', # orange
45    '#8de5a1bf', # green
46    '#ff9f9bbf', # red
47    '#d0bbffbf', # purple
48    '#debb9bbf', # brown
49    '#fab0e4bf', # pink
50    '#cfcfcfbf', # gray
51    '#fffea3bf', # yellow
52    '#b9f2f0bf', # cyan
53]
54ALPHAS = [0.75]
55FORMATS = ['-']
56FORMATS_POINTS = ['.']
57FORMATS_POINTS_AND_LINES = ['.-']
58
59WIDTH = 750
60HEIGHT = 350
61FONT_SIZE = 11
62
63SI_PREFIXES = {
64    18:  'E',
65    15:  'P',
66    12:  'T',
67    9:   'G',
68    6:   'M',
69    3:   'K',
70    0:   '',
71    -3:  'm',
72    -6:  'u',
73    -9:  'n',
74    -12: 'p',
75    -15: 'f',
76    -18: 'a',
77}
78
79SI2_PREFIXES = {
80    60:  'Ei',
81    50:  'Pi',
82    40:  'Ti',
83    30:  'Gi',
84    20:  'Mi',
85    10:  'Ki',
86    0:   '',
87    -10: 'mi',
88    -20: 'ui',
89    -30: 'ni',
90    -40: 'pi',
91    -50: 'fi',
92    -60: 'ai',
93}
94
95
96# formatter for matplotlib
97def si(x):
98    if x == 0:
99        return '0'
100    # figure out prefix and scale
101    p = 3*int(m.log(abs(x), 10**3))
102    p = min(18, max(-18, p))
103    # format with 3 digits of precision
104    s = '%.3f' % (abs(x) / (10.0**p))
105    s = s[:3+1]
106    # truncate but only digits that follow the dot
107    if '.' in s:
108        s = s.rstrip('0')
109        s = s.rstrip('.')
110    return '%s%s%s' % ('-' if x < 0 else '', s, SI_PREFIXES[p])
111
112# formatter for matplotlib
113def si2(x):
114    if x == 0:
115        return '0'
116    # figure out prefix and scale
117    p = 10*int(m.log(abs(x), 2**10))
118    p = min(30, max(-30, p))
119    # format with 3 digits of precision
120    s = '%.3f' % (abs(x) / (2.0**p))
121    s = s[:3+1]
122    # truncate but only digits that follow the dot
123    if '.' in s:
124        s = s.rstrip('0')
125        s = s.rstrip('.')
126    return '%s%s%s' % ('-' if x < 0 else '', s, SI2_PREFIXES[p])
127
128# parse escape strings
129def escape(s):
130    return codecs.escape_decode(s.encode('utf8'))[0].decode('utf8')
131
132# we want to use MaxNLocator, but since MaxNLocator forces multiples of 10
133# to be an option, we can't really...
134class AutoMultipleLocator(mpl.ticker.MultipleLocator):
135    def __init__(self, base, nbins=None):
136        # note base needs to be floats to avoid integer pow issues
137        self.base = float(base)
138        self.nbins = nbins
139        super().__init__(self.base)
140
141    def __call__(self):
142        # find best tick count, conveniently matplotlib has a function for this
143        vmin, vmax = self.axis.get_view_interval()
144        vmin, vmax = mpl.transforms.nonsingular(vmin, vmax, 1e-12, 1e-13)
145        if self.nbins is not None:
146            nbins = self.nbins
147        else:
148            nbins = np.clip(self.axis.get_tick_space(), 1, 9)
149
150        # find the best power, use this as our locator's actual base
151        scale = self.base ** (m.ceil(m.log((vmax-vmin) / (nbins+1), self.base)))
152        self.set_params(scale)
153
154        return super().__call__()
155
156
157def openio(path, mode='r', buffering=-1):
158    # allow '-' for stdin/stdout
159    if path == '-':
160        if mode == 'r':
161            return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering)
162        else:
163            return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering)
164    else:
165        return open(path, mode, buffering)
166
167
168# parse different data representations
169def dat(x):
170    # allow the first part of an a/b fraction
171    if '/' in x:
172        x, _ = x.split('/', 1)
173
174    # first try as int
175    try:
176        return int(x, 0)
177    except ValueError:
178        pass
179
180    # then try as float
181    try:
182        return float(x)
183        # just don't allow infinity or nan
184        if m.isinf(x) or m.isnan(x):
185            raise ValueError("invalid dat %r" % x)
186    except ValueError:
187        pass
188
189    # else give up
190    raise ValueError("invalid dat %r" % x)
191
192def collect(csv_paths, renames=[]):
193    # collect results from CSV files
194    results = []
195    for path in csv_paths:
196        try:
197            with openio(path) as f:
198                reader = csv.DictReader(f, restval='')
199                for r in reader:
200                    results.append(r)
201        except FileNotFoundError:
202            pass
203
204    if renames:
205        for r in results:
206            # make a copy so renames can overlap
207            r_ = {}
208            for new_k, old_k in renames:
209                if old_k in r:
210                    r_[new_k] = r[old_k]
211            r.update(r_)
212
213    return results
214
215def dataset(results, x=None, y=None, define=[]):
216    # organize by 'by', x, and y
217    dataset = {}
218    i = 0
219    for r in results:
220        # filter results by matching defines
221        if not all(k in r and r[k] in vs for k, vs in define):
222            continue
223
224        # find xs
225        if x is not None:
226            if x not in r:
227                continue
228            try:
229                x_ = dat(r[x])
230            except ValueError:
231                continue
232        else:
233            x_ = i
234            i += 1
235
236        # find ys
237        if y is not None:
238            if y not in r:
239                continue
240            try:
241                y_ = dat(r[y])
242            except ValueError:
243                continue
244        else:
245            y_ = None
246
247        if y_ is not None:
248            dataset[x_] = y_ + dataset.get(x_, 0)
249        else:
250            dataset[x_] = y_ or dataset.get(x_, None)
251
252    return dataset
253
254def datasets(results, by=None, x=None, y=None, define=[]):
255    # filter results by matching defines
256    results_ = []
257    for r in results:
258        if all(k in r and r[k] in vs for k, vs in define):
259            results_.append(r)
260    results = results_
261
262    # if y not specified, try to guess from data
263    if y is None:
264        y = co.OrderedDict()
265        for r in results:
266            for k, v in r.items():
267                if (by is None or k not in by) and v.strip():
268                    try:
269                        dat(v)
270                        y[k] = True
271                    except ValueError:
272                        y[k] = False
273        y = list(k for k,v in y.items() if v)
274
275    if by is not None:
276        # find all 'by' values
277        ks = set()
278        for r in results:
279            ks.add(tuple(r.get(k, '') for k in by))
280        ks = sorted(ks)
281
282    # collect all datasets
283    datasets = co.OrderedDict()
284    for ks_ in (ks if by is not None else [()]):
285        for x_ in (x if x is not None else [None]):
286            for y_ in y:
287                # hide x/y if there is only one field
288                k_x = x_ if len(x or []) > 1 else ''
289                k_y = y_ if len(y or []) > 1 or (not ks_ and not k_x) else ''
290
291                datasets[ks_ + (k_x, k_y)] = dataset(
292                    results,
293                    x_,
294                    y_,
295                    [(by_, {k_}) for by_, k_ in zip(by, ks_)]
296                        if by is not None else [])
297
298    return datasets
299
300
301# some classes for organizing subplots into a grid
302class Subplot:
303    def __init__(self, **args):
304        self.x = 0
305        self.y = 0
306        self.xspan = 1
307        self.yspan = 1
308        self.args = args
309
310class Grid:
311    def __init__(self, subplot, width=1.0, height=1.0):
312        self.xweights = [width]
313        self.yweights = [height]
314        self.map = {(0,0): subplot}
315        self.subplots = [subplot]
316
317    def __repr__(self):
318        return 'Grid(%r, %r)' % (self.xweights, self.yweights)
319
320    @property
321    def width(self):
322        return len(self.xweights)
323
324    @property
325    def height(self):
326        return len(self.yweights)
327
328    def __iter__(self):
329        return iter(self.subplots)
330
331    def __getitem__(self, i):
332        x, y = i
333        if x < 0:
334            x += len(self.xweights)
335        if y < 0:
336            y += len(self.yweights)
337
338        return self.map[(x,y)]
339
340    def merge(self, other, dir):
341        if dir in ['above', 'below']:
342            # first scale the two grids so they line up
343            self_xweights = self.xweights
344            other_xweights = other.xweights
345            self_w = sum(self_xweights)
346            other_w = sum(other_xweights)
347            ratio = self_w / other_w
348            other_xweights = [s*ratio for s in other_xweights]
349
350            # now interleave xweights as needed
351            new_xweights = []
352            self_map = {}
353            other_map = {}
354            self_i = 0
355            other_i = 0
356            self_xweight = (self_xweights[self_i]
357                if self_i < len(self_xweights) else m.inf)
358            other_xweight = (other_xweights[other_i]
359                if other_i < len(other_xweights) else m.inf)
360            while self_i < len(self_xweights) and other_i < len(other_xweights):
361                if other_xweight - self_xweight > 0.0000001:
362                    new_xweights.append(self_xweight)
363                    other_xweight -= self_xweight
364
365                    new_i = len(new_xweights)-1
366                    for j in range(len(self.yweights)):
367                        self_map[(new_i, j)] = self.map[(self_i, j)]
368                    for j in range(len(other.yweights)):
369                        other_map[(new_i, j)] = other.map[(other_i, j)]
370                    for s in other.subplots:
371                        if s.x+s.xspan-1 == new_i:
372                            s.xspan += 1
373                        elif s.x > new_i:
374                            s.x += 1
375
376                    self_i += 1
377                    self_xweight = (self_xweights[self_i]
378                        if self_i < len(self_xweights) else m.inf)
379                elif self_xweight - other_xweight > 0.0000001:
380                    new_xweights.append(other_xweight)
381                    self_xweight -= other_xweight
382
383                    new_i = len(new_xweights)-1
384                    for j in range(len(other.yweights)):
385                        other_map[(new_i, j)] = other.map[(other_i, j)]
386                    for j in range(len(self.yweights)):
387                        self_map[(new_i, j)] = self.map[(self_i, j)]
388                    for s in self.subplots:
389                        if s.x+s.xspan-1 == new_i:
390                            s.xspan += 1
391                        elif s.x > new_i:
392                            s.x += 1
393
394                    other_i += 1
395                    other_xweight = (other_xweights[other_i]
396                        if other_i < len(other_xweights) else m.inf)
397                else:
398                    new_xweights.append(self_xweight)
399
400                    new_i = len(new_xweights)-1
401                    for j in range(len(self.yweights)):
402                        self_map[(new_i, j)] = self.map[(self_i, j)]
403                    for j in range(len(other.yweights)):
404                        other_map[(new_i, j)] = other.map[(other_i, j)]
405
406                    self_i += 1
407                    self_xweight = (self_xweights[self_i]
408                        if self_i < len(self_xweights) else m.inf)
409                    other_i += 1
410                    other_xweight = (other_xweights[other_i]
411                        if other_i < len(other_xweights) else m.inf)
412
413            # squish so ratios are preserved
414            self_h = sum(self.yweights)
415            other_h = sum(other.yweights)
416            ratio = (self_h-other_h) / self_h
417            self_yweights = [s*ratio for s in self.yweights]
418
419            # finally concatenate the two grids
420            if dir == 'above':
421                for s in other.subplots:
422                    s.y += len(self_yweights)
423                self.subplots.extend(other.subplots)
424
425                self.xweights = new_xweights
426                self.yweights = self_yweights + other.yweights
427                self.map = self_map | {(x, y+len(self_yweights)): s
428                    for (x, y), s in other_map.items()}
429            else:
430                for s in self.subplots:
431                    s.y += len(other.yweights)
432                self.subplots.extend(other.subplots)
433
434                self.xweights = new_xweights
435                self.yweights = other.yweights + self_yweights
436                self.map = other_map | {(x, y+len(other.yweights)): s
437                    for (x, y), s in self_map.items()}
438
439        if dir in ['right', 'left']:
440            # first scale the two grids so they line up
441            self_yweights = self.yweights
442            other_yweights = other.yweights
443            self_h = sum(self_yweights)
444            other_h = sum(other_yweights)
445            ratio = self_h / other_h
446            other_yweights = [s*ratio for s in other_yweights]
447
448            # now interleave yweights as needed
449            new_yweights = []
450            self_map = {}
451            other_map = {}
452            self_i = 0
453            other_i = 0
454            self_yweight = (self_yweights[self_i]
455                if self_i < len(self_yweights) else m.inf)
456            other_yweight = (other_yweights[other_i]
457                if other_i < len(other_yweights) else m.inf)
458            while self_i < len(self_yweights) and other_i < len(other_yweights):
459                if other_yweight - self_yweight > 0.0000001:
460                    new_yweights.append(self_yweight)
461                    other_yweight -= self_yweight
462
463                    new_i = len(new_yweights)-1
464                    for j in range(len(self.xweights)):
465                        self_map[(j, new_i)] = self.map[(j, self_i)]
466                    for j in range(len(other.xweights)):
467                        other_map[(j, new_i)] = other.map[(j, other_i)]
468                    for s in other.subplots:
469                        if s.y+s.yspan-1 == new_i:
470                            s.yspan += 1
471                        elif s.y > new_i:
472                            s.y += 1
473
474                    self_i += 1
475                    self_yweight = (self_yweights[self_i]
476                        if self_i < len(self_yweights) else m.inf)
477                elif self_yweight - other_yweight > 0.0000001:
478                    new_yweights.append(other_yweight)
479                    self_yweight -= other_yweight
480
481                    new_i = len(new_yweights)-1
482                    for j in range(len(other.xweights)):
483                        other_map[(j, new_i)] = other.map[(j, other_i)]
484                    for j in range(len(self.xweights)):
485                        self_map[(j, new_i)] = self.map[(j, self_i)]
486                    for s in self.subplots:
487                        if s.y+s.yspan-1 == new_i:
488                            s.yspan += 1
489                        elif s.y > new_i:
490                            s.y += 1
491
492                    other_i += 1
493                    other_yweight = (other_yweights[other_i]
494                        if other_i < len(other_yweights) else m.inf)
495                else:
496                    new_yweights.append(self_yweight)
497
498                    new_i = len(new_yweights)-1
499                    for j in range(len(self.xweights)):
500                        self_map[(j, new_i)] = self.map[(j, self_i)]
501                    for j in range(len(other.xweights)):
502                        other_map[(j, new_i)] = other.map[(j, other_i)]
503
504                    self_i += 1
505                    self_yweight = (self_yweights[self_i]
506                        if self_i < len(self_yweights) else m.inf)
507                    other_i += 1
508                    other_yweight = (other_yweights[other_i]
509                        if other_i < len(other_yweights) else m.inf)
510
511            # squish so ratios are preserved
512            self_w = sum(self.xweights)
513            other_w = sum(other.xweights)
514            ratio = (self_w-other_w) / self_w
515            self_xweights = [s*ratio for s in self.xweights]
516
517            # finally concatenate the two grids
518            if dir == 'right':
519                for s in other.subplots:
520                    s.x += len(self_xweights)
521                self.subplots.extend(other.subplots)
522
523                self.xweights = self_xweights + other.xweights
524                self.yweights = new_yweights
525                self.map = self_map | {(x+len(self_xweights), y): s
526                    for (x, y), s in other_map.items()}
527            else:
528                for s in self.subplots:
529                    s.x += len(other.xweights)
530                self.subplots.extend(other.subplots)
531
532                self.xweights = other.xweights + self_xweights
533                self.yweights = new_yweights
534                self.map = other_map | {(x+len(other.xweights), y): s
535                    for (x, y), s in self_map.items()}
536
537
538    def scale(self, width, height):
539        self.xweights = [s*width for s in self.xweights]
540        self.yweights = [s*height for s in self.yweights]
541
542    @classmethod
543    def fromargs(cls, width=1.0, height=1.0, *,
544            subplots=[],
545            **args):
546        grid = cls(Subplot(**args))
547
548        for dir, subargs in subplots:
549            subgrid = cls.fromargs(
550                width=subargs.pop('width',
551                    0.5 if dir in ['right', 'left'] else width),
552                height=subargs.pop('height',
553                    0.5 if dir in ['above', 'below'] else height),
554                **subargs)
555            grid.merge(subgrid, dir)
556
557        grid.scale(width, height)
558        return grid
559
560
561def main(csv_paths, output, *,
562        svg=False,
563        png=False,
564        quiet=False,
565        by=None,
566        x=None,
567        y=None,
568        define=[],
569        points=False,
570        points_and_lines=False,
571        colors=None,
572        formats=None,
573        width=WIDTH,
574        height=HEIGHT,
575        xlim=(None,None),
576        ylim=(None,None),
577        xlog=False,
578        ylog=False,
579        x2=False,
580        y2=False,
581        xticks=None,
582        yticks=None,
583        xunits=None,
584        yunits=None,
585        xlabel=None,
586        ylabel=None,
587        xticklabels=None,
588        yticklabels=None,
589        title=None,
590        legend_right=False,
591        legend_above=False,
592        legend_below=False,
593        dark=False,
594        ggplot=False,
595        xkcd=False,
596        github=False,
597        font=None,
598        font_size=FONT_SIZE,
599        font_color=None,
600        foreground=None,
601        background=None,
602        subplot={},
603        subplots=[],
604        **args):
605    # guess the output format
606    if not png and not svg:
607        if output.endswith('.png'):
608            png = True
609        else:
610            svg = True
611
612    # some shortcuts for color schemes
613    if github:
614        ggplot = True
615        if font_color is None:
616            if dark:
617                font_color = '#c9d1d9'
618            else:
619                font_color = '#24292f'
620        if foreground is None:
621            if dark:
622                foreground = '#343942'
623            else:
624                foreground = '#eff1f3'
625        if background is None:
626            if dark:
627                background = '#0d1117'
628            else:
629                background = '#ffffff'
630
631    # what colors/alphas/formats to use?
632    if colors is not None:
633        colors_ = colors
634    elif dark:
635        colors_ = COLORS_DARK
636    else:
637        colors_ = COLORS
638
639    if formats is not None:
640        formats_ = formats
641    elif points_and_lines:
642        formats_ = FORMATS_POINTS_AND_LINES
643    elif points:
644        formats_ = FORMATS_POINTS
645    else:
646        formats_ = FORMATS
647
648    if font_color is not None:
649        font_color_ = font_color
650    elif dark:
651        font_color_ = '#ffffff'
652    else:
653        font_color_ = '#000000'
654
655    if foreground is not None:
656        foreground_ = foreground
657    elif dark:
658        foreground_ = '#333333'
659    else:
660        foreground_ = '#e5e5e5'
661
662    if background is not None:
663        background_ = background
664    elif dark:
665        background_ = '#000000'
666    else:
667        background_ = '#ffffff'
668
669    # configure some matplotlib settings
670    if xkcd:
671        # the font search here prints a bunch of unhelpful warnings
672        logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)
673        plt.xkcd()
674        # turn off the white outline, this breaks some things
675        plt.rc('path', effects=[])
676    if ggplot:
677        plt.style.use('ggplot')
678        plt.rc('patch', linewidth=0)
679        plt.rc('axes', facecolor=foreground_, edgecolor=background_)
680        plt.rc('grid', color=background_)
681        # fix the the gridlines when ggplot+xkcd
682        if xkcd:
683            plt.rc('grid', linewidth=1)
684            plt.rc('axes.spines', bottom=False, left=False)
685    if dark:
686        plt.style.use('dark_background')
687        plt.rc('savefig', facecolor='auto', edgecolor='auto')
688        # fix ggplot when dark
689        if ggplot:
690            plt.rc('axes',
691                facecolor=foreground_,
692                edgecolor=background_)
693            plt.rc('grid', color=background_)
694
695    if font is not None:
696        plt.rc('font', family=font)
697    plt.rc('font', size=font_size)
698    plt.rc('text', color=font_color_)
699    plt.rc('figure',
700        titlesize='medium',
701        labelsize='small')
702    plt.rc('axes',
703        titlesize='small',
704        labelsize='small',
705        labelcolor=font_color_)
706    if not ggplot:
707        plt.rc('axes', edgecolor=font_color_)
708    plt.rc('xtick', labelsize='small', color=font_color_)
709    plt.rc('ytick', labelsize='small', color=font_color_)
710    plt.rc('legend',
711        fontsize='small',
712        fancybox=False,
713        framealpha=None,
714        edgecolor=foreground_,
715        borderaxespad=0)
716    plt.rc('axes.spines', top=False, right=False)
717
718    plt.rc('figure', facecolor=background_, edgecolor=background_)
719    if not ggplot:
720        plt.rc('axes', facecolor='#00000000')
721
722    # I think the svg backend just ignores DPI, but seems to use something
723    # equivalent to 96, maybe this is the default for SVG rendering?
724    plt.rc('figure', dpi=96)
725
726    # separate out renames
727    renames = list(it.chain.from_iterable(
728        ((k, v) for v in vs)
729        for k, vs in it.chain(by or [], x or [], y or [])))
730    if by is not None:
731        by = [k for k, _ in by]
732    if x is not None:
733        x = [k for k, _ in x]
734    if y is not None:
735        y = [k for k, _ in y]
736
737    # first collect results from CSV files
738    results = collect(csv_paths, renames)
739
740    # then extract the requested datasets
741    datasets_ = datasets(results, by, x, y, define)
742
743    # figure out formats/colors here so that subplot defines
744    # don't change them later, that'd be bad
745    dataformats_ = {
746        name: formats_[i % len(formats_)]
747        for i, name in enumerate(datasets_.keys())}
748    datacolors_ = {
749        name: colors_[i % len(colors_)]
750        for i, name in enumerate(datasets_.keys())}
751
752    # create a grid of subplots
753    grid = Grid.fromargs(
754        subplots=subplots + subplot.pop('subplots', []),
755        **subplot)
756
757    # create a matplotlib plot
758    fig = plt.figure(figsize=(
759        width/plt.rcParams['figure.dpi'],
760        height/plt.rcParams['figure.dpi']),
761        layout='constrained',
762        # we need a linewidth to keep xkcd mode happy
763        linewidth=8 if xkcd else 0)
764
765    gs = fig.add_gridspec(
766        grid.height
767            + (1 if legend_above else 0)
768            + (1 if legend_below else 0),
769        grid.width
770            + (1 if legend_right else 0),
771        height_ratios=([0.001] if legend_above else [])
772            + [max(s, 0.01) for s in reversed(grid.yweights)]
773            + ([0.001] if legend_below else []),
774        width_ratios=[max(s, 0.01) for s in grid.xweights]
775            + ([0.001] if legend_right else []))
776
777    # first create axes so that plots can interact with each other
778    for s in grid:
779        s.ax = fig.add_subplot(gs[
780            grid.height-(s.y+s.yspan) + (1 if legend_above else 0)
781                : grid.height-s.y + (1 if legend_above else 0),
782            s.x
783                : s.x+s.xspan])
784
785    # now plot each subplot
786    for s in grid:
787        # allow subplot params to override global params
788        define_ = define + s.args.get('define', [])
789        xlim_ = s.args.get('xlim', xlim)
790        ylim_ = s.args.get('ylim', ylim)
791        xlog_ = s.args.get('xlog', False) or xlog
792        ylog_ = s.args.get('ylog', False) or ylog
793        x2_ = s.args.get('x2', False) or x2
794        y2_ = s.args.get('y2', False) or y2
795        xticks_ = s.args.get('xticks', xticks)
796        yticks_ = s.args.get('yticks', yticks)
797        xunits_ = s.args.get('xunits', xunits)
798        yunits_ = s.args.get('yunits', yunits)
799        xticklabels_ = s.args.get('xticklabels', xticklabels)
800        yticklabels_ = s.args.get('yticklabels', yticklabels)
801
802        # label/titles are handled a bit differently in subplots
803        subtitle = s.args.get('title')
804        xsublabel = s.args.get('xlabel')
805        ysublabel = s.args.get('ylabel')
806
807        # allow shortened ranges
808        if len(xlim_) == 1:
809            xlim_ = (0, xlim_[0])
810        if len(ylim_) == 1:
811            ylim_ = (0, ylim_[0])
812
813        # data can be constrained by subplot-specific defines,
814        # so re-extract for each plot
815        subdatasets = datasets(results, by, x, y, define_)
816
817        # plot!
818        ax = s.ax
819        for name, dataset in subdatasets.items():
820            dats = sorted((x,y) for x,y in dataset.items())
821            ax.plot([x for x,_ in dats], [y for _,y in dats],
822                dataformats_[name],
823                color=datacolors_[name],
824                label=','.join(k for k in name if k))
825
826        # axes scaling
827        if xlog_:
828            ax.set_xscale('symlog')
829            ax.xaxis.set_minor_locator(mpl.ticker.NullLocator())
830        if ylog_:
831            ax.set_yscale('symlog')
832            ax.yaxis.set_minor_locator(mpl.ticker.NullLocator())
833        # axes limits
834        ax.set_xlim(
835            xlim_[0] if xlim_[0] is not None
836                else min(it.chain([0], (k
837                    for r in subdatasets.values()
838                    for k, v in r.items()
839                    if v is not None))),
840            xlim_[1] if xlim_[1] is not None
841                else max(it.chain([0], (k
842                    for r in subdatasets.values()
843                    for k, v in r.items()
844                    if v is not None))))
845        ax.set_ylim(
846            ylim_[0] if ylim_[0] is not None
847                else min(it.chain([0], (v
848                    for r in subdatasets.values()
849                    for _, v in r.items()
850                    if v is not None))),
851            ylim_[1] if ylim_[1] is not None
852                else max(it.chain([0], (v
853                    for r in subdatasets.values()
854                    for _, v in r.items()
855                    if v is not None))))
856        # axes ticks
857        if x2_:
858            ax.xaxis.set_major_formatter(lambda x, pos:
859                si2(x)+(xunits_ if xunits_ else ''))
860            if xticklabels_ is not None:
861                ax.xaxis.set_ticklabels(xticklabels_)
862            if xticks_ is None:
863                ax.xaxis.set_major_locator(AutoMultipleLocator(2))
864            elif isinstance(xticks_, list):
865                ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks_))
866            elif xticks_ != 0:
867                ax.xaxis.set_major_locator(AutoMultipleLocator(2, xticks_-1))
868            else:
869                ax.xaxis.set_major_locator(mpl.ticker.NullLocator())
870        else:
871            ax.xaxis.set_major_formatter(lambda x, pos:
872                si(x)+(xunits_ if xunits_ else ''))
873            if xticklabels_ is not None:
874                ax.xaxis.set_ticklabels(xticklabels_)
875            if xticks_ is None:
876                ax.xaxis.set_major_locator(mpl.ticker.AutoLocator())
877            elif isinstance(xticks_, list):
878                ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks_))
879            elif xticks_ != 0:
880                ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(xticks_-1))
881            else:
882                ax.xaxis.set_major_locator(mpl.ticker.NullLocator())
883        if y2_:
884            ax.yaxis.set_major_formatter(lambda x, pos:
885                si2(x)+(yunits_ if yunits_ else ''))
886            if yticklabels_ is not None:
887                ax.yaxis.set_ticklabels(yticklabels_)
888            if yticks_ is None:
889                ax.yaxis.set_major_locator(AutoMultipleLocator(2))
890            elif isinstance(yticks_, list):
891                ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks_))
892            elif yticks_ != 0:
893                ax.yaxis.set_major_locator(AutoMultipleLocator(2, yticks_-1))
894            else:
895                ax.yaxis.set_major_locator(mpl.ticker.NullLocator())
896        else:
897            ax.yaxis.set_major_formatter(lambda x, pos:
898                si(x)+(yunits_ if yunits_ else ''))
899            if yticklabels_ is not None:
900                ax.yaxis.set_ticklabels(yticklabels_)
901            if yticks_ is None:
902                ax.yaxis.set_major_locator(mpl.ticker.AutoLocator())
903            elif isinstance(yticks_, list):
904                ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks_))
905            elif yticks_ != 0:
906                ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(yticks_-1))
907            else:
908                ax.yaxis.set_major_locator(mpl.ticker.NullLocator())
909        if ggplot:
910            ax.grid(sketch_params=None)
911
912        # axes subplot labels
913        if xsublabel is not None:
914            ax.set_xlabel(escape(xsublabel))
915        if ysublabel is not None:
916            ax.set_ylabel(escape(ysublabel))
917        if subtitle is not None:
918            ax.set_title(escape(subtitle))
919
920    # add a legend? a bit tricky with matplotlib
921    #
922    # the best solution I've found is a dedicated, invisible axes for the
923    # legend, hacky, but it works.
924    #
925    # note this was written before constrained_layout supported legend
926    # collisions, hopefully this is added in the future
927    labels = co.OrderedDict()
928    for s in grid:
929        for h, l in zip(*s.ax.get_legend_handles_labels()):
930            labels[l] = h
931
932    if legend_right:
933        ax = fig.add_subplot(gs[(1 if legend_above else 0):,-1])
934        ax.set_axis_off()
935        ax.legend(
936            labels.values(),
937            labels.keys(),
938            loc='upper left',
939            fancybox=False,
940            borderaxespad=0)
941
942    if legend_above:
943        ax = fig.add_subplot(gs[0, :grid.width])
944        ax.set_axis_off()
945
946        # try different column counts until we fit in the axes
947        for ncol in reversed(range(1, len(labels)+1)):
948            legend_ = ax.legend(
949                labels.values(),
950                labels.keys(),
951                loc='upper center',
952                ncol=ncol,
953                fancybox=False,
954                borderaxespad=0)
955
956            if (legend_.get_window_extent().width
957                    <= ax.get_window_extent().width):
958                break
959
960    if legend_below:
961        ax = fig.add_subplot(gs[-1, :grid.width])
962        ax.set_axis_off()
963
964        # big hack to get xlabel above the legend! but hey this
965        # works really well actually
966        if xlabel:
967            ax.set_title(escape(xlabel),
968                size=plt.rcParams['axes.labelsize'],
969                weight=plt.rcParams['axes.labelweight'])
970
971        # try different column counts until we fit in the axes
972        for ncol in reversed(range(1, len(labels)+1)):
973            legend_ = ax.legend(
974                labels.values(),
975                labels.keys(),
976                loc='upper center',
977                ncol=ncol,
978                fancybox=False,
979                borderaxespad=0)
980
981            if (legend_.get_window_extent().width
982                    <= ax.get_window_extent().width):
983                break
984
985
986    # axes labels, NOTE we reposition these below
987    if xlabel is not None and not legend_below:
988        fig.supxlabel(escape(xlabel))
989    if ylabel is not None:
990        fig.supylabel(escape(ylabel))
991    if title is not None:
992        fig.suptitle(escape(title))
993
994    # precompute constrained layout and find midpoints to adjust things
995    # that should be centered so they are actually centered
996    fig.canvas.draw()
997    xmid = (grid[0,0].ax.get_position().x0 + grid[-1,0].ax.get_position().x1)/2
998    ymid = (grid[0,0].ax.get_position().y0 + grid[0,-1].ax.get_position().y1)/2
999
1000    if xlabel is not None and not legend_below:
1001        fig.supxlabel(escape(xlabel), x=xmid)
1002    if ylabel is not None:
1003        fig.supylabel(escape(ylabel), y=ymid)
1004    if title is not None:
1005        fig.suptitle(escape(title), x=xmid)
1006
1007
1008    # write the figure!
1009    plt.savefig(output, format='png' if png else 'svg')
1010
1011    # some stats
1012    if not quiet:
1013        print('updated %s, %s datasets, %s points' % (
1014            output,
1015            len(datasets_),
1016            sum(len(dataset) for dataset in datasets_.values())))
1017
1018
1019if __name__ == "__main__":
1020    import sys
1021    import argparse
1022    parser = argparse.ArgumentParser(
1023        description="Plot CSV files with matplotlib.",
1024        allow_abbrev=False)
1025    parser.add_argument(
1026        'csv_paths',
1027        nargs='*',
1028        help="Input *.csv files.")
1029    output_rule = parser.add_argument(
1030        '-o', '--output',
1031        required=True,
1032        help="Output *.svg/*.png file.")
1033    parser.add_argument(
1034        '--svg',
1035        action='store_true',
1036        help="Output an svg file. By default this is infered.")
1037    parser.add_argument(
1038        '--png',
1039        action='store_true',
1040        help="Output a png file. By default this is infered.")
1041    parser.add_argument(
1042        '-q', '--quiet',
1043        action='store_true',
1044        help="Don't print info.")
1045    parser.add_argument(
1046        '-b', '--by',
1047        action='append',
1048        type=lambda x: (
1049            lambda k,v=None: (k, v.split(',') if v is not None else ())
1050            )(*x.split('=', 1)),
1051        help="Group by this field. Can rename fields with new_name=old_name.")
1052    parser.add_argument(
1053        '-x',
1054        action='append',
1055        type=lambda x: (
1056            lambda k,v=None: (k, v.split(',') if v is not None else ())
1057            )(*x.split('=', 1)),
1058        help="Field to use for the x-axis. Can rename fields with "
1059            "new_name=old_name.")
1060    parser.add_argument(
1061        '-y',
1062        action='append',
1063        type=lambda x: (
1064            lambda k,v=None: (k, v.split(',') if v is not None else ())
1065            )(*x.split('=', 1)),
1066        help="Field to use for the y-axis. Can rename fields with "
1067            "new_name=old_name.")
1068    parser.add_argument(
1069        '-D', '--define',
1070        type=lambda x: (lambda k,v: (k, set(v.split(','))))(*x.split('=', 1)),
1071        action='append',
1072        help="Only include results where this field is this value. May include "
1073            "comma-separated options.")
1074    parser.add_argument(
1075        '-.', '--points',
1076        action='store_true',
1077        help="Only draw data points.")
1078    parser.add_argument(
1079        '-!', '--points-and-lines',
1080        action='store_true',
1081        help="Draw data points and lines.")
1082    parser.add_argument(
1083        '--colors',
1084        type=lambda x: [x.strip() for x in x.split(',')],
1085        help="Comma-separated hex colors to use.")
1086    parser.add_argument(
1087        '--formats',
1088        type=lambda x: [x.strip().replace('0',',') for x in x.split(',')],
1089        help="Comma-separated matplotlib formats to use. Allows '0' as an "
1090            "alternative for ','.")
1091    parser.add_argument(
1092        '-W', '--width',
1093        type=lambda x: int(x, 0),
1094        help="Width in pixels. Defaults to %r." % WIDTH)
1095    parser.add_argument(
1096        '-H', '--height',
1097        type=lambda x: int(x, 0),
1098        help="Height in pixels. Defaults to %r." % HEIGHT)
1099    parser.add_argument(
1100        '-X', '--xlim',
1101        type=lambda x: tuple(
1102            dat(x) if x.strip() else None
1103            for x in x.split(',')),
1104        help="Range for the x-axis.")
1105    parser.add_argument(
1106        '-Y', '--ylim',
1107        type=lambda x: tuple(
1108            dat(x) if x.strip() else None
1109            for x in x.split(',')),
1110        help="Range for the y-axis.")
1111    parser.add_argument(
1112        '--xlog',
1113        action='store_true',
1114        help="Use a logarithmic x-axis.")
1115    parser.add_argument(
1116        '--ylog',
1117        action='store_true',
1118        help="Use a logarithmic y-axis.")
1119    parser.add_argument(
1120        '--x2',
1121        action='store_true',
1122        help="Use base-2 prefixes for the x-axis.")
1123    parser.add_argument(
1124        '--y2',
1125        action='store_true',
1126        help="Use base-2 prefixes for the y-axis.")
1127    parser.add_argument(
1128        '--xticks',
1129        type=lambda x: int(x, 0) if ',' not in x
1130            else [dat(x) for x in x.split(',')],
1131        help="Ticks for the x-axis. This can be explicit comma-separated "
1132            "ticks, the number of ticks, or 0 to disable.")
1133    parser.add_argument(
1134        '--yticks',
1135        type=lambda x: int(x, 0) if ',' not in x
1136            else [dat(x) for x in x.split(',')],
1137        help="Ticks for the y-axis. This can be explicit comma-separated "
1138            "ticks, the number of ticks, or 0 to disable.")
1139    parser.add_argument(
1140        '--xunits',
1141        help="Units for the x-axis.")
1142    parser.add_argument(
1143        '--yunits',
1144        help="Units for the y-axis.")
1145    parser.add_argument(
1146        '--xlabel',
1147        help="Add a label to the x-axis.")
1148    parser.add_argument(
1149        '--ylabel',
1150        help="Add a label to the y-axis.")
1151    parser.add_argument(
1152        '--xticklabels',
1153        type=lambda x:
1154            [x.strip() for x in x.split(',')]
1155            if x.strip() else [],
1156        help="Comma separated xticklabels.")
1157    parser.add_argument(
1158        '--yticklabels',
1159        type=lambda x:
1160            [x.strip() for x in x.split(',')]
1161            if x.strip() else [],
1162        help="Comma separated yticklabels.")
1163    parser.add_argument(
1164        '-t', '--title',
1165        help="Add a title.")
1166    parser.add_argument(
1167        '-l', '--legend-right',
1168        action='store_true',
1169        help="Place a legend to the right.")
1170    parser.add_argument(
1171        '--legend-above',
1172        action='store_true',
1173        help="Place a legend above.")
1174    parser.add_argument(
1175        '--legend-below',
1176        action='store_true',
1177        help="Place a legend below.")
1178    parser.add_argument(
1179        '--dark',
1180        action='store_true',
1181        help="Use the dark style.")
1182    parser.add_argument(
1183        '--ggplot',
1184        action='store_true',
1185        help="Use the ggplot style.")
1186    parser.add_argument(
1187        '--xkcd',
1188        action='store_true',
1189        help="Use the xkcd style.")
1190    parser.add_argument(
1191        '--github',
1192        action='store_true',
1193        help="Use the ggplot style with GitHub colors.")
1194    parser.add_argument(
1195        '--font',
1196        type=lambda x: [x.strip() for x in x.split(',')],
1197        help="Font family for matplotlib.")
1198    parser.add_argument(
1199        '--font-size',
1200        help="Font size for matplotlib. Defaults to %r." % FONT_SIZE)
1201    parser.add_argument(
1202        '--font-color',
1203        help="Color for the font and other line elements.")
1204    parser.add_argument(
1205        '--foreground',
1206        help="Foreground color to use.")
1207    parser.add_argument(
1208        '--background',
1209        help="Background color to use.")
1210    class AppendSubplot(argparse.Action):
1211        @staticmethod
1212        def parse(value):
1213            import copy
1214            subparser = copy.deepcopy(parser)
1215            next(a for a in subparser._actions
1216                if '--output' in a.option_strings).required = False
1217            next(a for a in subparser._actions
1218                if '--width' in a.option_strings).type = float
1219            next(a for a in subparser._actions
1220                if '--height' in a.option_strings).type = float
1221            return subparser.parse_intermixed_args(shlex.split(value or ""))
1222        def __call__(self, parser, namespace, value, option):
1223            if not hasattr(namespace, 'subplots'):
1224                namespace.subplots = []
1225            namespace.subplots.append((
1226                option.split('-')[-1],
1227                self.__class__.parse(value)))
1228    parser.add_argument(
1229        '--subplot-above',
1230        action=AppendSubplot,
1231        help="Add subplot above with the same dataset. Takes an arg string to "
1232            "control the subplot which supports most (but not all) of the "
1233            "parameters listed here. The relative dimensions of the subplot "
1234            "can be controlled with -W/-H which now take a percentage.")
1235    parser.add_argument(
1236        '--subplot-below',
1237        action=AppendSubplot,
1238        help="Add subplot below with the same dataset.")
1239    parser.add_argument(
1240        '--subplot-left',
1241        action=AppendSubplot,
1242        help="Add subplot left with the same dataset.")
1243    parser.add_argument(
1244        '--subplot-right',
1245        action=AppendSubplot,
1246        help="Add subplot right with the same dataset.")
1247    parser.add_argument(
1248        '--subplot',
1249        type=AppendSubplot.parse,
1250        help="Add subplot-specific arguments to the main plot.")
1251
1252    def dictify(ns):
1253        if hasattr(ns, 'subplots'):
1254            ns.subplots = [(dir, dictify(subplot_ns))
1255                for dir, subplot_ns in ns.subplots]
1256        if ns.subplot is not None:
1257            ns.subplot = dictify(ns.subplot)
1258        return {k: v
1259            for k, v in vars(ns).items()
1260            if v is not None}
1261
1262    sys.exit(main(**dictify(parser.parse_intermixed_args())))
1263