0

I have been trying to create a python script that will parse some pretty complex SQL and return a dataframe of all of the columns, tables, and databases being queried as well as section indicators of where these columns are located in the query like the name of a temp table being created. This includes not only the final columns being selected but also columns in subqueries, select statements in WHERE conditions, join conditions, coalesce statements, etc.

I am having trouble mapping the tables and databases in the WHERE conditions. It is returning the column names from these but the table and database values are coming back null. I am also having trouble only bringing back the actual table names and excluding any aliases.

import sqlglot
from sqlglot import expressions as exp
import pandas as pd

def extract_sql_metadata(sql: str) -> pd.DataFrame:
    parsed_statements = sqlglot.parse(sql)

    columns = []
    seen = set()
    table_registry = {}  # alias -> (table_name, db)
    current_section = None

    def record_column(col_expr: exp.Column):
        nonlocal current_section
        col = col_expr.name
        table_alias = col_expr.table

        resolved_table, db = table_registry.get(table_alias, (table_alias, None))

        key = (col, resolved_table, db, current_section)
        if key not in seen:
            seen.add(key)
            columns.append({
                "column": col,
                "table": resolved_table,
                "database": db,
                "query_section": current_section
            })

    def safe_traverse(val, context):
        if isinstance(val, exp.Expression):
            traverse(val, context)
        elif isinstance(val, list):
            for v in val:
                if isinstance(v, exp.Expression):
                    traverse(v, context)

    def traverse(node, context=None):
        nonlocal current_section

        if not isinstance(node, exp.Expression):
            return

        if isinstance(node, exp.CTE):
            name = node.alias_or_name
            current_section = name
            traverse(node.this, context=name)
            current_section = context

        elif isinstance(node, exp.Subquery):
            alias = node.alias_or_name
            if alias:
                table_registry[alias] = (f"subquery_{alias}", None)
                current_section = alias
            traverse(node.this, context=alias)
            current_section = context

        elif isinstance(node, exp.Table):
            table_name = node.name
            alias = node.alias_or_name or table_name
            db_expr = node.args.get("db")
            db = db_expr.name if isinstance(db_expr, exp.Identifier) else None
            table_registry[alias] = (table_name, db)

        elif isinstance(node, exp.Create):
            table_name = node.this.name
            current_section = table_name
            if node.expression:
                traverse(node.expression, context=table_name)
            current_section = context

        elif isinstance(node, exp.Insert):
            current_section = "final_select"
            traverse(node.expression, context=current_section)
            current_section = context

        elif isinstance(node, exp.Select):
            for proj in node.expressions:
                if isinstance(proj, exp.Alias) and isinstance(proj.this, exp.Column):
                    record_column(proj.this)
                elif isinstance(proj, exp.Column):
                    record_column(proj)

        elif isinstance(node, exp.Column):
            record_column(node)
            return  # avoid recursing into its children again

        # Safely traverse other children
        for key, child in node.args.items():
            # Skip strings or identifiers to avoid str.args error
            if isinstance(child, (exp.Expression, list)):
                safe_traverse(child, context)

    for stmt in parsed_statements:
        traverse(stmt)

    return pd.DataFrame(columns)

Sample of a complex SQL statement to be parsed

CREATE TABLE TRD AS (
SELECT 
   TR.REQUEST_ID
   ,P17.THIS_WORKING
   ,P17.REQUEST_FIELD_VAL AS "AUTHORIZATION"
   ,P20.REQUEST_FIELD_VAL AS "CONTRACT PD/AOR"
FROM ADW_VIEWS_FSICS.FSICS_IPSS_TRAINING_REQUESTS TR
LEFT JOIN ADW_VIEWS_FSICS.FSICS_IPSS_TRNG_REQUESTS_DET P17 
   ON TR.REQUEST_ID = P17.REQUEST_ID 
   AND P17.REQUEST_FIELD_EXTERNAL_ID = 'IPSS_MD_PROPERTY_17'
LEFT JOIN ADW_VIEWS_FSICS.FSICS_IPSS_TRNG_REQUESTS_DET P20 
   ON TR.REQUEST_ID = P20.REQUEST_ID 
   AND P20.REQUEST_FIELD_EXTERNAL_ID = 'IPSS_MD_PROPERTY_20'
WHERE TR.REQUEST_ID IN (
   SELECT REQUEST_ID
   FROM ADW_VIEWS_FSICS.MY_TNG_REQUESTS
   WHERE EVENT_TYPE = 'BASIC'
   )
); 

Given the above function and example SQL, I would want to get the below results.

column table database query_section
REQUEST_ID FSICS_IPSS_TRAINING_REQUESTS ADW_VIEWS_FSICS TRD
THIS_WORKING FSICS_IPSS_TRNG_REQUESTS_DET ADW_VIEWS_FSICS TRD
REQUEST_FIELD_VAL FSICS_IPSS_TRNG_REQUESTS_DET ADW_VIEWS_FSICS TRD
REQUEST_ID FSICS_IPSS_TRNG_REQUESTS_DET ADW_VIEWS_FSICS TRD
REQUEST_FIELD_EXTERNAL_ID FSICS_IPSS_TRNG_REQUESTS_DET ADW_VIEWS_FSICS TRD
REQUEST_ID MY_TNG_REQUESTS ADW_VIEWS_FSICS TRD
EVENT_TYPE MY_TNG_REQUESTS ADW_VIEWS_FSICS TRD

1 Answer 1

1

I was way over thinking it. Just in case anyone else stumbles across this, here is how I did it:

from sqlglot.optimizer.qualify_columns  import qualify_columns
from sqlglot.optimizer.scope            import traverse_scope
from sqlglot                            import parse, exp

def parse_query(sql_query, dialect='tsql'):

    df_list = []
    for ast in parse(sql_query, read=dialect):
        ast = qualify_columns(ast, schema=None)
        section = str(ast.this).upper() if ast.this else str(ast.key).upper()
        physical_columns = []
        for scope in traverse_scope(ast):
            for c in scope.columns:
                table = scope.sources.get(c.table)
                if isinstance(scope.sources.get(c.table), exp.Table):
                    database_name = table.db if hasattr(table, 'db') else None
                    physical_columns.append((section, database_name, table.name, c.name))
                else:
                    physical_columns.append((section, None, None, c.name))

        df = pd.DataFrame(physical_columns, columns=['section', 'database', 'table', 'columns'])
        df = df.drop_duplicates()
        df_list.append(df)

    return pd.concat(df_list, ignore_index=True)
Sign up to request clarification or add additional context in comments.

1 Comment

Thank you for contributing to the Stack Overflow community. This may be a correct answer, but it’d be really useful to provide additional explanation of your code so developers can understand your reasoning. This is especially useful for new developers who aren’t as familiar with the syntax or struggling to understand the concepts. Would you kindly edit your answer to include additional details for the benefit of the community?

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.