1#!/usr/bin/env python3 2# 3# Plot CSV files with matplotlib. 4# 5# Example: 6# ./scripts/plotmpl.py bench.csv -xSIZE -ybench_read -obench.svg 7# 8# Copyright (c) 2022, The littlefs authors. 9# SPDX-License-Identifier: BSD-3-Clause 10# 11 12import codecs 13import collections as co 14import csv 15import io 16import itertools as it 17import logging 18import math as m 19import numpy as np 20import os 21import shlex 22import shutil 23import time 24 25import matplotlib as mpl 26import matplotlib.pyplot as plt 27 28# some nicer colors borrowed from Seaborn 29# note these include a non-opaque alpha 30COLORS = [ 31 '#4c72b0bf', # blue 32 '#dd8452bf', # orange 33 '#55a868bf', # green 34 '#c44e52bf', # red 35 '#8172b3bf', # purple 36 '#937860bf', # brown 37 '#da8bc3bf', # pink 38 '#8c8c8cbf', # gray 39 '#ccb974bf', # yellow 40 '#64b5cdbf', # cyan 41] 42COLORS_DARK = [ 43 '#a1c9f4bf', # blue 44 '#ffb482bf', # orange 45 '#8de5a1bf', # green 46 '#ff9f9bbf', # red 47 '#d0bbffbf', # purple 48 '#debb9bbf', # brown 49 '#fab0e4bf', # pink 50 '#cfcfcfbf', # gray 51 '#fffea3bf', # yellow 52 '#b9f2f0bf', # cyan 53] 54ALPHAS = [0.75] 55FORMATS = ['-'] 56FORMATS_POINTS = ['.'] 57FORMATS_POINTS_AND_LINES = ['.-'] 58 59WIDTH = 750 60HEIGHT = 350 61FONT_SIZE = 11 62 63SI_PREFIXES = { 64 18: 'E', 65 15: 'P', 66 12: 'T', 67 9: 'G', 68 6: 'M', 69 3: 'K', 70 0: '', 71 -3: 'm', 72 -6: 'u', 73 -9: 'n', 74 -12: 'p', 75 -15: 'f', 76 -18: 'a', 77} 78 79SI2_PREFIXES = { 80 60: 'Ei', 81 50: 'Pi', 82 40: 'Ti', 83 30: 'Gi', 84 20: 'Mi', 85 10: 'Ki', 86 0: '', 87 -10: 'mi', 88 -20: 'ui', 89 -30: 'ni', 90 -40: 'pi', 91 -50: 'fi', 92 -60: 'ai', 93} 94 95 96# formatter for matplotlib 97def si(x): 98 if x == 0: 99 return '0' 100 # figure out prefix and scale 101 p = 3*int(m.log(abs(x), 10**3)) 102 p = min(18, max(-18, p)) 103 # format with 3 digits of precision 104 s = '%.3f' % (abs(x) / (10.0**p)) 105 s = s[:3+1] 106 # truncate but only digits that follow the dot 107 if '.' in s: 108 s = s.rstrip('0') 109 s = s.rstrip('.') 110 return '%s%s%s' % ('-' if x < 0 else '', s, SI_PREFIXES[p]) 111 112# formatter for matplotlib 113def si2(x): 114 if x == 0: 115 return '0' 116 # figure out prefix and scale 117 p = 10*int(m.log(abs(x), 2**10)) 118 p = min(30, max(-30, p)) 119 # format with 3 digits of precision 120 s = '%.3f' % (abs(x) / (2.0**p)) 121 s = s[:3+1] 122 # truncate but only digits that follow the dot 123 if '.' in s: 124 s = s.rstrip('0') 125 s = s.rstrip('.') 126 return '%s%s%s' % ('-' if x < 0 else '', s, SI2_PREFIXES[p]) 127 128# parse escape strings 129def escape(s): 130 return codecs.escape_decode(s.encode('utf8'))[0].decode('utf8') 131 132# we want to use MaxNLocator, but since MaxNLocator forces multiples of 10 133# to be an option, we can't really... 134class AutoMultipleLocator(mpl.ticker.MultipleLocator): 135 def __init__(self, base, nbins=None): 136 # note base needs to be floats to avoid integer pow issues 137 self.base = float(base) 138 self.nbins = nbins 139 super().__init__(self.base) 140 141 def __call__(self): 142 # find best tick count, conveniently matplotlib has a function for this 143 vmin, vmax = self.axis.get_view_interval() 144 vmin, vmax = mpl.transforms.nonsingular(vmin, vmax, 1e-12, 1e-13) 145 if self.nbins is not None: 146 nbins = self.nbins 147 else: 148 nbins = np.clip(self.axis.get_tick_space(), 1, 9) 149 150 # find the best power, use this as our locator's actual base 151 scale = self.base ** (m.ceil(m.log((vmax-vmin) / (nbins+1), self.base))) 152 self.set_params(scale) 153 154 return super().__call__() 155 156 157def openio(path, mode='r', buffering=-1): 158 # allow '-' for stdin/stdout 159 if path == '-': 160 if mode == 'r': 161 return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering) 162 else: 163 return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering) 164 else: 165 return open(path, mode, buffering) 166 167 168# parse different data representations 169def dat(x): 170 # allow the first part of an a/b fraction 171 if '/' in x: 172 x, _ = x.split('/', 1) 173 174 # first try as int 175 try: 176 return int(x, 0) 177 except ValueError: 178 pass 179 180 # then try as float 181 try: 182 return float(x) 183 # just don't allow infinity or nan 184 if m.isinf(x) or m.isnan(x): 185 raise ValueError("invalid dat %r" % x) 186 except ValueError: 187 pass 188 189 # else give up 190 raise ValueError("invalid dat %r" % x) 191 192def collect(csv_paths, renames=[]): 193 # collect results from CSV files 194 results = [] 195 for path in csv_paths: 196 try: 197 with openio(path) as f: 198 reader = csv.DictReader(f, restval='') 199 for r in reader: 200 results.append(r) 201 except FileNotFoundError: 202 pass 203 204 if renames: 205 for r in results: 206 # make a copy so renames can overlap 207 r_ = {} 208 for new_k, old_k in renames: 209 if old_k in r: 210 r_[new_k] = r[old_k] 211 r.update(r_) 212 213 return results 214 215def dataset(results, x=None, y=None, define=[]): 216 # organize by 'by', x, and y 217 dataset = {} 218 i = 0 219 for r in results: 220 # filter results by matching defines 221 if not all(k in r and r[k] in vs for k, vs in define): 222 continue 223 224 # find xs 225 if x is not None: 226 if x not in r: 227 continue 228 try: 229 x_ = dat(r[x]) 230 except ValueError: 231 continue 232 else: 233 x_ = i 234 i += 1 235 236 # find ys 237 if y is not None: 238 if y not in r: 239 continue 240 try: 241 y_ = dat(r[y]) 242 except ValueError: 243 continue 244 else: 245 y_ = None 246 247 if y_ is not None: 248 dataset[x_] = y_ + dataset.get(x_, 0) 249 else: 250 dataset[x_] = y_ or dataset.get(x_, None) 251 252 return dataset 253 254def datasets(results, by=None, x=None, y=None, define=[]): 255 # filter results by matching defines 256 results_ = [] 257 for r in results: 258 if all(k in r and r[k] in vs for k, vs in define): 259 results_.append(r) 260 results = results_ 261 262 # if y not specified, try to guess from data 263 if y is None: 264 y = co.OrderedDict() 265 for r in results: 266 for k, v in r.items(): 267 if (by is None or k not in by) and v.strip(): 268 try: 269 dat(v) 270 y[k] = True 271 except ValueError: 272 y[k] = False 273 y = list(k for k,v in y.items() if v) 274 275 if by is not None: 276 # find all 'by' values 277 ks = set() 278 for r in results: 279 ks.add(tuple(r.get(k, '') for k in by)) 280 ks = sorted(ks) 281 282 # collect all datasets 283 datasets = co.OrderedDict() 284 for ks_ in (ks if by is not None else [()]): 285 for x_ in (x if x is not None else [None]): 286 for y_ in y: 287 # hide x/y if there is only one field 288 k_x = x_ if len(x or []) > 1 else '' 289 k_y = y_ if len(y or []) > 1 or (not ks_ and not k_x) else '' 290 291 datasets[ks_ + (k_x, k_y)] = dataset( 292 results, 293 x_, 294 y_, 295 [(by_, {k_}) for by_, k_ in zip(by, ks_)] 296 if by is not None else []) 297 298 return datasets 299 300 301# some classes for organizing subplots into a grid 302class Subplot: 303 def __init__(self, **args): 304 self.x = 0 305 self.y = 0 306 self.xspan = 1 307 self.yspan = 1 308 self.args = args 309 310class Grid: 311 def __init__(self, subplot, width=1.0, height=1.0): 312 self.xweights = [width] 313 self.yweights = [height] 314 self.map = {(0,0): subplot} 315 self.subplots = [subplot] 316 317 def __repr__(self): 318 return 'Grid(%r, %r)' % (self.xweights, self.yweights) 319 320 @property 321 def width(self): 322 return len(self.xweights) 323 324 @property 325 def height(self): 326 return len(self.yweights) 327 328 def __iter__(self): 329 return iter(self.subplots) 330 331 def __getitem__(self, i): 332 x, y = i 333 if x < 0: 334 x += len(self.xweights) 335 if y < 0: 336 y += len(self.yweights) 337 338 return self.map[(x,y)] 339 340 def merge(self, other, dir): 341 if dir in ['above', 'below']: 342 # first scale the two grids so they line up 343 self_xweights = self.xweights 344 other_xweights = other.xweights 345 self_w = sum(self_xweights) 346 other_w = sum(other_xweights) 347 ratio = self_w / other_w 348 other_xweights = [s*ratio for s in other_xweights] 349 350 # now interleave xweights as needed 351 new_xweights = [] 352 self_map = {} 353 other_map = {} 354 self_i = 0 355 other_i = 0 356 self_xweight = (self_xweights[self_i] 357 if self_i < len(self_xweights) else m.inf) 358 other_xweight = (other_xweights[other_i] 359 if other_i < len(other_xweights) else m.inf) 360 while self_i < len(self_xweights) and other_i < len(other_xweights): 361 if other_xweight - self_xweight > 0.0000001: 362 new_xweights.append(self_xweight) 363 other_xweight -= self_xweight 364 365 new_i = len(new_xweights)-1 366 for j in range(len(self.yweights)): 367 self_map[(new_i, j)] = self.map[(self_i, j)] 368 for j in range(len(other.yweights)): 369 other_map[(new_i, j)] = other.map[(other_i, j)] 370 for s in other.subplots: 371 if s.x+s.xspan-1 == new_i: 372 s.xspan += 1 373 elif s.x > new_i: 374 s.x += 1 375 376 self_i += 1 377 self_xweight = (self_xweights[self_i] 378 if self_i < len(self_xweights) else m.inf) 379 elif self_xweight - other_xweight > 0.0000001: 380 new_xweights.append(other_xweight) 381 self_xweight -= other_xweight 382 383 new_i = len(new_xweights)-1 384 for j in range(len(other.yweights)): 385 other_map[(new_i, j)] = other.map[(other_i, j)] 386 for j in range(len(self.yweights)): 387 self_map[(new_i, j)] = self.map[(self_i, j)] 388 for s in self.subplots: 389 if s.x+s.xspan-1 == new_i: 390 s.xspan += 1 391 elif s.x > new_i: 392 s.x += 1 393 394 other_i += 1 395 other_xweight = (other_xweights[other_i] 396 if other_i < len(other_xweights) else m.inf) 397 else: 398 new_xweights.append(self_xweight) 399 400 new_i = len(new_xweights)-1 401 for j in range(len(self.yweights)): 402 self_map[(new_i, j)] = self.map[(self_i, j)] 403 for j in range(len(other.yweights)): 404 other_map[(new_i, j)] = other.map[(other_i, j)] 405 406 self_i += 1 407 self_xweight = (self_xweights[self_i] 408 if self_i < len(self_xweights) else m.inf) 409 other_i += 1 410 other_xweight = (other_xweights[other_i] 411 if other_i < len(other_xweights) else m.inf) 412 413 # squish so ratios are preserved 414 self_h = sum(self.yweights) 415 other_h = sum(other.yweights) 416 ratio = (self_h-other_h) / self_h 417 self_yweights = [s*ratio for s in self.yweights] 418 419 # finally concatenate the two grids 420 if dir == 'above': 421 for s in other.subplots: 422 s.y += len(self_yweights) 423 self.subplots.extend(other.subplots) 424 425 self.xweights = new_xweights 426 self.yweights = self_yweights + other.yweights 427 self.map = self_map | {(x, y+len(self_yweights)): s 428 for (x, y), s in other_map.items()} 429 else: 430 for s in self.subplots: 431 s.y += len(other.yweights) 432 self.subplots.extend(other.subplots) 433 434 self.xweights = new_xweights 435 self.yweights = other.yweights + self_yweights 436 self.map = other_map | {(x, y+len(other.yweights)): s 437 for (x, y), s in self_map.items()} 438 439 if dir in ['right', 'left']: 440 # first scale the two grids so they line up 441 self_yweights = self.yweights 442 other_yweights = other.yweights 443 self_h = sum(self_yweights) 444 other_h = sum(other_yweights) 445 ratio = self_h / other_h 446 other_yweights = [s*ratio for s in other_yweights] 447 448 # now interleave yweights as needed 449 new_yweights = [] 450 self_map = {} 451 other_map = {} 452 self_i = 0 453 other_i = 0 454 self_yweight = (self_yweights[self_i] 455 if self_i < len(self_yweights) else m.inf) 456 other_yweight = (other_yweights[other_i] 457 if other_i < len(other_yweights) else m.inf) 458 while self_i < len(self_yweights) and other_i < len(other_yweights): 459 if other_yweight - self_yweight > 0.0000001: 460 new_yweights.append(self_yweight) 461 other_yweight -= self_yweight 462 463 new_i = len(new_yweights)-1 464 for j in range(len(self.xweights)): 465 self_map[(j, new_i)] = self.map[(j, self_i)] 466 for j in range(len(other.xweights)): 467 other_map[(j, new_i)] = other.map[(j, other_i)] 468 for s in other.subplots: 469 if s.y+s.yspan-1 == new_i: 470 s.yspan += 1 471 elif s.y > new_i: 472 s.y += 1 473 474 self_i += 1 475 self_yweight = (self_yweights[self_i] 476 if self_i < len(self_yweights) else m.inf) 477 elif self_yweight - other_yweight > 0.0000001: 478 new_yweights.append(other_yweight) 479 self_yweight -= other_yweight 480 481 new_i = len(new_yweights)-1 482 for j in range(len(other.xweights)): 483 other_map[(j, new_i)] = other.map[(j, other_i)] 484 for j in range(len(self.xweights)): 485 self_map[(j, new_i)] = self.map[(j, self_i)] 486 for s in self.subplots: 487 if s.y+s.yspan-1 == new_i: 488 s.yspan += 1 489 elif s.y > new_i: 490 s.y += 1 491 492 other_i += 1 493 other_yweight = (other_yweights[other_i] 494 if other_i < len(other_yweights) else m.inf) 495 else: 496 new_yweights.append(self_yweight) 497 498 new_i = len(new_yweights)-1 499 for j in range(len(self.xweights)): 500 self_map[(j, new_i)] = self.map[(j, self_i)] 501 for j in range(len(other.xweights)): 502 other_map[(j, new_i)] = other.map[(j, other_i)] 503 504 self_i += 1 505 self_yweight = (self_yweights[self_i] 506 if self_i < len(self_yweights) else m.inf) 507 other_i += 1 508 other_yweight = (other_yweights[other_i] 509 if other_i < len(other_yweights) else m.inf) 510 511 # squish so ratios are preserved 512 self_w = sum(self.xweights) 513 other_w = sum(other.xweights) 514 ratio = (self_w-other_w) / self_w 515 self_xweights = [s*ratio for s in self.xweights] 516 517 # finally concatenate the two grids 518 if dir == 'right': 519 for s in other.subplots: 520 s.x += len(self_xweights) 521 self.subplots.extend(other.subplots) 522 523 self.xweights = self_xweights + other.xweights 524 self.yweights = new_yweights 525 self.map = self_map | {(x+len(self_xweights), y): s 526 for (x, y), s in other_map.items()} 527 else: 528 for s in self.subplots: 529 s.x += len(other.xweights) 530 self.subplots.extend(other.subplots) 531 532 self.xweights = other.xweights + self_xweights 533 self.yweights = new_yweights 534 self.map = other_map | {(x+len(other.xweights), y): s 535 for (x, y), s in self_map.items()} 536 537 538 def scale(self, width, height): 539 self.xweights = [s*width for s in self.xweights] 540 self.yweights = [s*height for s in self.yweights] 541 542 @classmethod 543 def fromargs(cls, width=1.0, height=1.0, *, 544 subplots=[], 545 **args): 546 grid = cls(Subplot(**args)) 547 548 for dir, subargs in subplots: 549 subgrid = cls.fromargs( 550 width=subargs.pop('width', 551 0.5 if dir in ['right', 'left'] else width), 552 height=subargs.pop('height', 553 0.5 if dir in ['above', 'below'] else height), 554 **subargs) 555 grid.merge(subgrid, dir) 556 557 grid.scale(width, height) 558 return grid 559 560 561def main(csv_paths, output, *, 562 svg=False, 563 png=False, 564 quiet=False, 565 by=None, 566 x=None, 567 y=None, 568 define=[], 569 points=False, 570 points_and_lines=False, 571 colors=None, 572 formats=None, 573 width=WIDTH, 574 height=HEIGHT, 575 xlim=(None,None), 576 ylim=(None,None), 577 xlog=False, 578 ylog=False, 579 x2=False, 580 y2=False, 581 xticks=None, 582 yticks=None, 583 xunits=None, 584 yunits=None, 585 xlabel=None, 586 ylabel=None, 587 xticklabels=None, 588 yticklabels=None, 589 title=None, 590 legend_right=False, 591 legend_above=False, 592 legend_below=False, 593 dark=False, 594 ggplot=False, 595 xkcd=False, 596 github=False, 597 font=None, 598 font_size=FONT_SIZE, 599 font_color=None, 600 foreground=None, 601 background=None, 602 subplot={}, 603 subplots=[], 604 **args): 605 # guess the output format 606 if not png and not svg: 607 if output.endswith('.png'): 608 png = True 609 else: 610 svg = True 611 612 # some shortcuts for color schemes 613 if github: 614 ggplot = True 615 if font_color is None: 616 if dark: 617 font_color = '#c9d1d9' 618 else: 619 font_color = '#24292f' 620 if foreground is None: 621 if dark: 622 foreground = '#343942' 623 else: 624 foreground = '#eff1f3' 625 if background is None: 626 if dark: 627 background = '#0d1117' 628 else: 629 background = '#ffffff' 630 631 # what colors/alphas/formats to use? 632 if colors is not None: 633 colors_ = colors 634 elif dark: 635 colors_ = COLORS_DARK 636 else: 637 colors_ = COLORS 638 639 if formats is not None: 640 formats_ = formats 641 elif points_and_lines: 642 formats_ = FORMATS_POINTS_AND_LINES 643 elif points: 644 formats_ = FORMATS_POINTS 645 else: 646 formats_ = FORMATS 647 648 if font_color is not None: 649 font_color_ = font_color 650 elif dark: 651 font_color_ = '#ffffff' 652 else: 653 font_color_ = '#000000' 654 655 if foreground is not None: 656 foreground_ = foreground 657 elif dark: 658 foreground_ = '#333333' 659 else: 660 foreground_ = '#e5e5e5' 661 662 if background is not None: 663 background_ = background 664 elif dark: 665 background_ = '#000000' 666 else: 667 background_ = '#ffffff' 668 669 # configure some matplotlib settings 670 if xkcd: 671 # the font search here prints a bunch of unhelpful warnings 672 logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR) 673 plt.xkcd() 674 # turn off the white outline, this breaks some things 675 plt.rc('path', effects=[]) 676 if ggplot: 677 plt.style.use('ggplot') 678 plt.rc('patch', linewidth=0) 679 plt.rc('axes', facecolor=foreground_, edgecolor=background_) 680 plt.rc('grid', color=background_) 681 # fix the the gridlines when ggplot+xkcd 682 if xkcd: 683 plt.rc('grid', linewidth=1) 684 plt.rc('axes.spines', bottom=False, left=False) 685 if dark: 686 plt.style.use('dark_background') 687 plt.rc('savefig', facecolor='auto', edgecolor='auto') 688 # fix ggplot when dark 689 if ggplot: 690 plt.rc('axes', 691 facecolor=foreground_, 692 edgecolor=background_) 693 plt.rc('grid', color=background_) 694 695 if font is not None: 696 plt.rc('font', family=font) 697 plt.rc('font', size=font_size) 698 plt.rc('text', color=font_color_) 699 plt.rc('figure', 700 titlesize='medium', 701 labelsize='small') 702 plt.rc('axes', 703 titlesize='small', 704 labelsize='small', 705 labelcolor=font_color_) 706 if not ggplot: 707 plt.rc('axes', edgecolor=font_color_) 708 plt.rc('xtick', labelsize='small', color=font_color_) 709 plt.rc('ytick', labelsize='small', color=font_color_) 710 plt.rc('legend', 711 fontsize='small', 712 fancybox=False, 713 framealpha=None, 714 edgecolor=foreground_, 715 borderaxespad=0) 716 plt.rc('axes.spines', top=False, right=False) 717 718 plt.rc('figure', facecolor=background_, edgecolor=background_) 719 if not ggplot: 720 plt.rc('axes', facecolor='#00000000') 721 722 # I think the svg backend just ignores DPI, but seems to use something 723 # equivalent to 96, maybe this is the default for SVG rendering? 724 plt.rc('figure', dpi=96) 725 726 # separate out renames 727 renames = list(it.chain.from_iterable( 728 ((k, v) for v in vs) 729 for k, vs in it.chain(by or [], x or [], y or []))) 730 if by is not None: 731 by = [k for k, _ in by] 732 if x is not None: 733 x = [k for k, _ in x] 734 if y is not None: 735 y = [k for k, _ in y] 736 737 # first collect results from CSV files 738 results = collect(csv_paths, renames) 739 740 # then extract the requested datasets 741 datasets_ = datasets(results, by, x, y, define) 742 743 # figure out formats/colors here so that subplot defines 744 # don't change them later, that'd be bad 745 dataformats_ = { 746 name: formats_[i % len(formats_)] 747 for i, name in enumerate(datasets_.keys())} 748 datacolors_ = { 749 name: colors_[i % len(colors_)] 750 for i, name in enumerate(datasets_.keys())} 751 752 # create a grid of subplots 753 grid = Grid.fromargs( 754 subplots=subplots + subplot.pop('subplots', []), 755 **subplot) 756 757 # create a matplotlib plot 758 fig = plt.figure(figsize=( 759 width/plt.rcParams['figure.dpi'], 760 height/plt.rcParams['figure.dpi']), 761 layout='constrained', 762 # we need a linewidth to keep xkcd mode happy 763 linewidth=8 if xkcd else 0) 764 765 gs = fig.add_gridspec( 766 grid.height 767 + (1 if legend_above else 0) 768 + (1 if legend_below else 0), 769 grid.width 770 + (1 if legend_right else 0), 771 height_ratios=([0.001] if legend_above else []) 772 + [max(s, 0.01) for s in reversed(grid.yweights)] 773 + ([0.001] if legend_below else []), 774 width_ratios=[max(s, 0.01) for s in grid.xweights] 775 + ([0.001] if legend_right else [])) 776 777 # first create axes so that plots can interact with each other 778 for s in grid: 779 s.ax = fig.add_subplot(gs[ 780 grid.height-(s.y+s.yspan) + (1 if legend_above else 0) 781 : grid.height-s.y + (1 if legend_above else 0), 782 s.x 783 : s.x+s.xspan]) 784 785 # now plot each subplot 786 for s in grid: 787 # allow subplot params to override global params 788 define_ = define + s.args.get('define', []) 789 xlim_ = s.args.get('xlim', xlim) 790 ylim_ = s.args.get('ylim', ylim) 791 xlog_ = s.args.get('xlog', False) or xlog 792 ylog_ = s.args.get('ylog', False) or ylog 793 x2_ = s.args.get('x2', False) or x2 794 y2_ = s.args.get('y2', False) or y2 795 xticks_ = s.args.get('xticks', xticks) 796 yticks_ = s.args.get('yticks', yticks) 797 xunits_ = s.args.get('xunits', xunits) 798 yunits_ = s.args.get('yunits', yunits) 799 xticklabels_ = s.args.get('xticklabels', xticklabels) 800 yticklabels_ = s.args.get('yticklabels', yticklabels) 801 802 # label/titles are handled a bit differently in subplots 803 subtitle = s.args.get('title') 804 xsublabel = s.args.get('xlabel') 805 ysublabel = s.args.get('ylabel') 806 807 # allow shortened ranges 808 if len(xlim_) == 1: 809 xlim_ = (0, xlim_[0]) 810 if len(ylim_) == 1: 811 ylim_ = (0, ylim_[0]) 812 813 # data can be constrained by subplot-specific defines, 814 # so re-extract for each plot 815 subdatasets = datasets(results, by, x, y, define_) 816 817 # plot! 818 ax = s.ax 819 for name, dataset in subdatasets.items(): 820 dats = sorted((x,y) for x,y in dataset.items()) 821 ax.plot([x for x,_ in dats], [y for _,y in dats], 822 dataformats_[name], 823 color=datacolors_[name], 824 label=','.join(k for k in name if k)) 825 826 # axes scaling 827 if xlog_: 828 ax.set_xscale('symlog') 829 ax.xaxis.set_minor_locator(mpl.ticker.NullLocator()) 830 if ylog_: 831 ax.set_yscale('symlog') 832 ax.yaxis.set_minor_locator(mpl.ticker.NullLocator()) 833 # axes limits 834 ax.set_xlim( 835 xlim_[0] if xlim_[0] is not None 836 else min(it.chain([0], (k 837 for r in subdatasets.values() 838 for k, v in r.items() 839 if v is not None))), 840 xlim_[1] if xlim_[1] is not None 841 else max(it.chain([0], (k 842 for r in subdatasets.values() 843 for k, v in r.items() 844 if v is not None)))) 845 ax.set_ylim( 846 ylim_[0] if ylim_[0] is not None 847 else min(it.chain([0], (v 848 for r in subdatasets.values() 849 for _, v in r.items() 850 if v is not None))), 851 ylim_[1] if ylim_[1] is not None 852 else max(it.chain([0], (v 853 for r in subdatasets.values() 854 for _, v in r.items() 855 if v is not None)))) 856 # axes ticks 857 if x2_: 858 ax.xaxis.set_major_formatter(lambda x, pos: 859 si2(x)+(xunits_ if xunits_ else '')) 860 if xticklabels_ is not None: 861 ax.xaxis.set_ticklabels(xticklabels_) 862 if xticks_ is None: 863 ax.xaxis.set_major_locator(AutoMultipleLocator(2)) 864 elif isinstance(xticks_, list): 865 ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks_)) 866 elif xticks_ != 0: 867 ax.xaxis.set_major_locator(AutoMultipleLocator(2, xticks_-1)) 868 else: 869 ax.xaxis.set_major_locator(mpl.ticker.NullLocator()) 870 else: 871 ax.xaxis.set_major_formatter(lambda x, pos: 872 si(x)+(xunits_ if xunits_ else '')) 873 if xticklabels_ is not None: 874 ax.xaxis.set_ticklabels(xticklabels_) 875 if xticks_ is None: 876 ax.xaxis.set_major_locator(mpl.ticker.AutoLocator()) 877 elif isinstance(xticks_, list): 878 ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks_)) 879 elif xticks_ != 0: 880 ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(xticks_-1)) 881 else: 882 ax.xaxis.set_major_locator(mpl.ticker.NullLocator()) 883 if y2_: 884 ax.yaxis.set_major_formatter(lambda x, pos: 885 si2(x)+(yunits_ if yunits_ else '')) 886 if yticklabels_ is not None: 887 ax.yaxis.set_ticklabels(yticklabels_) 888 if yticks_ is None: 889 ax.yaxis.set_major_locator(AutoMultipleLocator(2)) 890 elif isinstance(yticks_, list): 891 ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks_)) 892 elif yticks_ != 0: 893 ax.yaxis.set_major_locator(AutoMultipleLocator(2, yticks_-1)) 894 else: 895 ax.yaxis.set_major_locator(mpl.ticker.NullLocator()) 896 else: 897 ax.yaxis.set_major_formatter(lambda x, pos: 898 si(x)+(yunits_ if yunits_ else '')) 899 if yticklabels_ is not None: 900 ax.yaxis.set_ticklabels(yticklabels_) 901 if yticks_ is None: 902 ax.yaxis.set_major_locator(mpl.ticker.AutoLocator()) 903 elif isinstance(yticks_, list): 904 ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks_)) 905 elif yticks_ != 0: 906 ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(yticks_-1)) 907 else: 908 ax.yaxis.set_major_locator(mpl.ticker.NullLocator()) 909 if ggplot: 910 ax.grid(sketch_params=None) 911 912 # axes subplot labels 913 if xsublabel is not None: 914 ax.set_xlabel(escape(xsublabel)) 915 if ysublabel is not None: 916 ax.set_ylabel(escape(ysublabel)) 917 if subtitle is not None: 918 ax.set_title(escape(subtitle)) 919 920 # add a legend? a bit tricky with matplotlib 921 # 922 # the best solution I've found is a dedicated, invisible axes for the 923 # legend, hacky, but it works. 924 # 925 # note this was written before constrained_layout supported legend 926 # collisions, hopefully this is added in the future 927 labels = co.OrderedDict() 928 for s in grid: 929 for h, l in zip(*s.ax.get_legend_handles_labels()): 930 labels[l] = h 931 932 if legend_right: 933 ax = fig.add_subplot(gs[(1 if legend_above else 0):,-1]) 934 ax.set_axis_off() 935 ax.legend( 936 labels.values(), 937 labels.keys(), 938 loc='upper left', 939 fancybox=False, 940 borderaxespad=0) 941 942 if legend_above: 943 ax = fig.add_subplot(gs[0, :grid.width]) 944 ax.set_axis_off() 945 946 # try different column counts until we fit in the axes 947 for ncol in reversed(range(1, len(labels)+1)): 948 legend_ = ax.legend( 949 labels.values(), 950 labels.keys(), 951 loc='upper center', 952 ncol=ncol, 953 fancybox=False, 954 borderaxespad=0) 955 956 if (legend_.get_window_extent().width 957 <= ax.get_window_extent().width): 958 break 959 960 if legend_below: 961 ax = fig.add_subplot(gs[-1, :grid.width]) 962 ax.set_axis_off() 963 964 # big hack to get xlabel above the legend! but hey this 965 # works really well actually 966 if xlabel: 967 ax.set_title(escape(xlabel), 968 size=plt.rcParams['axes.labelsize'], 969 weight=plt.rcParams['axes.labelweight']) 970 971 # try different column counts until we fit in the axes 972 for ncol in reversed(range(1, len(labels)+1)): 973 legend_ = ax.legend( 974 labels.values(), 975 labels.keys(), 976 loc='upper center', 977 ncol=ncol, 978 fancybox=False, 979 borderaxespad=0) 980 981 if (legend_.get_window_extent().width 982 <= ax.get_window_extent().width): 983 break 984 985 986 # axes labels, NOTE we reposition these below 987 if xlabel is not None and not legend_below: 988 fig.supxlabel(escape(xlabel)) 989 if ylabel is not None: 990 fig.supylabel(escape(ylabel)) 991 if title is not None: 992 fig.suptitle(escape(title)) 993 994 # precompute constrained layout and find midpoints to adjust things 995 # that should be centered so they are actually centered 996 fig.canvas.draw() 997 xmid = (grid[0,0].ax.get_position().x0 + grid[-1,0].ax.get_position().x1)/2 998 ymid = (grid[0,0].ax.get_position().y0 + grid[0,-1].ax.get_position().y1)/2 999 1000 if xlabel is not None and not legend_below: 1001 fig.supxlabel(escape(xlabel), x=xmid) 1002 if ylabel is not None: 1003 fig.supylabel(escape(ylabel), y=ymid) 1004 if title is not None: 1005 fig.suptitle(escape(title), x=xmid) 1006 1007 1008 # write the figure! 1009 plt.savefig(output, format='png' if png else 'svg') 1010 1011 # some stats 1012 if not quiet: 1013 print('updated %s, %s datasets, %s points' % ( 1014 output, 1015 len(datasets_), 1016 sum(len(dataset) for dataset in datasets_.values()))) 1017 1018 1019if __name__ == "__main__": 1020 import sys 1021 import argparse 1022 parser = argparse.ArgumentParser( 1023 description="Plot CSV files with matplotlib.", 1024 allow_abbrev=False) 1025 parser.add_argument( 1026 'csv_paths', 1027 nargs='*', 1028 help="Input *.csv files.") 1029 output_rule = parser.add_argument( 1030 '-o', '--output', 1031 required=True, 1032 help="Output *.svg/*.png file.") 1033 parser.add_argument( 1034 '--svg', 1035 action='store_true', 1036 help="Output an svg file. By default this is infered.") 1037 parser.add_argument( 1038 '--png', 1039 action='store_true', 1040 help="Output a png file. By default this is infered.") 1041 parser.add_argument( 1042 '-q', '--quiet', 1043 action='store_true', 1044 help="Don't print info.") 1045 parser.add_argument( 1046 '-b', '--by', 1047 action='append', 1048 type=lambda x: ( 1049 lambda k,v=None: (k, v.split(',') if v is not None else ()) 1050 )(*x.split('=', 1)), 1051 help="Group by this field. Can rename fields with new_name=old_name.") 1052 parser.add_argument( 1053 '-x', 1054 action='append', 1055 type=lambda x: ( 1056 lambda k,v=None: (k, v.split(',') if v is not None else ()) 1057 )(*x.split('=', 1)), 1058 help="Field to use for the x-axis. Can rename fields with " 1059 "new_name=old_name.") 1060 parser.add_argument( 1061 '-y', 1062 action='append', 1063 type=lambda x: ( 1064 lambda k,v=None: (k, v.split(',') if v is not None else ()) 1065 )(*x.split('=', 1)), 1066 help="Field to use for the y-axis. Can rename fields with " 1067 "new_name=old_name.") 1068 parser.add_argument( 1069 '-D', '--define', 1070 type=lambda x: (lambda k,v: (k, set(v.split(','))))(*x.split('=', 1)), 1071 action='append', 1072 help="Only include results where this field is this value. May include " 1073 "comma-separated options.") 1074 parser.add_argument( 1075 '-.', '--points', 1076 action='store_true', 1077 help="Only draw data points.") 1078 parser.add_argument( 1079 '-!', '--points-and-lines', 1080 action='store_true', 1081 help="Draw data points and lines.") 1082 parser.add_argument( 1083 '--colors', 1084 type=lambda x: [x.strip() for x in x.split(',')], 1085 help="Comma-separated hex colors to use.") 1086 parser.add_argument( 1087 '--formats', 1088 type=lambda x: [x.strip().replace('0',',') for x in x.split(',')], 1089 help="Comma-separated matplotlib formats to use. Allows '0' as an " 1090 "alternative for ','.") 1091 parser.add_argument( 1092 '-W', '--width', 1093 type=lambda x: int(x, 0), 1094 help="Width in pixels. Defaults to %r." % WIDTH) 1095 parser.add_argument( 1096 '-H', '--height', 1097 type=lambda x: int(x, 0), 1098 help="Height in pixels. Defaults to %r." % HEIGHT) 1099 parser.add_argument( 1100 '-X', '--xlim', 1101 type=lambda x: tuple( 1102 dat(x) if x.strip() else None 1103 for x in x.split(',')), 1104 help="Range for the x-axis.") 1105 parser.add_argument( 1106 '-Y', '--ylim', 1107 type=lambda x: tuple( 1108 dat(x) if x.strip() else None 1109 for x in x.split(',')), 1110 help="Range for the y-axis.") 1111 parser.add_argument( 1112 '--xlog', 1113 action='store_true', 1114 help="Use a logarithmic x-axis.") 1115 parser.add_argument( 1116 '--ylog', 1117 action='store_true', 1118 help="Use a logarithmic y-axis.") 1119 parser.add_argument( 1120 '--x2', 1121 action='store_true', 1122 help="Use base-2 prefixes for the x-axis.") 1123 parser.add_argument( 1124 '--y2', 1125 action='store_true', 1126 help="Use base-2 prefixes for the y-axis.") 1127 parser.add_argument( 1128 '--xticks', 1129 type=lambda x: int(x, 0) if ',' not in x 1130 else [dat(x) for x in x.split(',')], 1131 help="Ticks for the x-axis. This can be explicit comma-separated " 1132 "ticks, the number of ticks, or 0 to disable.") 1133 parser.add_argument( 1134 '--yticks', 1135 type=lambda x: int(x, 0) if ',' not in x 1136 else [dat(x) for x in x.split(',')], 1137 help="Ticks for the y-axis. This can be explicit comma-separated " 1138 "ticks, the number of ticks, or 0 to disable.") 1139 parser.add_argument( 1140 '--xunits', 1141 help="Units for the x-axis.") 1142 parser.add_argument( 1143 '--yunits', 1144 help="Units for the y-axis.") 1145 parser.add_argument( 1146 '--xlabel', 1147 help="Add a label to the x-axis.") 1148 parser.add_argument( 1149 '--ylabel', 1150 help="Add a label to the y-axis.") 1151 parser.add_argument( 1152 '--xticklabels', 1153 type=lambda x: 1154 [x.strip() for x in x.split(',')] 1155 if x.strip() else [], 1156 help="Comma separated xticklabels.") 1157 parser.add_argument( 1158 '--yticklabels', 1159 type=lambda x: 1160 [x.strip() for x in x.split(',')] 1161 if x.strip() else [], 1162 help="Comma separated yticklabels.") 1163 parser.add_argument( 1164 '-t', '--title', 1165 help="Add a title.") 1166 parser.add_argument( 1167 '-l', '--legend-right', 1168 action='store_true', 1169 help="Place a legend to the right.") 1170 parser.add_argument( 1171 '--legend-above', 1172 action='store_true', 1173 help="Place a legend above.") 1174 parser.add_argument( 1175 '--legend-below', 1176 action='store_true', 1177 help="Place a legend below.") 1178 parser.add_argument( 1179 '--dark', 1180 action='store_true', 1181 help="Use the dark style.") 1182 parser.add_argument( 1183 '--ggplot', 1184 action='store_true', 1185 help="Use the ggplot style.") 1186 parser.add_argument( 1187 '--xkcd', 1188 action='store_true', 1189 help="Use the xkcd style.") 1190 parser.add_argument( 1191 '--github', 1192 action='store_true', 1193 help="Use the ggplot style with GitHub colors.") 1194 parser.add_argument( 1195 '--font', 1196 type=lambda x: [x.strip() for x in x.split(',')], 1197 help="Font family for matplotlib.") 1198 parser.add_argument( 1199 '--font-size', 1200 help="Font size for matplotlib. Defaults to %r." % FONT_SIZE) 1201 parser.add_argument( 1202 '--font-color', 1203 help="Color for the font and other line elements.") 1204 parser.add_argument( 1205 '--foreground', 1206 help="Foreground color to use.") 1207 parser.add_argument( 1208 '--background', 1209 help="Background color to use.") 1210 class AppendSubplot(argparse.Action): 1211 @staticmethod 1212 def parse(value): 1213 import copy 1214 subparser = copy.deepcopy(parser) 1215 next(a for a in subparser._actions 1216 if '--output' in a.option_strings).required = False 1217 next(a for a in subparser._actions 1218 if '--width' in a.option_strings).type = float 1219 next(a for a in subparser._actions 1220 if '--height' in a.option_strings).type = float 1221 return subparser.parse_intermixed_args(shlex.split(value or "")) 1222 def __call__(self, parser, namespace, value, option): 1223 if not hasattr(namespace, 'subplots'): 1224 namespace.subplots = [] 1225 namespace.subplots.append(( 1226 option.split('-')[-1], 1227 self.__class__.parse(value))) 1228 parser.add_argument( 1229 '--subplot-above', 1230 action=AppendSubplot, 1231 help="Add subplot above with the same dataset. Takes an arg string to " 1232 "control the subplot which supports most (but not all) of the " 1233 "parameters listed here. The relative dimensions of the subplot " 1234 "can be controlled with -W/-H which now take a percentage.") 1235 parser.add_argument( 1236 '--subplot-below', 1237 action=AppendSubplot, 1238 help="Add subplot below with the same dataset.") 1239 parser.add_argument( 1240 '--subplot-left', 1241 action=AppendSubplot, 1242 help="Add subplot left with the same dataset.") 1243 parser.add_argument( 1244 '--subplot-right', 1245 action=AppendSubplot, 1246 help="Add subplot right with the same dataset.") 1247 parser.add_argument( 1248 '--subplot', 1249 type=AppendSubplot.parse, 1250 help="Add subplot-specific arguments to the main plot.") 1251 1252 def dictify(ns): 1253 if hasattr(ns, 'subplots'): 1254 ns.subplots = [(dir, dictify(subplot_ns)) 1255 for dir, subplot_ns in ns.subplots] 1256 if ns.subplot is not None: 1257 ns.subplot = dictify(ns.subplot) 1258 return {k: v 1259 for k, v in vars(ns).items() 1260 if v is not None} 1261 1262 sys.exit(main(**dictify(parser.parse_intermixed_args()))) 1263