Source code for airflow_plugins.operators.db

import logging
import re

import psycopg2
import psycopg2.extensions
from airflow.hooks.postgres_hook import PostgresHook as PostgresHookBase
from airflow.operators.postgres_operator import \
    PostgresOperator as PostgresOperatorBase
from airflow.utils.decorators import apply_defaults
from airflow_plugins.operators.base import \
    PostgresOperator as PostgresOperatorStatic


[docs]class PostgresHook(PostgresHookBase): """Tuned PostgreSQL hook which support running SQL like create database. Supports silent fail. """ def __init__(self, database=None, fail_silently=False, *args, **kwargs): super(PostgresHook, self).__init__(*args, **kwargs) self.fail_silently = fail_silently self.schema = database def get_conn(self): conn = self.get_connection(self.postgres_conn_id) conn_args = dict( host=conn.host, user=conn.login, password=conn.password, dbname=self.schema or conn.schema, port=conn.port) # check for ssl parameters in conn.extra for arg_name, arg_val in conn.extra_dejson.items(): if arg_name in ['sslmode', 'sslcert', 'sslkey', 'sslrootcert', 'sslcrl']: conn_args[arg_name] = arg_val psycopg2_conn = psycopg2.connect(**conn_args) if psycopg2_conn.server_version < 70400: self.supports_autocommit = True return psycopg2_conn
[docs] def run(self, sql, autocommit=False, parameters=None): """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute sequentially :param sql: the sql statement to be executed (str) or a list of sql statements to execute :type sql: str or list :param autocommit: What to set the connection's autocommit setting to before executing the query. :type autocommit: bool :param parameters: The parameters to render the SQL query with. :type parameters: mapping or iterable """ conn = self.get_conn() if isinstance(sql, str): sql = [sql] self.set_autocommit(conn, autocommit) cur = conn.cursor() for s in sql: logging.info(s) if parameters is not None: cur.execute(s, parameters) else: if self.fail_silently: try: cur.execute(s) conn.commit() except Exception as e: conn.rollback() logging.exception(e) else: cur.execute(s) conn.commit() cur.close() conn.close()
[docs]class PostgresOperator(PostgresOperatorBase): """PostgreSQL operator which uses PostgresHook""" @apply_defaults def __init__(self, database=None, fail_silently=True, *args, **kwargs): super(PostgresOperator, self).__init__(*args, **kwargs) self.fail_silently = fail_silently self.schema = database def pre_execute(self, context): self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, database=self.schema, fail_silently=self.fail_silently) def execute(self, context): logging.info('Executing: ' + str(self.sql)) self.hook.run(self.sql, self.autocommit, parameters=self.parameters)
[docs]class CreateDatabase(PostgresOperatorStatic): """Operator which creates database in PostgreSQL.""" _sql = [ "CREATE DATABASE {{ params.database_name }};", # keep create db at top "GRANT ALL PRIVILEGES ON DATABASE {{ params.database_name }} " "TO {{ params.user }};", # set user in pre_execute if not in params ] def pre_execute(self, context): params = context['params'] company = params.get('company') if company is not None: db_name = params['database_name'] self.params['database_name'] = company.lower() + '_' + db_name hook = PostgresHook(postgres_conn_id=self.postgres_conn_id) conn = hook.get_connection(self.postgres_conn_id) if conn is not None: user = params.get('user', conn.login) if user == conn.login: self.params['user'] = user else: for item in reversed([ "CREATE USER {{ params.user }} " "WITH PASSWORD '{{ params.password }}';", "ALTER ROLE {{ params.user }} " "SET client_encoding TO 'utf8';", "ALTER ROLE {{ params.user }} " "SET default_transaction_isolation TO 'read committed';", "ALTER ROLE {{ params.user }} SET timezone TO 'UTC';", ]): self._sql.insert(1, item) self.sql = self._sql context['ti'].render_templates() def execute(self, context): sqls = (self.sql[0:1], self.sql[1:]) logging.info('Executing: ' + str(self.sql)) self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, fail_silently=True) # fails if db exists self.hook.run(sqls[0], self.autocommit, parameters=self.parameters) self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, fail_silently=False) # should not fail if len(sqls[1]) > 0: self.hook.run(sqls[1], self.autocommit, parameters=self.parameters)
[docs]class DropDatabase(PostgresOperatorStatic): """Drop database operator.""" sql = [ "DROP DATABASE {{ params.company|lower }}_{{ params.database_name }};", ] def execute(self, context): logging.info('Executing: ' + str(self.sql)) self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, fail_silently=True) self.hook.run(self.sql, self.autocommit, parameters=self.parameters)
[docs]class ChangeDatabaseName(PostgresOperatorStatic): """Rename database in operator.""" sql = [ "ALTER DATABASE {{ params.company|lower }}_{{ params.database_name }} " "RENAME TO {{ params.company|lower }}_{{ params.database_name }}" "_{{ ts_nodash[:15] }};", ] def execute(self, context): logging.info('Executing: ' + str(self.sql)) self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, fail_silently=False) self.hook.run(self.sql, autocommit=True, parameters=self.parameters)
[docs]class CreateTableWithColumns(PostgresOperator): """Create database with columns.""" _sql = [ "DROP TABLE IF EXISTS {{ params.table_name }};", "CREATE TABLE {{ params.table_name }} ({{ params.table_columns }});" ] @apply_defaults def __init__(self, *args, **kwargs): super(CreateTableWithColumns, self).__init__(sql=self._sql, *args, **kwargs) @classmethod def _parse_extra_args(cls, args): parsed = [] if not args: return parsed if not isinstance(args, list): args = args.strip().split() def add_stripped(*values): for val in values: if not val: continue else: while (len(val) > 1 and val[0] == val[-1] and val[0] in ['"', "'"]): val = val[1:-1] if val: parsed.append(val) for arg in args: if re.match(r'-[-\w]+=', arg): argsplit = arg.split('=') opt = argsplit[0] val = '='.join(argsplit[1:]) add_stripped(opt, val) else: add_stripped(arg) return parsed @classmethod def _get_table_columns(cls, csv_file_path, extra=None): from io import StringIO from csvkit.utilities.csvcut import CSVCut output = StringIO() extra_args = cls._parse_extra_args(extra) if extra else [] args = ['-n', *extra_args, csv_file_path] print(args) csvcut = CSVCut(args=args) csvcut.output_file = output csvcut.run() csv_columns = [ col.split(': ')[1] for col in output.getvalue().splitlines() ] table_columns = [ '"{}"'.format(col) if (col != col.lower() or ' ' in col) else col for col in csv_columns ] return table_columns def pre_execute(self, context): if context['params'].get('table_columns') is None: columns = self._get_table_columns(context['params']['csv_file'], context['params'].get('extra')) self.params['table_columns'] = ', '.join([ '{} TEXT'.format(col) for col in columns ]) self.sql = self._sql context['ti'].render_templates() super(CreateTableWithColumns, self).pre_execute(context)