applied black

pull/173/head
Catherine 2020-05-02 10:58:20 -04:00
parent ba21a75616
commit 2af7e344fb
10 changed files with 486 additions and 333 deletions

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -16,7 +16,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {

View File

@ -4,4 +4,6 @@ pytest
wheel wheel
twine twine
readme-renderer readme-renderer
black
isort

View File

@ -4,26 +4,30 @@ makes guesses about the role of each column for plotting purposes
(X values, Y values, and text labels). (X values, Y values, and text labels).
""" """
class Column(list): class Column(list):
'Store a column of tabular data; record its name and whether it is numeric' "Store a column of tabular data; record its name and whether it is numeric"
is_quantity = True is_quantity = True
name = '' name = ""
def __init__(self, *arg, **kwarg): def __init__(self, *arg, **kwarg):
pass pass
def is_quantity(val): def is_quantity(val):
"""Is ``val`` a quantity (int, float, datetime, etc) (not str, bool)? """Is ``val`` a quantity (int, float, datetime, etc) (not str, bool)?
Relies on presence of __sub__. Relies on presence of __sub__.
""" """
return hasattr(val, '__sub__') return hasattr(val, "__sub__")
class ColumnGuesserMixin(object): class ColumnGuesserMixin(object):
""" """
plot: [x, y, y...], y plot: [x, y, y...], y
pie: ... y pie: ... y
""" """
def _build_columns(self): def _build_columns(self):
self.columns = [Column() for col in self.keys] self.columns = [Column() for col in self.keys]
for row in self: for row in self:
@ -32,39 +36,40 @@ class ColumnGuesserMixin(object):
col.append(col_val) col.append(col_val)
if (col_val is not None) and (not is_quantity(col_val)): if (col_val is not None) and (not is_quantity(col_val)):
col.is_quantity = False col.is_quantity = False
for (idx, key_name) in enumerate(self.keys): for (idx, key_name) in enumerate(self.keys):
self.columns[idx].name = key_name self.columns[idx].name = key_name
self.x = Column() self.x = Column()
self.ys = [] self.ys = []
def _get_y(self): def _get_y(self):
for idx in range(len(self.columns)-1,-1,-1): for idx in range(len(self.columns) - 1, -1, -1):
if self.columns[idx].is_quantity: if self.columns[idx].is_quantity:
self.ys.insert(0, self.columns.pop(idx)) self.ys.insert(0, self.columns.pop(idx))
return True return True
def _get_x(self): def _get_x(self):
for idx in range(len(self.columns)): for idx in range(len(self.columns)):
if self.columns[idx].is_quantity: if self.columns[idx].is_quantity:
self.x = self.columns.pop(idx) self.x = self.columns.pop(idx)
return True return True
def _get_xlabel(self, xlabel_sep=" "): def _get_xlabel(self, xlabel_sep=" "):
self.xlabels = [] self.xlabels = []
if self.columns: if self.columns:
for row_idx in range(len(self.columns[0])): for row_idx in range(len(self.columns[0])):
self.xlabels.append(xlabel_sep.join( self.xlabels.append(
str(c[row_idx]) for c in self.columns)) xlabel_sep.join(str(c[row_idx]) for c in self.columns)
)
self.xlabel = ", ".join(c.name for c in self.columns) self.xlabel = ", ".join(c.name for c in self.columns)
def _guess_columns(self): def _guess_columns(self):
self._build_columns() self._build_columns()
self._get_y() self._get_y()
if not self.ys: if not self.ys:
raise AttributeError("No quantitative columns found for chart") raise AttributeError("No quantitative columns found for chart")
def guess_pie_columns(self, xlabel_sep=" "): def guess_pie_columns(self, xlabel_sep=" "):
""" """
Assigns x, y, and x labels from the data set for a pie chart. Assigns x, y, and x labels from the data set for a pie chart.
@ -75,7 +80,7 @@ class ColumnGuesserMixin(object):
""" """
self._guess_columns() self._guess_columns()
self._get_xlabel(xlabel_sep) self._get_xlabel(xlabel_sep)
def guess_plot_columns(self): def guess_plot_columns(self):
""" """
Assigns ``x`` and ``y`` series from the data set for a plot. Assigns ``x`` and ``y`` series from the data set for a plot.
@ -88,4 +93,4 @@ class ColumnGuesserMixin(object):
self._guess_columns() self._guess_columns()
self._get_x() self._get_x()
while self._get_y(): while self._get_y():
pass pass

View File

@ -1,20 +1,22 @@
import sqlalchemy
import os import os
import re import re
import sqlalchemy
class ConnectionError(Exception): class ConnectionError(Exception):
pass pass
def rough_dict_get(dct, sought, default=None): def rough_dict_get(dct, sought, default=None):
''' """
Like dct.get(sought), but any key containing sought will do. Like dct.get(sought), but any key containing sought will do.
If there is a `@` in sought, seek each piece separately. If there is a `@` in sought, seek each piece separately.
This lets `me@server` match `me:***@myserver/db` This lets `me@server` match `me:***@myserver/db`
''' """
sought = sought.split('@') sought = sought.split("@")
for (key, val) in dct.items(): for (key, val) in dct.items():
if not any(s.lower() not in key.lower() for s in sought): if not any(s.lower() not in key.lower() for s in sought):
return val return val
@ -29,15 +31,21 @@ class Connection(object):
def tell_format(cls): def tell_format(cls):
return """Connection info needed in SQLAlchemy format, example: return """Connection info needed in SQLAlchemy format, example:
postgresql://username:password@hostname/dbname postgresql://username:password@hostname/dbname
or an existing connection: %s""" % str(cls.connections.keys()) or an existing connection: %s""" % str(
cls.connections.keys()
)
def __init__(self, connect_str=None, connect_args={}, creator=None): def __init__(self, connect_str=None, connect_args={}, creator=None):
try: try:
if creator: if creator:
engine = sqlalchemy.create_engine(connect_str, connect_args=connect_args, creator=creator) engine = sqlalchemy.create_engine(
connect_str, connect_args=connect_args, creator=creator
)
else: else:
engine = sqlalchemy.create_engine(connect_str, connect_args=connect_args) engine = sqlalchemy.create_engine(
except: # TODO: bare except; but what's an ArgumentError? connect_str, connect_args=connect_args
)
except: # TODO: bare except; but what's an ArgumentError?
print(self.tell_format()) print(self.tell_format())
raise raise
self.dialect = engine.url.get_dialect() self.dialect = engine.url.get_dialect()
@ -65,42 +73,50 @@ class Connection(object):
if displaycon: if displaycon:
print(cls.connection_list()) print(cls.connection_list())
else: else:
if os.getenv('DATABASE_URL'): if os.getenv("DATABASE_URL"):
cls.current = Connection(os.getenv('DATABASE_URL'), connect_args, creator) cls.current = Connection(
os.getenv("DATABASE_URL"), connect_args, creator
)
else: else:
raise ConnectionError('Environment variable $DATABASE_URL not set, and no connect string given.') raise ConnectionError(
"Environment variable $DATABASE_URL not set, and no connect string given."
)
return cls.current return cls.current
@classmethod @classmethod
def assign_name(cls, engine): def assign_name(cls, engine):
name = '%s@%s' % (engine.url.username or '', engine.url.database) name = "%s@%s" % (engine.url.username or "", engine.url.database)
return name return name
@classmethod @classmethod
def connection_list(cls): def connection_list(cls):
result = [] result = []
for key in sorted(cls.connections): for key in sorted(cls.connections):
engine_url = cls.connections[key].metadata.bind.url # type: sqlalchemy.engine.url.URL engine_url = cls.connections[
key
].metadata.bind.url # type: sqlalchemy.engine.url.URL
if cls.connections[key] == cls.current: if cls.connections[key] == cls.current:
template = ' * {}' template = " * {}"
else: else:
template = ' {}' template = " {}"
result.append(template.format(engine_url.__repr__())) result.append(template.format(engine_url.__repr__()))
return '\n'.join(result) return "\n".join(result)
def _close(cls, descriptor): def _close(cls, descriptor):
if isinstance(descriptor, Connection): if isinstance(descriptor, Connection):
conn = descriptor conn = descriptor
else: else:
conn = cls.connections.get(descriptor) or \ conn = cls.connections.get(descriptor) or cls.connections.get(
cls.connections.get(descriptor.lower()) descriptor.lower()
)
if not conn: if not conn:
raise Exception("Could not close connection because it was not found amongst these: %s" \ raise Exception(
%str(cls.connections.keys())) "Could not close connection because it was not found amongst these: %s"
% str(cls.connections.keys())
)
cls.connections.pop(conn.name) cls.connections.pop(conn.name)
cls.connections.pop(str(conn.metadata.bind.url)) cls.connections.pop(str(conn.metadata.bind.url))
conn.session.close() conn.session.close()
def close(self): def close(self):
self.__class__._close(self) self.__class__._close(self)

View File

@ -1,26 +1,30 @@
import json import json
import re import re
from string import Formatter from string import Formatter
from IPython.core.magic import Magics, magics_class, cell_magic, line_magic, needs_local_scope
from IPython.core.magic import (Magics, cell_magic, line_magic, magics_class,
needs_local_scope)
from IPython.core.magic_arguments import (argument, magic_arguments,
parse_argstring)
from IPython.display import display_javascript from IPython.display import display_javascript
from sqlalchemy.exc import OperationalError, ProgrammingError
import sql.connection
import sql.parse
import sql.run
try: try:
from traitlets.config.configurable import Configurable from traitlets.config.configurable import Configurable
from traitlets import Bool, Int, Unicode from traitlets import Bool, Int, Unicode
except ImportError: except ImportError:
from IPython.config.configurable import Configurable from IPython.config.configurable import Configurable
from IPython.utils.traitlets import Bool, Int, Unicode from IPython.utils.traitlets import Bool, Int, Unicode
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
try: try:
from pandas.core.frame import DataFrame, Series from pandas.core.frame import DataFrame, Series
except ImportError: except ImportError:
DataFrame = None DataFrame = None
Series = None Series = None
from sqlalchemy.exc import ProgrammingError, OperationalError
import sql.connection
import sql.parse
import sql.run
@magics_class @magics_class
class SqlMagic(Magics, Configurable): class SqlMagic(Magics, Configurable):
@ -29,20 +33,47 @@ class SqlMagic(Magics, Configurable):
Provides the %%sql magic.""" Provides the %%sql magic."""
displaycon = Bool(True, config=True, help="Show connection string after execute") displaycon = Bool(True, config=True, help="Show connection string after execute")
autolimit = Int(0, config=True, allow_none=True, help="Automatically limit the size of the returned result sets") autolimit = Int(
style = Unicode('DEFAULT', config=True, help="Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)") 0,
short_errors = Bool(True, config=True, help="Don't display the full traceback on SQL Programming Error") config=True,
displaylimit = Int(None, config=True, allow_none=True, help="Automatically limit the number of rows displayed (full result set is still stored)") allow_none=True,
autopandas = Bool(False, config=True, help="Return Pandas DataFrames instead of regular result sets") help="Automatically limit the size of the returned result sets",
column_local_vars = Bool(False, config=True, help="Return data into local variables from column names") )
style = Unicode(
"DEFAULT",
config=True,
help="Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)",
)
short_errors = Bool(
True,
config=True,
help="Don't display the full traceback on SQL Programming Error",
)
displaylimit = Int(
None,
config=True,
allow_none=True,
help="Automatically limit the number of rows displayed (full result set is still stored)",
)
autopandas = Bool(
False,
config=True,
help="Return Pandas DataFrames instead of regular result sets",
)
column_local_vars = Bool(
False, config=True, help="Return data into local variables from column names"
)
feedback = Bool(True, config=True, help="Print number of rows affected by DML") feedback = Bool(True, config=True, help="Print number of rows affected by DML")
dsn_filename = Unicode('odbc.ini', config=True, help="Path to DSN file. " dsn_filename = Unicode(
"When the first argument is of the form [section], " "odbc.ini",
"a sqlalchemy connection string is formed from the " config=True,
"matching section in the DSN file.") help="Path to DSN file. "
"When the first argument is of the form [section], "
"a sqlalchemy connection string is formed from the "
"matching section in the DSN file.",
)
autocommit = Bool(True, config=True, help="Set autocommit mode") autocommit = Bool(True, config=True, help="Set autocommit mode")
def __init__(self, shell): def __init__(self, shell):
Configurable.__init__(self, config=shell.config) Configurable.__init__(self, config=shell.config)
Magics.__init__(self, shell=shell) Magics.__init__(self, shell=shell)
@ -51,19 +82,42 @@ class SqlMagic(Magics, Configurable):
self.shell.configurables.append(self) self.shell.configurables.append(self)
@needs_local_scope @needs_local_scope
@line_magic('sql') @line_magic("sql")
@cell_magic('sql') @cell_magic("sql")
@magic_arguments() @magic_arguments()
@argument('line', default='', nargs='*', type=str, help='sql') @argument("line", default="", nargs="*", type=str, help="sql")
@argument('-l', '--connections', action='store_true', help="list active connections") @argument(
@argument('-x', '--close', type=str, help="close a session by name") "-l", "--connections", action="store_true", help="list active connections"
@argument('-c', '--creator', type=str, help="specify creator function for new connection") )
@argument('-s', '--section', type=str, help="section of dsn_file to be used for generating a connection string") @argument("-x", "--close", type=str, help="close a session by name")
@argument('-p', '--persist', action='store_true', help="create a table name in the database from the named DataFrame") @argument(
@argument('--append', action='store_true', help="create, or append to, a table name in the database from the named DataFrame") "-c", "--creator", type=str, help="specify creator function for new connection"
@argument('-a', '--connection_arguments', type=str, help="specify dictionary of connection arguments to pass to SQL driver") )
@argument('-f', '--file', type=str, help="Run SQL from file at this path") @argument(
def execute(self, line='', cell='', local_ns={}): "-s",
"--section",
type=str,
help="section of dsn_file to be used for generating a connection string",
)
@argument(
"-p",
"--persist",
action="store_true",
help="create a table name in the database from the named DataFrame",
)
@argument(
"--append",
action="store_true",
help="create, or append to, a table name in the database from the named DataFrame",
)
@argument(
"-a",
"--connection_arguments",
type=str,
help="specify dictionary of connection arguments to pass to SQL driver",
)
@argument("-f", "--file", type=str, help="Run SQL from file at this path")
def execute(self, line="", cell="", local_ns={}):
"""Runs SQL statement against a database, specified by SQLAlchemy connect string. """Runs SQL statement against a database, specified by SQLAlchemy connect string.
If no database connection has been established, first word If no database connection has been established, first word
@ -89,7 +143,9 @@ class SqlMagic(Magics, Configurable):
""" """
# Parse variables (words wrapped in {}) for %%sql magic (for %sql this is done automatically) # Parse variables (words wrapped in {}) for %%sql magic (for %sql this is done automatically)
cell_variables = [fn for _, fn, _, _ in Formatter().parse(cell) if fn is not None] cell_variables = [
fn for _, fn, _, _ in Formatter().parse(cell) if fn is not None
]
cell_params = {} cell_params = {}
for variable in cell_variables: for variable in cell_variables:
cell_params[variable] = local_ns[variable] cell_params[variable] = local_ns[variable]
@ -105,15 +161,15 @@ class SqlMagic(Magics, Configurable):
user_ns = self.shell.user_ns.copy() user_ns = self.shell.user_ns.copy()
user_ns.update(local_ns) user_ns.update(local_ns)
command_text = ' '.join(args.line) + '\n' + cell command_text = " ".join(args.line) + "\n" + cell
if args.file: if args.file:
with open(args.file, 'r') as infile: with open(args.file, "r") as infile:
command_text = infile.read() + "\n" + command_text command_text = infile.read() + "\n" + command_text
parsed = sql.parse.parse(command_text, self) parsed = sql.parse.parse(command_text, self)
connect_str = parsed['connection'] connect_str = parsed["connection"]
if args.section: if args.section:
connect_str = sql.parse.connection_from_dsn_section(args.section, self) connect_str = sql.parse.connection_from_dsn_section(args.section, self)
@ -137,27 +193,36 @@ class SqlMagic(Magics, Configurable):
args.creator = user_ns[args.creator] args.creator = user_ns[args.creator]
try: try:
conn = sql.connection.Connection.set(parsed['connection'], displaycon=self.displaycon, connect_args=args.connection_arguments, creator=args.creator) conn = sql.connection.Connection.set(
parsed["connection"],
displaycon=self.displaycon,
connect_args=args.connection_arguments,
creator=args.creator,
)
except Exception as e: except Exception as e:
print(e) print(e)
print(sql.connection.Connection.tell_format()) print(sql.connection.Connection.tell_format())
return None return None
if args.persist: if args.persist:
return self._persist_dataframe(parsed['sql'], conn, user_ns, append=False) return self._persist_dataframe(parsed["sql"], conn, user_ns, append=False)
if args.append: if args.append:
return self._persist_dataframe(parsed['sql'], conn, user_ns, append=True) return self._persist_dataframe(parsed["sql"], conn, user_ns, append=True)
if not parsed['sql']: if not parsed["sql"]:
return return
try: try:
result = sql.run.run(conn, parsed['sql'], self, user_ns) result = sql.run.run(conn, parsed["sql"], self, user_ns)
if result is not None and not isinstance(result, str) and self.column_local_vars: if (
#Instead of returning values, set variables directly in the result is not None
#users namespace. Variable names given by column names and not isinstance(result, str)
and self.column_local_vars
):
# Instead of returning values, set variables directly in the
# users namespace. Variable names given by column names
if self.autopandas: if self.autopandas:
keys = result.keys() keys = result.keys()
@ -166,21 +231,22 @@ class SqlMagic(Magics, Configurable):
result = result.dict() result = result.dict()
if self.feedback: if self.feedback:
print('Returning data to local variables [{}]'.format( print(
', '.join(keys))) "Returning data to local variables [{}]".format(", ".join(keys))
)
self.shell.user_ns.update(result) self.shell.user_ns.update(result)
return None return None
else: else:
if parsed['result_var']: if parsed["result_var"]:
result_var = parsed['result_var'] result_var = parsed["result_var"]
print("Returning data to local variable {}".format(result_var)) print("Returning data to local variable {}".format(result_var))
self.shell.user_ns.update({result_var: result}) self.shell.user_ns.update({result_var: result})
return None return None
#Return results into the default ipython _ variable # Return results into the default ipython _ variable
return result return result
except (ProgrammingError, OperationalError) as e: except (ProgrammingError, OperationalError) as e:
@ -190,28 +256,29 @@ class SqlMagic(Magics, Configurable):
else: else:
raise raise
legal_sql_identifier = re.compile(r'^[A-Za-z0-9#_$]+') legal_sql_identifier = re.compile(r"^[A-Za-z0-9#_$]+")
def _persist_dataframe(self, raw, conn, user_ns, append=False): def _persist_dataframe(self, raw, conn, user_ns, append=False):
"""Implements PERSIST, which writes a DataFrame to the RDBMS""" """Implements PERSIST, which writes a DataFrame to the RDBMS"""
if not DataFrame: if not DataFrame:
raise ImportError("Must `pip install pandas` to use DataFrames") raise ImportError("Must `pip install pandas` to use DataFrames")
frame_name = raw.strip(';') frame_name = raw.strip(";")
# Get the DataFrame from the user namespace # Get the DataFrame from the user namespace
if not frame_name: if not frame_name:
raise SyntaxError('Syntax: %sql PERSIST <name_of_data_frame>') raise SyntaxError("Syntax: %sql PERSIST <name_of_data_frame>")
frame = eval(frame_name, user_ns) frame = eval(frame_name, user_ns)
if not isinstance(frame, DataFrame) and not isinstance(frame, Series): if not isinstance(frame, DataFrame) and not isinstance(frame, Series):
raise TypeError('%s is not a Pandas DataFrame or Series' % frame_name) raise TypeError("%s is not a Pandas DataFrame or Series" % frame_name)
# Make a suitable name for the resulting database table # Make a suitable name for the resulting database table
table_name = frame_name.lower() table_name = frame_name.lower()
table_name = self.legal_sql_identifier.search(table_name).group(0) table_name = self.legal_sql_identifier.search(table_name).group(0)
if_exists = 'append' if append else 'fail' if_exists = "append" if append else "fail"
frame.to_sql(table_name, conn.session.engine, if_exists=if_exists) frame.to_sql(table_name, conn.session.engine, if_exists=if_exists)
return 'Persisted %s' % table_name return "Persisted %s" % table_name
def load_ipython_extension(ip): def load_ipython_extension(ip):

View File

@ -1,9 +1,11 @@
import json
import re
from os.path import expandvars from os.path import expandvars
import six import six
from six.moves import configparser as CP from six.moves import configparser as CP
from sqlalchemy.engine.url import URL from sqlalchemy.engine.url import URL
import json
import re
def connection_from_dsn_section(section, config): def connection_from_dsn_section(section, config):
parser = CP.ConfigParser() parser = CP.ConfigParser()
@ -11,19 +13,19 @@ def connection_from_dsn_section(section, config):
cfg_dict = dict(parser.items(section)) cfg_dict = dict(parser.items(section))
return str(URL(**cfg_dict)) return str(URL(**cfg_dict))
def _connection_string(s): def _connection_string(s):
s = expandvars(s) # for environment variables s = expandvars(s) # for environment variables
if '@' in s or '://' in s: if "@" in s or "://" in s:
return s return s
if s.startswith('[') and s.endswith(']'): if s.startswith("[") and s.endswith("]"):
section = s.lstrip('[').rstrip(']') section = s.lstrip("[").rstrip("]")
parser = CP.ConfigParser() parser = CP.ConfigParser()
parser.read(config.dsn_filename) parser.read(config.dsn_filename)
cfg_dict = dict(parser.items(section)) cfg_dict = dict(parser.items(section))
return str(URL(**cfg_dict)) return str(URL(**cfg_dict))
return '' return ""
def parse(cell, config): def parse(cell, config):
@ -37,18 +39,18 @@ def parse(cell, config):
connection string and `<<` operator in. connection string and `<<` operator in.
""" """
result = {'connection': '', 'sql': '', 'result_var': None} result = {"connection": "", "sql": "", "result_var": None}
pieces = cell.split(None, 3) pieces = cell.split(None, 3)
if not pieces: if not pieces:
return result return result
result['connection'] = _connection_string(pieces[0]) result["connection"] = _connection_string(pieces[0])
if result['connection']: if result["connection"]:
pieces.pop(0) pieces.pop(0)
if len(pieces) > 1 and pieces[1] == '<<': if len(pieces) > 1 and pieces[1] == "<<":
result['result_var'] = pieces.pop(0) result["result_var"] = pieces.pop(0)
pieces.pop(0) # discard << operator pieces.pop(0) # discard << operator
result['sql'] = (' '.join(pieces)).strip() result["sql"] = (" ".join(pieces)).strip()
return result return result

View File

@ -24,9 +24,9 @@ def unduplicate_field_names(field_names):
for k in field_names: for k in field_names:
if k in res: if k in res:
i = 1 i = 1
while k + '_' + str(i) in res: while k + "_" + str(i) in res:
i += 1 i += 1
k += '_' + str(i) k += "_" + str(i)
res.append(k) res.append(k)
return res return res
@ -46,8 +46,7 @@ class UnicodeWriter(object):
def writerow(self, row): def writerow(self, row):
if six.PY2: if six.PY2:
_row = [s.encode("utf-8") if hasattr(s, "encode") else s _row = [s.encode("utf-8") if hasattr(s, "encode") else s for s in row]
for s in row]
else: else:
_row = row _row = row
self.writer.writerow(_row) self.writer.writerow(_row)
@ -75,12 +74,12 @@ class CsvResultDescriptor(object):
self.file_path = file_path self.file_path = file_path
def __repr__(self): def __repr__(self):
return 'CSV results at %s' % os.path.join( return "CSV results at %s" % os.path.join(os.path.abspath("."), self.file_path)
os.path.abspath('.'), self.file_path)
def _repr_html_(self): def _repr_html_(self):
return '<a href="%s">CSV results</a>' % os.path.join('.', 'files', return '<a href="%s">CSV results</a>' % os.path.join(
self.file_path) ".", "files", self.file_path
)
def _nonbreaking_spaces(match_obj): def _nonbreaking_spaces(match_obj):
@ -90,11 +89,11 @@ def _nonbreaking_spaces(match_obj):
Call with a ``re`` match object. Retain group 1, replace group 2 Call with a ``re`` match object. Retain group 1, replace group 2
with nonbreaking speaces. with nonbreaking speaces.
""" """
spaces = '&nbsp;' * len(match_obj.group(2)) spaces = "&nbsp;" * len(match_obj.group(2))
return '%s%s' % (match_obj.group(1), spaces) return "%s%s" % (match_obj.group(1), spaces)
_cell_with_spaces_pattern = re.compile(r'(<td>)( {2,})') _cell_with_spaces_pattern = re.compile(r"(<td>)( {2,})")
class ResultSet(list, ColumnGuesserMixin): class ResultSet(list, ColumnGuesserMixin):
@ -124,22 +123,23 @@ class ResultSet(list, ColumnGuesserMixin):
self.pretty = None self.pretty = None
def _repr_html_(self): def _repr_html_(self):
_cell_with_spaces_pattern = re.compile(r'(<td>)( {2,})') _cell_with_spaces_pattern = re.compile(r"(<td>)( {2,})")
if self.pretty: if self.pretty:
self.pretty.add_rows(self) self.pretty.add_rows(self)
result = self.pretty.get_html_string() result = self.pretty.get_html_string()
result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result) result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result)
if self.config.displaylimit and len( if self.config.displaylimit and len(self) > self.config.displaylimit:
self) > self.config.displaylimit: result = (
result = '%s\n<span style="font-style:italic;text-align:center;">%d rows, truncated to displaylimit of %d</span>' % ( '%s\n<span style="font-style:italic;text-align:center;">%d rows, truncated to displaylimit of %d</span>'
result, len(self), self.config.displaylimit) % (result, len(self), self.config.displaylimit)
)
return result return result
else: else:
return None return None
def __str__(self, *arg, **kwarg): def __str__(self, *arg, **kwarg):
self.pretty.add_rows(self) self.pretty.add_rows(self)
return str(self.pretty or '') return str(self.pretty or "")
def __getitem__(self, key): def __getitem__(self, key):
""" """
@ -170,6 +170,7 @@ class ResultSet(list, ColumnGuesserMixin):
def DataFrame(self): def DataFrame(self):
"Returns a Pandas DataFrame instance built from the result set." "Returns a Pandas DataFrame instance built from the result set."
import pandas as pd import pandas as pd
frame = pd.DataFrame(self, columns=(self and self.keys) or []) frame = pd.DataFrame(self, columns=(self and self.keys) or [])
return frame return frame
@ -196,6 +197,7 @@ class ResultSet(list, ColumnGuesserMixin):
""" """
self.guess_pie_columns(xlabel_sep=key_word_sep) self.guess_pie_columns(xlabel_sep=key_word_sep)
import matplotlib.pylab as plt import matplotlib.pylab as plt
pie = plt.pie(self.ys[0], labels=self.xlabels, **kwargs) pie = plt.pie(self.ys[0], labels=self.xlabels, **kwargs)
plt.title(title or self.ys[0].name) plt.title(title or self.ys[0].name)
return pie return pie
@ -219,11 +221,12 @@ class ResultSet(list, ColumnGuesserMixin):
through to ``matplotlib.pylab.plot``. through to ``matplotlib.pylab.plot``.
""" """
import matplotlib.pylab as plt import matplotlib.pylab as plt
self.guess_plot_columns() self.guess_plot_columns()
self.x = self.x or range(len(self.ys[0])) self.x = self.x or range(len(self.ys[0]))
coords = reduce(operator.add, [(self.x, y) for y in self.ys]) coords = reduce(operator.add, [(self.x, y) for y in self.ys])
plot = plt.plot(*coords, **kwargs) plot = plt.plot(*coords, **kwargs)
if hasattr(self.x, 'name'): if hasattr(self.x, "name"):
plt.xlabel(self.x.name) plt.xlabel(self.x.name)
ylabel = ", ".join(y.name for y in self.ys) ylabel = ", ".join(y.name for y in self.ys)
plt.title(title or ylabel) plt.title(title or ylabel)
@ -251,6 +254,7 @@ class ResultSet(list, ColumnGuesserMixin):
through to ``matplotlib.pylab.bar``. through to ``matplotlib.pylab.bar``.
""" """
import matplotlib.pylab as plt import matplotlib.pylab as plt
self.guess_pie_columns(xlabel_sep=key_word_sep) self.guess_pie_columns(xlabel_sep=key_word_sep)
plot = plt.bar(range(len(self.ys[0])), self.ys[0], **kwargs) plot = plt.bar(range(len(self.ys[0])), self.ys[0], **kwargs)
if self.xlabels: if self.xlabels:
@ -266,11 +270,11 @@ class ResultSet(list, ColumnGuesserMixin):
return None # no results return None # no results
self.pretty.add_rows(self) self.pretty.add_rows(self)
if filename: if filename:
encoding = format_params.get('encoding', 'utf-8') encoding = format_params.get("encoding", "utf-8")
if six.PY2: if six.PY2:
outfile = open(filename, 'wb') outfile = open(filename, "wb")
else: else:
outfile = open(filename, 'w', newline='', encoding=encoding) outfile = open(filename, "w", newline="", encoding=encoding)
else: else:
outfile = six.StringIO() outfile = six.StringIO()
writer = UnicodeWriter(outfile, **format_params) writer = UnicodeWriter(outfile, **format_params)
@ -286,9 +290,9 @@ class ResultSet(list, ColumnGuesserMixin):
def interpret_rowcount(rowcount): def interpret_rowcount(rowcount):
if rowcount < 0: if rowcount < 0:
result = 'Done.' result = "Done."
else: else:
result = '%d rows affected.' % rowcount result = "%d rows affected." % rowcount
return result return result
@ -313,34 +317,33 @@ class FakeResultProxy(object):
def from_list(self, source_list): def from_list(self, source_list):
"Simulates SQLA ResultProxy from a list." "Simulates SQLA ResultProxy from a list."
self.fetchall = lambda: source_list self.fetchall = lambda: source_list
self.rowcount = len(source_list) self.rowcount = len(source_list)
def fetchmany(size): def fetchmany(size):
pos = 0 pos = 0
while pos < len(source_list): while pos < len(source_list):
yield source_list[pos:pos+size] yield source_list[pos : pos + size]
pos += size pos += size
self.fetchmany = fetchmany self.fetchmany = fetchmany
# some dialects have autocommit # some dialects have autocommit
# specific dialects break when commit is used: # specific dialects break when commit is used:
_COMMIT_BLACKLIST_DIALECTS = ('mssql', 'clickhouse', 'teradata', 'athena') _COMMIT_BLACKLIST_DIALECTS = ("mssql", "clickhouse", "teradata", "athena")
def _commit(conn, config): def _commit(conn, config):
"""Issues a commit, if appropriate for current config and dialect""" """Issues a commit, if appropriate for current config and dialect"""
_should_commit = config.autocommit and all( _should_commit = config.autocommit and all(
dialect not in str(conn.dialect) dialect not in str(conn.dialect) for dialect in _COMMIT_BLACKLIST_DIALECTS
for dialect in _COMMIT_BLACKLIST_DIALECTS) )
if _should_commit: if _should_commit:
try: try:
conn.session.execute('commit') conn.session.execute("commit")
except sqlalchemy.exc.OperationalError: except sqlalchemy.exc.OperationalError:
pass # not all engines can commit pass # not all engines can commit
@ -349,14 +352,15 @@ def run(conn, sql, config, user_namespace):
if sql.strip(): if sql.strip():
for statement in sqlparse.split(sql): for statement in sqlparse.split(sql):
first_word = sql.strip().split()[0].lower() first_word = sql.strip().split()[0].lower()
if first_word == 'begin': if first_word == "begin":
raise Exception("ipython_sql does not support transactions") raise Exception("ipython_sql does not support transactions")
if first_word.startswith('\\') and 'postgres' in str(conn.dialect): if first_word.startswith("\\") and "postgres" in str(conn.dialect):
if not PGSpecial: if not PGSpecial:
raise ImportError('pgspecial not installed') raise ImportError("pgspecial not installed")
pgspecial = PGSpecial() pgspecial = PGSpecial()
_, cur, headers, _ = pgspecial.execute( _, cur, headers, _ = pgspecial.execute(
conn.session.connection.cursor(), statement)[0] conn.session.connection.cursor(), statement
)[0]
result = FakeResultProxy(cur, headers) result = FakeResultProxy(cur, headers)
else: else:
txt = sqlalchemy.sql.text(statement) txt = sqlalchemy.sql.text(statement)
@ -369,9 +373,9 @@ def run(conn, sql, config, user_namespace):
return resultset.DataFrame() return resultset.DataFrame()
else: else:
return resultset return resultset
#returning only last result, intentionally # returning only last result, intentionally
else: else:
return 'Connected: %s' % conn.name return "Connected: %s" % conn.name
class PrettyTable(prettytable.PrettyTable): class PrettyTable(prettytable.PrettyTable):
@ -391,5 +395,5 @@ class PrettyTable(prettytable.PrettyTable):
self.row_count = len(data) self.row_count = len(data)
else: else:
self.row_count = min(len(data), self.displaylimit) self.row_count = min(len(data), self.displaylimit)
for row in data[:self.displaylimit]: for row in data[: self.displaylimit]:
self.add_row(row) self.add_row(row)

View File

@ -13,10 +13,10 @@ class SqlEnv(object):
self.connectstr = connectstr self.connectstr = connectstr
def query(self, txt): def query(self, txt):
return ip.run_line_magic('sql', "%s %s" % (self.connectstr, txt)) return ip.run_line_magic("sql", "%s %s" % (self.connectstr, txt))
sql_env = SqlEnv('sqlite://') sql_env = SqlEnv("sqlite://")
@pytest.fixture @pytest.fixture
@ -54,14 +54,14 @@ class TestOneNum(Harness):
assert results.ys == [[1.01, 2.01, 3.01]] assert results.ys == [[1.01, 2.01, 3.01]]
assert results.x == [] assert results.x == []
assert results.xlabels == [] assert results.xlabels == []
assert results.xlabel == '' assert results.xlabel == ""
def test_plot(self, tbl): def test_plot(self, tbl):
results = self.run_query() results = self.run_query()
results.guess_plot_columns() results.guess_plot_columns()
assert results.ys == [[1.01, 2.01, 3.01]] assert results.ys == [[1.01, 2.01, 3.01]]
assert results.x == [] assert results.x == []
assert results.x.name == '' assert results.x.name == ""
class TestOneStrOneNum(Harness): class TestOneStrOneNum(Harness):
@ -72,8 +72,8 @@ class TestOneStrOneNum(Harness):
results.guess_pie_columns(xlabel_sep="//") results.guess_pie_columns(xlabel_sep="//")
assert results.ys[0].is_quantity assert results.ys[0].is_quantity
assert results.ys == [[1.01, 2.01, 3.01]] assert results.ys == [[1.01, 2.01, 3.01]]
assert results.xlabels == ['r1-txt1', 'r2-txt1', 'r3-txt1'] assert results.xlabels == ["r1-txt1", "r2-txt1", "r3-txt1"]
assert results.xlabel == 'name' assert results.xlabel == "name"
def test_plot(self, tbl): def test_plot(self, tbl):
results = self.run_query() results = self.run_query()
@ -91,10 +91,11 @@ class TestTwoStrTwoNum(Harness):
assert results.ys[0].is_quantity assert results.ys[0].is_quantity
assert results.ys == [[1.01, 2.01, 3.01]] assert results.ys == [[1.01, 2.01, 3.01]]
assert results.xlabels == [ assert results.xlabels == [
'r1-txt2//1.04//r1-txt1', 'r2-txt2//2.04//r2-txt1', "r1-txt2//1.04//r1-txt1",
'r3-txt2//3.04//r3-txt1' "r2-txt2//2.04//r2-txt1",
"r3-txt2//3.04//r3-txt1",
] ]
assert results.xlabel == 'name2, y3, name' assert results.xlabel == "name2, y3, name"
def test_plot(self, tbl): def test_plot(self, tbl):
results = self.run_query() results = self.run_query()
@ -112,8 +113,9 @@ class TestTwoStrThreeNum(Harness):
assert results.ys[0].is_quantity assert results.ys[0].is_quantity
assert results.ys == [[1.04, 2.04, 3.04]] assert results.ys == [[1.04, 2.04, 3.04]]
assert results.xlabels == [ assert results.xlabels == [
'r1-txt1//1.01//r1-txt2//1.02', 'r2-txt1//2.01//r2-txt2//2.02', "r1-txt1//1.01//r1-txt2//1.02",
'r3-txt1//3.01//r3-txt2//3.02' "r2-txt1//2.01//r2-txt2//2.02",
"r3-txt1//3.01//r3-txt2//3.02",
] ]
def test_plot(self, tbl): def test_plot(self, tbl):

View File

@ -14,7 +14,7 @@ def runsql(ip_session, statements):
statements, statements,
] ]
for statement in statements: for statement in statements:
result = ip_session.run_line_magic('sql', 'sqlite:// %s' % statement) result = ip_session.run_line_magic("sql", "sqlite:// %s" % statement)
return result # returns only last result return result # returns only last result
@ -23,87 +23,108 @@ def ip():
"""Provides an IPython session in which tables have been created""" """Provides an IPython session in which tables have been created"""
ip_session = get_ipython() ip_session = get_ipython()
runsql(ip_session, [ runsql(
"CREATE TABLE test (n INT, name TEXT)", ip_session,
"INSERT INTO test VALUES (1, 'foo')", [
"INSERT INTO test VALUES (2, 'bar')", "CREATE TABLE test (n INT, name TEXT)",
"CREATE TABLE author (first_name, last_name, year_of_death)", "INSERT INTO test VALUES (1, 'foo')",
"INSERT INTO author VALUES ('William', 'Shakespeare', 1616)", "INSERT INTO test VALUES (2, 'bar')",
"INSERT INTO author VALUES ('Bertold', 'Brecht', 1956)" "CREATE TABLE author (first_name, last_name, year_of_death)",
]) "INSERT INTO author VALUES ('William', 'Shakespeare', 1616)",
"INSERT INTO author VALUES ('Bertold', 'Brecht', 1956)",
],
)
yield ip_session yield ip_session
runsql(ip_session, 'DROP TABLE test') runsql(ip_session, "DROP TABLE test")
runsql(ip_session, 'DROP TABLE author') runsql(ip_session, "DROP TABLE author")
def test_memory_db(ip): def test_memory_db(ip):
assert runsql(ip, "SELECT * FROM test;")[0][0] == 1 assert runsql(ip, "SELECT * FROM test;")[0][0] == 1
assert runsql(ip, "SELECT * FROM test;")[1]['name'] == 'bar' assert runsql(ip, "SELECT * FROM test;")[1]["name"] == "bar"
def test_html(ip): def test_html(ip):
result = runsql(ip, "SELECT * FROM test;") result = runsql(ip, "SELECT * FROM test;")
assert '<td>foo</td>' in result._repr_html_().lower() assert "<td>foo</td>" in result._repr_html_().lower()
def test_print(ip): def test_print(ip):
result = runsql(ip, "SELECT * FROM test;") result = runsql(ip, "SELECT * FROM test;")
assert re.search(r'1\s+\|\s+foo', str(result)) assert re.search(r"1\s+\|\s+foo", str(result))
def test_plain_style(ip): def test_plain_style(ip):
ip.run_line_magic('config', "SqlMagic.style = 'PLAIN_COLUMNS'") ip.run_line_magic("config", "SqlMagic.style = 'PLAIN_COLUMNS'")
result = runsql(ip, "SELECT * FROM test;") result = runsql(ip, "SELECT * FROM test;")
assert re.search(r'1\s+\|\s+foo', str(result)) assert re.search(r"1\s+\|\s+foo", str(result))
@pytest.mark.skip @pytest.mark.skip
def test_multi_sql(ip): def test_multi_sql(ip):
result = ip.run_cell_magic('sql', '', """ result = ip.run_cell_magic(
"sql",
"",
"""
sqlite:// sqlite://
SELECT last_name FROM author; SELECT last_name FROM author;
""") """,
assert 'Shakespeare' in str(result) and 'Brecht' in str(result) )
assert "Shakespeare" in str(result) and "Brecht" in str(result)
def test_result_var(ip): def test_result_var(ip):
ip.run_cell_magic('sql', '', """ ip.run_cell_magic(
"sql",
"",
"""
sqlite:// sqlite://
x << x <<
SELECT last_name FROM author; SELECT last_name FROM author;
""") """,
result = ip.user_global_ns['x'] )
assert 'Shakespeare' in str(result) and 'Brecht' in str(result) result = ip.user_global_ns["x"]
assert "Shakespeare" in str(result) and "Brecht" in str(result)
def test_result_var_multiline_shovel(ip): def test_result_var_multiline_shovel(ip):
ip.run_cell_magic('sql', '', """ ip.run_cell_magic(
"sql",
"",
"""
sqlite:// x << SELECT last_name sqlite:// x << SELECT last_name
FROM author; FROM author;
""") """,
result = ip.user_global_ns['x'] )
assert 'Shakespeare' in str(result) and 'Brecht' in str(result) result = ip.user_global_ns["x"]
assert "Shakespeare" in str(result) and "Brecht" in str(result)
def test_access_results_by_keys(ip): def test_access_results_by_keys(ip):
assert runsql(ip, assert runsql(ip, "SELECT * FROM author;")["William"] == (
"SELECT * FROM author;")['William'] == (u'William', u"William",
u'Shakespeare', 1616) u"Shakespeare",
1616,
)
def test_duplicate_column_names_accepted(ip): def test_duplicate_column_names_accepted(ip):
result = ip.run_cell_magic('sql', '', """ result = ip.run_cell_magic(
"sql",
"",
"""
sqlite:// sqlite://
SELECT last_name, last_name FROM author; SELECT last_name, last_name FROM author;
""") """,
assert (u'Brecht', u'Brecht') in result )
assert (u"Brecht", u"Brecht") in result
def test_autolimit(ip): def test_autolimit(ip):
ip.run_line_magic('config', "SqlMagic.autolimit = 0") ip.run_line_magic("config", "SqlMagic.autolimit = 0")
result = runsql(ip, "SELECT * FROM test;") result = runsql(ip, "SELECT * FROM test;")
assert len(result) == 2 assert len(result) == 2
ip.run_line_magic('config', "SqlMagic.autolimit = 1") ip.run_line_magic("config", "SqlMagic.autolimit = 1")
result = runsql(ip, "SELECT * FROM test;") result = runsql(ip, "SELECT * FROM test;")
assert len(result) == 1 assert len(result) == 1
@ -113,8 +134,8 @@ def test_persist(ip):
ip.run_cell("results = %sql SELECT * FROM test;") ip.run_cell("results = %sql SELECT * FROM test;")
ip.run_cell("results_dframe = results.DataFrame()") ip.run_cell("results_dframe = results.DataFrame()")
ip.run_cell("%sql --persist sqlite:// results_dframe") ip.run_cell("%sql --persist sqlite:// results_dframe")
persisted = runsql(ip, 'SELECT * FROM results_dframe') persisted = runsql(ip, "SELECT * FROM results_dframe")
assert 'foo' in str(persisted) assert "foo" in str(persisted)
def test_append(ip): def test_append(ip):
@ -122,9 +143,9 @@ def test_append(ip):
ip.run_cell("results = %sql SELECT * FROM test;") ip.run_cell("results = %sql SELECT * FROM test;")
ip.run_cell("results_dframe = results.DataFrame()") ip.run_cell("results_dframe = results.DataFrame()")
ip.run_cell("%sql --persist sqlite:// results_dframe") ip.run_cell("%sql --persist sqlite:// results_dframe")
persisted = runsql(ip, 'SELECT COUNT(*) FROM results_dframe') persisted = runsql(ip, "SELECT COUNT(*) FROM results_dframe")
ip.run_cell("%sql --append sqlite:// results_dframe") ip.run_cell("%sql --append sqlite:// results_dframe")
appended = runsql(ip, 'SELECT COUNT(*) FROM results_dframe') appended = runsql(ip, "SELECT COUNT(*) FROM results_dframe")
assert appended[0][0] == persisted[0][0] * 2 assert appended[0][0] == persisted[0][0] * 2
@ -149,28 +170,32 @@ def test_persist_bare(ip):
def test_persist_frame_at_its_creation(ip): def test_persist_frame_at_its_creation(ip):
ip.run_cell("results = %sql SELECT * FROM author;") ip.run_cell("results = %sql SELECT * FROM author;")
ip.run_cell("%sql --persist sqlite:// results.DataFrame()") ip.run_cell("%sql --persist sqlite:// results.DataFrame()")
persisted = runsql(ip, 'SELECT * FROM results') persisted = runsql(ip, "SELECT * FROM results")
assert 'Shakespeare' in str(persisted) assert "Shakespeare" in str(persisted)
def test_connection_args_enforce_json(ip): def test_connection_args_enforce_json(ip):
result = ip.run_cell("%sql --connection_arguments {\"badlyformed\":true") result = ip.run_cell('%sql --connection_arguments {"badlyformed":true')
assert result.error_in_exec assert result.error_in_exec
def test_connection_args_in_connection(ip): def test_connection_args_in_connection(ip):
ip.run_cell("%sql --connection_arguments {\"timeout\":10} sqlite:///:memory:") ip.run_cell('%sql --connection_arguments {"timeout":10} sqlite:///:memory:')
result = ip.run_cell("%sql --connections") result = ip.run_cell("%sql --connections")
assert 'timeout' in result.result['sqlite:///:memory:'].connect_args assert "timeout" in result.result["sqlite:///:memory:"].connect_args
def test_connection_args_single_quotes(ip): def test_connection_args_single_quotes(ip):
ip.run_cell("%sql --connection_arguments '{\"timeout\": 10}' sqlite:///:memory:") ip.run_cell("%sql --connection_arguments '{\"timeout\": 10}' sqlite:///:memory:")
result = ip.run_cell("%sql --connections") result = ip.run_cell("%sql --connections")
assert 'timeout' in result.result['sqlite:///:memory:'].connect_args assert "timeout" in result.result["sqlite:///:memory:"].connect_args
def test_connection_args_double_quotes(ip): def test_connection_args_double_quotes(ip):
ip.run_cell('%sql --connection_arguments \"{\\\"timeout\\\": 10}\" sqlite:///:memory:') ip.run_cell('%sql --connection_arguments "{\\"timeout\\": 10}" sqlite:///:memory:')
result = ip.run_cell("%sql --connections") result = ip.run_cell("%sql --connections")
assert 'timeout' in result.result['sqlite:///:memory:'].connect_args assert "timeout" in result.result["sqlite:///:memory:"].connect_args
# TODO: support # TODO: support
# @with_setup(_setup_author, _teardown_author) # @with_setup(_setup_author, _teardown_author)
@ -182,162 +207,168 @@ def test_connection_args_double_quotes(ip):
def test_displaylimit(ip): def test_displaylimit(ip):
ip.run_line_magic('config', "SqlMagic.autolimit = None") ip.run_line_magic("config", "SqlMagic.autolimit = None")
ip.run_line_magic('config', "SqlMagic.displaylimit = None") ip.run_line_magic("config", "SqlMagic.displaylimit = None")
result = runsql( result = runsql(
ip, ip,
"SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;" "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;",
) )
assert 'apple' in result._repr_html_() assert "apple" in result._repr_html_()
assert 'banana' in result._repr_html_() assert "banana" in result._repr_html_()
assert 'cherry' in result._repr_html_() assert "cherry" in result._repr_html_()
ip.run_line_magic('config', "SqlMagic.displaylimit = 1") ip.run_line_magic("config", "SqlMagic.displaylimit = 1")
result = runsql( result = runsql(
ip, ip,
"SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;" "SELECT * FROM (VALUES ('apple'), ('banana'), ('cherry')) AS Result ORDER BY 1;",
) )
assert 'apple' in result._repr_html_() assert "apple" in result._repr_html_()
assert 'cherry' not in result._repr_html_() assert "cherry" not in result._repr_html_()
def test_column_local_vars(ip): def test_column_local_vars(ip):
ip.run_line_magic('config', "SqlMagic.column_local_vars = True") ip.run_line_magic("config", "SqlMagic.column_local_vars = True")
result = runsql(ip, "SELECT * FROM author;") result = runsql(ip, "SELECT * FROM author;")
assert result is None assert result is None
assert 'William' in ip.user_global_ns['first_name'] assert "William" in ip.user_global_ns["first_name"]
assert 'Shakespeare' in ip.user_global_ns['last_name'] assert "Shakespeare" in ip.user_global_ns["last_name"]
assert len(ip.user_global_ns['first_name']) == 2 assert len(ip.user_global_ns["first_name"]) == 2
ip.run_line_magic('config', "SqlMagic.column_local_vars = False") ip.run_line_magic("config", "SqlMagic.column_local_vars = False")
def test_userns_not_changed(ip): def test_userns_not_changed(ip):
ip.run_cell( ip.run_cell(
dedent(""" dedent(
"""
def function(): def function():
local_var = 'local_val' local_var = 'local_val'
%sql sqlite:// INSERT INTO test VALUES (2, 'bar'); %sql sqlite:// INSERT INTO test VALUES (2, 'bar');
function()""")) function()"""
assert 'local_var' not in ip.user_ns )
)
assert "local_var" not in ip.user_ns
def test_bind_vars(ip): def test_bind_vars(ip):
ip.user_global_ns['x'] = 22 ip.user_global_ns["x"] = 22
result = runsql(ip, "SELECT :x") result = runsql(ip, "SELECT :x")
assert result[0][0] == 22 assert result[0][0] == 22
def test_autopandas(ip): def test_autopandas(ip):
ip.run_line_magic('config', "SqlMagic.autopandas = True") ip.run_line_magic("config", "SqlMagic.autopandas = True")
dframe = runsql(ip, "SELECT * FROM test;") dframe = runsql(ip, "SELECT * FROM test;")
assert not dframe.empty assert not dframe.empty
assert dframe.ndim == 2 assert dframe.ndim == 2
assert dframe.name[0] == 'foo' assert dframe.name[0] == "foo"
def test_csv(ip): def test_csv(ip):
ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh ip.run_line_magic("config", "SqlMagic.autopandas = False") # uh-oh
result = runsql(ip, "SELECT * FROM test;") result = runsql(ip, "SELECT * FROM test;")
result = result.csv() result = result.csv()
for row in result.splitlines(): for row in result.splitlines():
assert row.count(',') == 1 assert row.count(",") == 1
assert len(result.splitlines()) == 3 assert len(result.splitlines()) == 3
def test_csv_to_file(ip): def test_csv_to_file(ip):
ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh ip.run_line_magic("config", "SqlMagic.autopandas = False") # uh-oh
result = runsql(ip, "SELECT * FROM test;") result = runsql(ip, "SELECT * FROM test;")
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
fname = os.path.join(tempdir, 'test.csv') fname = os.path.join(tempdir, "test.csv")
output = result.csv(fname) output = result.csv(fname)
assert os.path.exists(output.file_path) assert os.path.exists(output.file_path)
with open(output.file_path) as csvfile: with open(output.file_path) as csvfile:
content = csvfile.read() content = csvfile.read()
for row in content.splitlines(): for row in content.splitlines():
assert row.count(',') == 1 assert row.count(",") == 1
assert len(content.splitlines()) == 3 assert len(content.splitlines()) == 3
def test_sql_from_file(ip): def test_sql_from_file(ip):
ip.run_line_magic('config', "SqlMagic.autopandas = False") ip.run_line_magic("config", "SqlMagic.autopandas = False")
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
fname = os.path.join(tempdir, 'test.sql') fname = os.path.join(tempdir, "test.sql")
with open(fname, 'w') as tempf: with open(fname, "w") as tempf:
tempf.write("SELECT * FROM test;") tempf.write("SELECT * FROM test;")
result = ip.run_cell("%sql --file " + fname) result = ip.run_cell("%sql --file " + fname)
assert result.result == [(1, 'foo'), (2, 'bar')] assert result.result == [(1, "foo"), (2, "bar")]
def test_sql_from_nonexistent_file(ip): def test_sql_from_nonexistent_file(ip):
ip.run_line_magic('config', "SqlMagic.autopandas = False") ip.run_line_magic("config", "SqlMagic.autopandas = False")
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
fname = os.path.join(tempdir, 'nonexistent.sql') fname = os.path.join(tempdir, "nonexistent.sql")
result = ip.run_cell("%sql --file " + fname) result = ip.run_cell("%sql --file " + fname)
assert isinstance(result.error_in_exec, FileNotFoundError) assert isinstance(result.error_in_exec, FileNotFoundError)
def test_dict(ip): def test_dict(ip):
result = runsql(ip, "SELECT * FROM author;") result = runsql(ip, "SELECT * FROM author;")
result = result.dict() result = result.dict()
assert isinstance(result, dict) assert isinstance(result, dict)
assert 'first_name' in result assert "first_name" in result
assert 'last_name' in result assert "last_name" in result
assert 'year_of_death' in result assert "year_of_death" in result
assert len(result['last_name']) == 2 assert len(result["last_name"]) == 2
def test_dicts(ip): def test_dicts(ip):
result = runsql(ip, "SELECT * FROM author;") result = runsql(ip, "SELECT * FROM author;")
for row in result.dicts(): for row in result.dicts():
assert isinstance(row, dict) assert isinstance(row, dict)
assert 'first_name' in row assert "first_name" in row
assert 'last_name' in row assert "last_name" in row
assert 'year_of_death' in row assert "year_of_death" in row
def test_bracket_var_substitution(ip): def test_bracket_var_substitution(ip):
ip.user_global_ns['col'] = 'first_name' ip.user_global_ns["col"] = "first_name"
assert runsql(ip, assert runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")[0] == (
"SELECT * FROM author" u"William",
" WHERE {col} = 'William' ")[0] == (u'William', u"Shakespeare",
u'Shakespeare', 1616) 1616,
)
ip.user_global_ns["col"] = "last_name"
result = runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")
assert not result
ip.user_global_ns['col'] = 'last_name'
result = runsql(ip,
"SELECT * FROM author"
" WHERE {col} = 'William' ")
assert not result
def test_multiline_bracket_var_substitution(ip): def test_multiline_bracket_var_substitution(ip):
ip.user_global_ns['col'] = 'first_name' ip.user_global_ns["col"] = "first_name"
assert runsql(ip, assert runsql(ip, "SELECT * FROM author\n" " WHERE {col} = 'William' ")[0] == (
"SELECT * FROM author\n" u"William",
" WHERE {col} = 'William' ")[0] == (u'William', u"Shakespeare",
u'Shakespeare', 1616) 1616,
)
ip.user_global_ns["col"] = "last_name"
result = runsql(ip, "SELECT * FROM author" " WHERE {col} = 'William' ")
assert not result
ip.user_global_ns['col'] = 'last_name'
result = runsql(ip,
"SELECT * FROM author"
" WHERE {col} = 'William' ")
assert not result
def test_multiline_bracket_var_substitution(ip): def test_multiline_bracket_var_substitution(ip):
ip.user_global_ns['col'] = 'first_name' ip.user_global_ns["col"] = "first_name"
result = ip.run_cell_magic('sql', '', """ result = ip.run_cell_magic(
"sql",
"",
"""
sqlite:// SELECT * FROM author sqlite:// SELECT * FROM author
WHERE {col} = 'William' WHERE {col} = 'William'
""") """,
assert (u'William', u'Shakespeare', 1616) in result )
assert (u"William", u"Shakespeare", 1616) in result
ip.user_global_ns['col'] = 'last_name' ip.user_global_ns["col"] = "last_name"
result = ip.run_cell_magic('sql', '', """ result = ip.run_cell_magic(
"sql",
"",
"""
sqlite:// SELECT * FROM author sqlite:// SELECT * FROM author
WHERE {col} = 'William' WHERE {col} = 'William'
""") """,
assert not result )
assert not result

View File

@ -1,79 +1,103 @@
from pathlib import Path import json
import os import os
from sql.parse import parse, connection_from_dsn_section from pathlib import Path
from six.moves import configparser from six.moves import configparser
from sql.parse import connection_from_dsn_section, parse
try: try:
from traitlets.config.configurable import Configurable from traitlets.config.configurable import Configurable
except ImportError: except ImportError:
from IPython.config.configurable import Configurable from IPython.config.configurable import Configurable
import json
empty_config = Configurable() empty_config = Configurable()
default_connect_args = {'options': '-csearch_path=test'} default_connect_args = {"options": "-csearch_path=test"}
def test_parse_no_sql(): def test_parse_no_sql():
assert parse("will:longliveliz@localhost/shakes", empty_config) == \ assert parse("will:longliveliz@localhost/shakes", empty_config) == {
{'connection': "will:longliveliz@localhost/shakes", "connection": "will:longliveliz@localhost/shakes",
'sql': '', "sql": "",
'result_var': None} "result_var": None,
}
def test_parse_with_sql(): def test_parse_with_sql():
assert parse("postgresql://will:longliveliz@localhost/shakes SELECT * FROM work", assert parse(
empty_config) == \ "postgresql://will:longliveliz@localhost/shakes SELECT * FROM work",
{'connection': "postgresql://will:longliveliz@localhost/shakes", empty_config,
'sql': 'SELECT * FROM work', ) == {
'result_var': None} "connection": "postgresql://will:longliveliz@localhost/shakes",
"sql": "SELECT * FROM work",
"result_var": None,
}
def test_parse_sql_only(): def test_parse_sql_only():
assert parse("SELECT * FROM work", empty_config) == \ assert parse("SELECT * FROM work", empty_config) == {
{'connection': "", "connection": "",
'sql': 'SELECT * FROM work', "sql": "SELECT * FROM work",
'result_var': None} "result_var": None,
}
def test_parse_postgresql_socket_connection(): def test_parse_postgresql_socket_connection():
assert parse("postgresql:///shakes SELECT * FROM work", empty_config) == \ assert parse("postgresql:///shakes SELECT * FROM work", empty_config) == {
{'connection': "postgresql:///shakes", "connection": "postgresql:///shakes",
'sql': 'SELECT * FROM work', "sql": "SELECT * FROM work",
'result_var': None} "result_var": None,
}
def test_expand_environment_variables_in_connection(): def test_expand_environment_variables_in_connection():
os.environ['DATABASE_URL'] = 'postgresql:///shakes' os.environ["DATABASE_URL"] = "postgresql:///shakes"
assert parse("$DATABASE_URL SELECT * FROM work", empty_config) == \ assert parse("$DATABASE_URL SELECT * FROM work", empty_config) == {
{'connection': "postgresql:///shakes", "connection": "postgresql:///shakes",
'sql': 'SELECT * FROM work', "sql": "SELECT * FROM work",
'result_var': None} "result_var": None,
}
def test_parse_shovel_operator(): def test_parse_shovel_operator():
assert parse("dest << SELECT * FROM work", empty_config) == \ assert parse("dest << SELECT * FROM work", empty_config) == {
{'connection': "", "connection": "",
'sql': 'SELECT * FROM work', "sql": "SELECT * FROM work",
'result_var': "dest"} "result_var": "dest",
}
def test_parse_connect_plus_shovel(): def test_parse_connect_plus_shovel():
assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == \ assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == {
{'connection': "sqlite://", "connection": "sqlite://",
'sql': 'SELECT * FROM work', "sql": "SELECT * FROM work",
'result_var': None} "result_var": None,
}
def test_parse_shovel_operator(): def test_parse_shovel_operator():
assert parse("dest << SELECT * FROM work", empty_config) == \ assert parse("dest << SELECT * FROM work", empty_config) == {
{'connection': "", "connection": "",
'sql': 'SELECT * FROM work', "sql": "SELECT * FROM work",
'result_var': "dest"} "result_var": "dest",
}
def test_parse_connect_plus_shovel(): def test_parse_connect_plus_shovel():
assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == \ assert parse("sqlite:// dest << SELECT * FROM work", empty_config) == {
{'connection': "sqlite://", "connection": "sqlite://",
'sql': 'SELECT * FROM work', "sql": "SELECT * FROM work",
'result_var': "dest"} "result_var": "dest",
}
class DummyConfig: class DummyConfig:
dsn_filename = Path('src/tests/test_dsn_config.ini') dsn_filename = Path("src/tests/test_dsn_config.ini")
def test_connection_from_dsn_section(): def test_connection_from_dsn_section():
result = connection_from_dsn_section(section='DB_CONFIG_1', result = connection_from_dsn_section(section="DB_CONFIG_1", config=DummyConfig())
config = DummyConfig()) assert result == "postgres://goesto11:seentheelephant@my.remote.host:5432/pgmain"
assert result == 'postgres://goesto11:seentheelephant@my.remote.host:5432/pgmain' result = connection_from_dsn_section(section="DB_CONFIG_2", config=DummyConfig())
result = connection_from_dsn_section(section='DB_CONFIG_2', assert result == "mysql://thefin:fishputsfishonthetable@127.0.0.1/dolfin"
config = DummyConfig())
assert result == 'mysql://thefin:fishputsfishonthetable@127.0.0.1/dolfin'