"""
My standard utilities. Intended to be included in all projects
Obviously everything included here needs to be in the standard library (or numpy)
"""
import contextlib
import fractions
import operator
import os
import re
import shutil
import string
import sys
import tempfile
import threading
from cStringIO import StringIO
from uuid import uuid4
import webcolors
[docs]def make_none(): return None
@contextlib.contextmanager
[docs]def recursionlimit_atleast(n=1000):
"""Context manager for temporarily raising the context manager's
the interpreter's maximum call stack size (misleading called the ``recursion limit``)
Notes:
This will explicitly reset the the recursion limit when we exit the context;
any intermediate recursion limit changes will be lost
This will not lower the limit ``n`` is less than the current recursion limit.
"""
current_limit = sys.getrecursionlimit()
if n >= current_limit:
sys.setrecursionlimit(n)
yield
sys.setrecursionlimit(current_limit)
[docs]def if_not_none(item, default):
""" Equivalent to `item if item is not None else default` """
if item is None:
return default
else:
return item
[docs]def printflush(s, newline=True):
if newline:
print s
else:
print s,
sys.stdout.flush()
[docs]class methodcaller:
"""The pickleable implementation of the standard library operator.methodcaller.
This was copied without modification from:
https://github.com/python/cpython/blob/065990fa5bd30fb3ca61b90adebc7d8cb3f16b5a/Lib/operator.py
The c-extension version is not pickleable, so we keep a copy of the pure-python standard library
code here. See https://bugs.python.org/issue22955
Original documentation:
Return a callable object that calls the given method on its operand.
After f = methodcaller('name'), the call f(r) returns r.name().
After g = methodcaller('name', 'date', foo=1), the call g(r) returns
r.name('date', foo=1).
"""
__slots__ = ('_name', '_args', '_kwargs')
def __init__(*args, **kwargs):
if len(args) < 2:
msg = "methodcaller needs at least one argument, the method name"
raise TypeError(msg)
self = args[0]
self._name = args[1]
if not isinstance(self._name, str):
raise TypeError('method name must be a string')
self._args = args[2:]
self._kwargs = kwargs
def __call__(self, obj):
return getattr(obj, self._name)(*self._args, **self._kwargs)
def __repr__(self):
args = [repr(self._name)]
args.extend(map(repr, self._args))
args.extend('%s=%r' % (k, v) for k, v in self._kwargs.items())
return '%s.%s(%s)' % (self.__class__.__module__,
self.__class__.__name__,
', '.join(args))
def __reduce__(self):
if not self._kwargs:
return self.__class__, (self._name,) + self._args
else:
from functools import partial
return partial(self.__class__, self._name, **self._kwargs), self._args
[docs]class textnotify(object):
""" Print a single, immediately flushed line to log the execution of a block.
Prints 'done' at the end of the line (or 'ERROR' if an uncaught exception)
Examples:
>>> import time
>>> with textnotify('starting to sleep'):
>>> time.sleep(3)
starting to sleep...done
>>> with textnotify('raising an exception...'):
>>> raise ValueError()
raising an exception...error
ValueError [...]
"""
def __init__(self, startmsg):
if startmsg.strip()[-3:] != '...':
startmsg = startmsg.strip() + '...'
self.startmsg = startmsg
def __enter__(self):
printflush(self.startmsg, newline=False)
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
printflush('done')
else:
printflush('ERROR')
[docs]class progressbar(object):
""" Create a progress bar for a calculation
The context manager provides a callback which needs to be called as
set_progress(percent), where percent is a number between 0 and 100
Examples:
>>> import time
>>> with progressbar('count to 100') as set_progress:
>>> for i in xrange(100):
>>> time.sleep(0.5)
>>> set_progress(i+1)
"""
def __init__(self, description):
import ipywidgets as ipy
import traitlets
try:
self.progress_bar = ipy.FloatProgress(0, min=0, max=100, description=description)
except traitlets.TraitError:
self.progress_bar = None
def __enter__(self):
from IPython.display import display
if self.progress_bar is not None:
display(self.progress_bar)
return self.set_progress
[docs] def set_progress(self, percent):
if self.progress_bar is not None:
self.progress_bar.value = percent
def __exit__(self, exc_type, exc_val, exc_tb):
if self.progress_bar is not None:
self.value = 100.0
if exc_type is not None:
self.progress_bar.bar_style = 'danger'
else:
self.progress_bar.bar_style = 'success'
[docs]class PipedFile(object):
"""
Allows us to pass data by filesystem path without ever writing it to disk
To prevent deadlock, we spawn a thread to write to the pipe
Call it as a context manager:
>>> with PipedFile('file contents',filename='contents.txt') as pipepath:
>>> print open(pipepath,'r').read()
"""
def __init__(self, fileobj, filename='pipe'):
if type(fileobj) in (unicode,str):
self.fileobj = StringIO(fileobj)
else:
self.fileobj = fileobj
self.tempdir = None
assert '/' not in filename,"Filename must not include directory"
self.filename = filename
def __enter__(self):
self.tempdir = tempfile.mkdtemp()
self.pipe_path = os.path.join(self.tempdir, self.filename)
os.mkfifo(self.pipe_path)
self.pipe_thread = threading.Thread(target=self._write_to_pipe)
self.pipe_thread.start()
return self.pipe_path
def _write_to_pipe(self):
with open(self.pipe_path,'w') as pipe:
pipe.write(self.fileobj.read())
def __exit__(self, type, value, traceback):
if self.tempdir is not None:
shutil.rmtree(self.tempdir)
[docs]def remove_directories(list_of_paths):
"""
Removes non-leafs from a list of directory paths
"""
found_dirs = set('/')
for path in list_of_paths:
dirs = path.strip().split('/')
for i in xrange(2, len(dirs)):
found_dirs.add('/'.join(dirs[:i]))
paths = [path for path in list_of_paths if
(path.strip() not in found_dirs) and path.strip()[-1] != '/']
return paths
[docs]def make_local_temp_dir():
tempdir = '/tmp/%s' % uuid4()
os.mkdir(tempdir)
return tempdir
[docs]class BaseTable(object):
def __init__(self, categories, fileobj=None):
self.categories = categories
self.lines = []
self.fileobj = fileobj
[docs] def add_line(self, obj):
if hasattr(obj, 'keys'):
newline = [obj.get(cat, '') for cat in self.categories]
else:
assert len(obj) == len(self.categories)
newline = obj
self.lines.append(newline)
self.writeline(newline)
[docs] def writeline(self, newline):
raise NotImplementedError()
[docs] def getstring(self):
raise NotImplementedError()
[docs]class PrintTable(BaseTable):
def __init__(self, formatstr, fileobj=sys.stdout):
self.format = formatstr
categories = []
self._wrote_header = False
for field in string.Formatter().parse(formatstr):
key = field.split('.')[0]
categories.append(key)
super(PrintTable, self).__init__(categories, fileobj=fileobj)
[docs] def writeline(self, line):
if not self._wrote_header:
print >> self._fileobj, self.format.format(self.categories)
self._wrote_header = True
if self.fileobj is None: return
print >> self.fileobj, self.formatstr.format(**line)
[docs] def getstring(self):
s = StringIO()
for line in self.lines:
print >> s, self.format.format(line)
return s.getvalue()
[docs]class MarkdownTable(BaseTable):
def __init__(self, *categories):
super(MarkdownTable, self).__init__(categories)
[docs] def markdown(self, replace=None):
if replace is None: replace = {}
outlines = ['| ' + ' | '.join(self.categories) + ' |',
'|-' + ''.join('|-' for x in self.categories) + '|']
for line in self.lines:
nextline = [str(replace.get(val, val)) for val in line]
outlines.append('| ' + ' | '.join(nextline) + ' |')
return '\n'.join(outlines)
[docs] def writeline(self, newline):
pass
[docs] def getstring(self):
return self.markdown()
[docs]def binomial_coefficient(n, k):
# credit to http://stackoverflow.com/users/226086/nas-banov
return int(reduce(operator.mul,
(fractions.Fraction(n - i, i + 1) for i in range(k)), 1))
[docs]def pairwise_displacements(a):
"""
:type a: numpy.array
from http://stackoverflow.com/questions/22390418/pairwise-displacement-vectors-among-set-of-points
"""
import numpy as np
n = a.shape[0]
d = a.shape[1]
c = binomial_coefficient(n, 2)
out = np.zeros((c, d))
l = 0
r = l + n - 1
for sl in range(1, n): # no point1 - point1!
out[l:r] = a[:n - sl] - a[sl:]
l = r
r += n - (sl + 1)
return out
[docs]def is_printable(s):
import string
for c in s:
if c not in string.printable:
return False
else:
return True
class _RedirectStream(object):
"""From python3.4 stdlib
"""
_stream = None
def __init__(self, new_target):
self._new_target = new_target
# We use a list of old targets to make this CM re-entrant
self._old_targets = []
def __enter__(self):
self._old_targets.append(getattr(sys, self._stream))
setattr(sys, self._stream, self._new_target)
return self._new_target
def __exit__(self, exctype, excinst, exctb):
setattr(sys, self._stream, self._old_targets.pop())
[docs]class redirect_stdout(_RedirectStream):
"""From python3.4 stdlib"""
_stream = "stdout"
[docs]class redirect_stderr(_RedirectStream):
"""From python3.4 stdlib"""
_stream = "stderr"
GETFLOAT = re.compile(r'-?\d+(\.\d+)?(e[-+]?\d+)') # matches numbers, e.g. 1, -2.0, 3.5e50, 0.001e-10
[docs]def is_color(s):
""" Do our best to determine if "s" is a color spec that can be converted to hex
:param s:
:return:
"""
def in_range(i): return 0 <= i <= int('0xFFFFFF', 0)
try:
if type(s) == int:
return in_range(s)
elif type(s) not in (str, unicode):
return False
elif s in webcolors.css3_names_to_hex:
return True
elif s[0] == '#':
return in_range(int('0x' + s[1:], 0))
elif s[0:2] == '0x':
return in_range(int(s, 0))
elif len(s) == 6:
return in_range(int('0x' + s, 0))
except ValueError:
return False
[docs]def from_filepath(func, filelike):
"""Run func on a temporary *path* assigned to filelike"""
if type(filelike) == str:
return func(filelike)
else:
with tempfile.NamedTemporaryFile() as outfile:
outfile.write(filelike.read())
outfile.flush()
result = func(outfile.name)
return result