Eat Study Love

먹고 공부하고 사랑하라

Data Science/Research

SQL2NL 모델 추가 실험(VectorDB, Embedding)[1]

eatplaylove 2025. 8. 4. 17:56

https://eglife.tistory.com/361

 

SQL2NL 모델 실험진행(2)

https://eglife.tistory.com/360 SQL2NL 모델 실험진행https://eglife.tistory.com/359 SQL2NL 모델 개선방안에 대한 검토 - Prompt & RAGhttps://eglife.tistory.com/358 연구주제 고찰(2)MCP를 많이 파봤지만, 일단 연구주제는 SQL2NL

eglife.tistory.com

아무래도 DB를 중점적으로 다루는 교수님을 만났다 보니까 모든 연구에 DB가 묻어있기를 바라시는 거 같다.

아래, 내 Flow도 SQL2NL은 NL2SQL 연구의 '검증' 용으로 쓰일 예정이며 이 때문에 LLM의 자연어반환은 거의 99%에 가까운 정확도를 보여야 한다는 것이다.

1. SQL2NL Schema Linking 유효성 확인 KCC 
2. SQL2NL Few-shot 유효성 확인 中  
 - 어떤 기준으로 Few-shot을 선택할 지에 대해서 Novelty 

------ Future Work --------
3. SQL2NL VectorDB + RAG 적합성 확인
 - SQL<->NL Pair Data를 어떤 기준으로 Vector Embedding 할 까? : Novelty
4. SQL2NL Model Fine-tuning 기법 검토
Finally : SQL2NL Model 만들고 배포

 

일단 내 연구 Flow는 위와 같다. 보통 NL2SQL이 BIRD, SPIDER 등 특정 Benchmark Domain 內에서 실행될 터이니, 죽이되는 밥이되는 연구법을 갈고 닦아서 특정 Domain 안에서는 NL변환 정확도를 최적화하라는 교수님의 Order이다.

 

일단, 나는 나의 갈 길을 가야하니 SQL<->NL Pair Data를 VectorDB에 어떤 기준으로 Embedding 할 지를 생각해봐야겠다.

 

방법은 크게 2가지이다.

 

1. 총 Data가 많을 때( ex) 1K 개 있다고 가정 ) 그 중 대표가 되는 일부 200ea를 추출하여 얘네 기준으로 Embedding을 한다. 추후 Test Query가 입력되면 얘도 200차원으로 임베딩해서 총 Data중 유사한 녀석 Top-K를 추출한다.

+++ ) VectorDB Embedding Vector의 차원을 200차원(N차원)으로 관리할 수 있다.

- - - ) 어떤 놈들을 어떤 기준으로 대표로 뽑을 지 고민해야 한다.

 

2. 총 Data를 특정 Feature N개에 맞춰 Embedding을 진행한다. 그 후 Process는 1번과 동일

+++ ) VectorDB Embedding Vector의 차원을 N차원으로 관리할 수 있다. 솔직히, SQL은 Structure가 어느 정도 정해져있기에 SQLGlot의 Parsing 기준으로만 Feature를 나눠도 유요할 것이라고 예상한다.


SQLGlot을 이용한다면, 사실 1번보단 2번이 더 나을 거 같다.(물론 보다보면 1,2번이 비슷한 말을 하고 있는 거 같음)

그리고, Python<->SQL 세상에선 이 오픈소스 라이브러리가 꽤나 많이 활용되는 거 같다.

 

SQLGlot의 AST(Abstract Syntax Tree) Parsing 기준을 보고 그것을 Feature로 만들어서 다량의 SQL<->Text Data를 VectorDB에 Embedding해보자. 만약 기준을 10개 잡았다면 위 10개 기준으로 모든 SQL<->NL Pair Data를 Embedding하는 Embedding Function을 만들어야 할 것이고, Test용 Input Data도 임베딩을 한다.

 

이미 SQLGlot기반 Top - K fewshot은 꽤나 유효하다는 것을 K-DS 연구로 검토완료한 상황이다. 추가로 확인할 것은, 대량의 SQL-NL Bigdata를 다루기 위한 이 VectorDB & RAG Framework가 SQL2NL에도 쓰일 수 있냐는 것이고, 그 방법론으로 내가 제시한 Structure Aware(가칭) 기반 Embedding이 유효한 영향을 끼치냐이다.

 

다음주엔 얼른 Embedding Function을 만들고 Test까지 진행해봐야겠다.

 

사실, Test는 금방할 것으로 예상되지만,, 이 Embedding Feature를 만드는 것이 시간이 좀 걸릴 거 같다.

그리고, 뭐든 Test할때 Schema Linking은 필수적으로 해줘야 할 거 같다. 이 정도 Basic도 없으면 정확도가 좀 떨어진다.


SQLGlot_parser.py
0.31MB
diff(SQLGlot_diff_function).py
0.02MB

1) SQL Vector Embedding Feature를 설정하기 위해 SQLGlot의 parsing 구조를 파악한다.

 

위 코드를 자체적으로 분석해보려고 했는데 코드가 너무 길다.. 그래서 GPT를 통해 같이 분석하기로 한다.

 

연구목표는 SQLGlot의 SQL문 Parsing 기준을 Feature로 설정해서 SQL Data를 Embedding 하는 기준으로 삼으려는 것이다.

 

이렇게 해서 Embedding이 비슷한 위치에 위치한 친구들은, SQLGlot의 diff function을 이용했을 때 Similarity가 높게 측정되었으면 한다.

# ───────────────────────────────
# STEP 1. AST diff 기반 유사도 계산
# ───────────────────────────────
def diff_similarity(sql1: str, sql2: str) -> float:
    try:
        tree1 = parse_one(sanitize_sql(sql1))
        tree2 = parse_one(sanitize_sql(sql2))
    except Exception as e:
        print(f"❌ Parse failed:\n{sql1}\n→ {e}")
        return 0.0
    edits = diff(tree1, tree2)
    total = len(edits)
    keep = sum(isinstance(e, Keep) for e in edits)
    return keep / total if total > 0 else 1.0

 

근데 이 방법에서, 내가 생각했던 단점을 GPT도 정확히 짚어냈다.

 

허점이 너무나도 많이 보이지만, 어찌 어찌 Feature를 뽑아내서 Embedding을 진행시켜보자.

 

- SQLGlot Parse.py 기준으로 Feature를 뽑는다면,

SQL을 Parse.py 코드를 통과시켜 아래 항목들에 대해서 check 후 임베딩하는 것.

Feature 이름 뽑은 이유 / 설명
distinct_keyword DISTINCT_TOKENS에 정의. SELECT 행 집합을 유일화하는 의미-변별적 키워드.
select_star PRIMARY_PARSERS에서 TokenType.STAR 별도 처리. 컬럼 전체 선택 여부를 캡처.
where_clause QUERY_MODIFIER_PARSERS에 WHERE 토큰 매핑. 행 필터 존재 여부.
group_by_clause GROUP_BY 토큰 매핑. 집계 그룹 지정이 있으면 1.
having_clause HAVING 토큰 매핑. 집계 결과 필터 존재 여부.
qualify_clause QUALIFY 토큰 매핑. 윈도 함수 결과 필터링 사용.
window_clause WINDOW 토큰 매핑. 명시적 WINDOW 정의(윈도 명 재사용)를 식별.
order_by_clause ORDER_BY 토큰 매핑. 결과 정렬 여부.
limit_fetch_offset LIMIT / FETCH / OFFSET 토큰 매핑. 행 수 제한·오프셋 존재.
table_sample_clause TABLE_SAMPLE / USING 토큰 매핑. 샘플링 절 사용 여부.
cluster_distribute_sort CLUSTER_BY / DISTRIBUTE_BY / SORT_BY 토큰 매핑. 빅데이터 계열 분산·정렬 힌트.
connect_by_clause CONNECT_BY / START_WITH 토큰 매핑. 계층 쿼리(Oracle style) 여부.
join_method JOIN_METHODS 집합(A SOF/NATURAL/POSITIONAL). 조인 알고리즘 특성.
join_side JOIN_SIDES 집합(LEFT/RIGHT/FULL). 외부 조인 방향.
join_kind JOIN_KINDS 집합(INNER/OUTER/CROSS/SEMI/ANTI/STRAIGHT_JOIN). 조인 종류.
set_operation SET_OPERATIONS 집합(UNION/INTERSECT/EXCEPT). 다중 SELECT 결과 병합 여부.
subquery_predicate SUBQUERY_PREDICATES 집합(ANY/ALL/EXISTS/SOME). 서브쿼리 비교 연산 사용.
boolean_operator CONJUNCTION AND, DISJUNCTION OR. 복합 조건 결합 존재.
equality_operator EQUALITY 딕셔너리(=, !=, <=>). 등가 비교 사용.
comparison_operator COMPARISON 딕셔너리(>, ≥, <, ≤). 비등가 비교 사용.
arithmetic_operator TERM (+ - %) & FACTOR (* / DIV) 딕셔너리. 산술 연산 존재.
bitwise_operator BITWISE 딕셔너리(&

 

- SQLGlot diff.py 기준으로 Feature를 뽑는다면,

SQL을 diff 함수에서 중점적으로 보는 아래 항목들을 기준으로 임베딩 시키는 것

Feature ID 추출 규칙 (AST 기준) 값 타입 / 예시 왜 필요한가? (diff 보존 근거)
F1 : Root Op tree.key → Select, Insert, Create … One-hot 루트가 다르면 diff 수정량이 대부분 ‘Replace’ → 가장 큰 결정 요인
F2 : Projection Sig. len(tree.expressions) + Star 여부 정수 & 이진 SELECT * ↔ SELECT a,b 는 diff 에서 Star vs Column 대량 교체 발생
F3 : Aggregate Funcs AST 후행 탐색에서 exp.Function & is_agg 집합(함수명) 그룹화 의미 변화는 Plan 변동이 크고 diff 에서도 대량 Replace
F4 : Scalar Funcs FUNCTIONS 딕트 키와 매칭되는 노드 집합 집합(함수명) 문자열·수치 변환 등 SELECT식이 유지되면 Keep ↑
F5 : Tables Used Table.this 식별자 리스트 멀티-핫 FROM/JOIN 의 테이블 세트가 같으면 diff 에서 대부분 Keep
F6 : Join Topology 각 Join.kind, Join.on 유형(=, >, LIKE) 카운트/원-핫 INNER↔LEFT 같은 변경은 diff 의 핵심 엣지
F7 : Predicate Pattern WHERE 트리의 논리 연산자 깊이 + Comparison 연산자 집합 (depth, {=,>,BETWEEN…}) 조건식 구조가 유사하면 Keep 비율 높음
F8 : Literal Buckets 숫자·문자·날짜 리터럴 개수, 타입별 (n_int, n_str, n_date) 값이 달라도 노드 타입은 동일 → Keep 비율 확보
F9 : Window/Analytic Window 노드 여부 + over.clause 해시 이진 & 집합 윈도우 함수 유무가 diff 를 크게 갈라놓음
F10 : Grouping/Ordering GROUP BY 컬럼 수, ORDER BY 존재 (int, bool) 집계·정렬 변동도 AST 수정량에 민감
F11 : Limit/Offset limit.expression 존재 및 값 구간 (None / small / big) 이진+범주 결과 제어 절은 쿼리 의미엔 작지만 AST 차이엔 크게 반영
F12 : Subquery Pattern 서브쿼리 개수 + EXISTS / IN / ANY 구분 (count, one-hot) 서브쿼리 구조 변경은 diff 에서 노드 대거 이동
F13 : CTE Signature WITH 유무 + CTE 개수 이진+정수 CTE 단위가 같으면 트리 상위가 거의 Keep
F14 : Data-Type Hints TYPE_TOKENS 등장 카운트 멀티-핫 CAST/DDL 구문 비교에 필요
F15 : Expression Depth Stats 전체 트리 깊이, 평균·최대 서브트리 깊이 정수·실수 복잡도 유사성이 diff 유사도와 양의 상관

 

- 두 번째 방식 기준으로 예시를 만등러 보았다.

 

 

2) SQL Vector Embedding Feature 기준으로 Embedding 진행

Feature는 약 70ea 뽑아놨는데, 제대로 Embedding을 시킨 건지 모르겄다.

import re
import numpy as np
from typing import List, Dict
from sklearn import tree
from sqlglot import parse_one, exp

# 1. Feature 키 리스트 정의 (총 72개)
FEATURE_KEYS: List[str] = [
    "is_select", "has_cte", "has_subquery",
    "n_select_expr", "has_star", "n_literals", "n_columns",
    "n_agg_funcs", "has_sum", "has_avg", "has_count", "has_min", "has_max",
    "n_tables", "n_joins", "has_inner_join", "has_left_join", "has_right_join", "has_full_join",
    "has_where", "n_predicates", "has_between", "has_in", "has_like", "has_and", "has_or",
    "has_group_by", "n_group_cols", "has_having",
    "has_order_by", "has_order_asc", "has_order_desc", "has_limit",
    "n_windows",
    "n_scalar_funcs", "has_date_func", "has_string_func",
    "has_distinct", "select_has_alias", "has_union", "has_intersect", "has_except",
    "has_case", "has_cast", "has_coalesce", "has_exists", "has_json_extract", "has_between_dates",
    "has_int_literal", "has_float_literal", "has_string_literal",
    "query_len_char", "query_depth",
    "has_placeholder",
    "has_rank", "has_row_number",
    "has_is_null", "has_is_not_null",
    "has_any", "has_all",
    "n_eq", "n_neq", "n_gt_lt", "n_arithmetic_ops",
    "has_concat", "has_interval",
    "has_star_except", "has_star_replace",
    "has_array_construct", "has_unnest"
]

# 2. key → index 매핑
FEATURE_INDEX: Dict[str, int] = {k: i for i, k in enumerate(FEATURE_KEYS)}

# 3. SQL → feature vector 함수 정의
def extract_sql_features(sql: str) -> np.ndarray:
    vec = np.zeros(len(FEATURE_KEYS), dtype=float)
    
    try:
        tree = parse_one(sql)
    except Exception:
        return vec  # 파싱 실패 시 0벡터 반환

    def mark(feat: str, value: float = 1.0):
        vec[FEATURE_INDEX[feat]] = value

    def inc(feat: str, value: int = 1):
        vec[FEATURE_INDEX[feat]] += value

    mark("is_select", isinstance(tree, exp.Select))
    mark("has_cte", bool(tree.find(exp.With)))
    mark("has_subquery", any(isinstance(n, exp.Subquery) for n in tree.find_all(exp.Subquery)))

    # SELECT
    select_exprs = tree.expressions or []
    mark("n_select_expr", len(select_exprs))
    if any(isinstance(e, exp.Star) for e in select_exprs):
        mark("has_star")
        if any(e.args.get("except") for e in select_exprs if isinstance(e, exp.Star)):
            mark("has_star_except")
        if any(e.args.get("replace") for e in select_exprs if isinstance(e, exp.Star)):
            mark("has_star_replace")

    # Column & Literal
    for n in tree.walk():
        if isinstance(n, exp.Column):
            inc("n_columns")
        elif isinstance(n, exp.Literal):
            inc("n_literals")
            if n.is_string:
                mark("has_string_literal")
            elif re.match(r"^\d+$", str(n.this)):
                mark("has_int_literal")
            else:
                mark("has_float_literal")

    # Aggregation
    agg_map = {
        exp.Sum: "has_sum",
        exp.Avg: "has_avg",
        exp.Count: "has_count",
        exp.Min: "has_min",
        exp.Max: "has_max"
    }

    # 날짜/문자열 함수 집합
    scalar_date_funcs = {"date", "datediff", "strftime", "date_add", "date_sub"}
    scalar_string_funcs = {"substr", "substring", "upper", "lower", "concat", "concat_ws", "length"}

    # ─────────────────────────────────────────
    # ✅ 1. 집계 함수: exp.AggFunc 서브클래스만 확정적으로 처리
    # ─────────────────────────────────────────
    for f in tree.find_all(exp.AggFunc):  # 이 노드들은 is_aggregate=True가 보장됨
        fname = f.name.upper()
        inc("n_agg_funcs")
        for k, flag in agg_map.items():
            if isinstance(f, k):
                mark(flag)

    # ─────────────────────────────────────────
    # ✅ 2. 일반 함수: 전체 Expression 중 name 존재하는 것만 추출
    # ─────────────────────────────────────────
    for node in tree.walk():
        # 함수가 아닌 연산자 or 구조 노드는 제외
        if not isinstance(node, exp.Expression):
            continue
        if not hasattr(node, "name"):  # name 속성이 없는 연산자들 제외
            continue
        if not isinstance(node, exp.Func):  # exp.Func 기반 함수만
            continue
        if getattr(node, "is_aggregate", False):  # 이미 처리한 agg는 skip
            continue

        fname = node.name.upper()
        inc("n_scalar_funcs")

        if fname.lower() in scalar_date_funcs:
            mark("has_date_func")
        if fname.lower() in scalar_string_funcs:
            mark("has_string_func")
        if fname == "COALESCE":
            mark("has_coalesce")
        if fname.startswith("JSON_") or isinstance(node, (exp.JSONExtract, exp.JSONExtractScalar)):
            mark("has_json_extract")
        if fname in {"CONCAT", "CONCAT_WS"}:
            mark("has_concat")
        if fname == "CAST" or isinstance(node, exp.Cast):
            mark("has_cast")
        if fname == "RANK":
            mark("has_rank")
        if fname == "ROW_NUMBER":
            mark("has_row_number")
        if fname == "UNNEST":
            mark("has_unnest")

    # Tables / Joins
    mark("n_tables", len(list(tree.find_all(exp.Table))))
    joins = list(tree.find_all(exp.Join))
    mark("n_joins", len(joins))
    for j in joins:
        kind = (j.args.get("kind") or "").upper()
        if kind == "INNER":
            mark("has_inner_join")
        elif kind == "LEFT":
            mark("has_left_join")
        elif kind == "RIGHT":
            mark("has_right_join")
        elif kind == "FULL":
            mark("has_full_join")

    # WHERE
    where = tree.find(exp.Where)
    if where:
        mark("has_where")
        preds = list(where.find_all(exp.Predicate))
        mark("n_predicates", len(preds))
        if where.find(exp.Between):
            mark("has_between")
            if any(re.search(r"\d{4}", str(l.this)) for l in where.find_all(exp.Literal)):
                mark("has_between_dates")
        if where.find(exp.In): mark("has_in")
        if where.find(exp.Like): mark("has_like")
        if where.find(exp.And): mark("has_and")
        if where.find(exp.Or): mark("has_or")
        if where.find(exp.Is):
            if any(getattr(p, "negated", False) for p in where.find_all(exp.Is)):
                mark("has_is_not_null")
            else:
                mark("has_is_null")
        if where.find(exp.Exists): mark("has_exists")

    # GROUP BY / HAVING
    group = tree.find(exp.Group)
    if group:
        mark("has_group_by")
        mark("n_group_cols", len(group.expressions or []))
    if tree.find(exp.Having):
        mark("has_having")

    # ORDER / LIMIT
    order = tree.find(exp.Order)
    if order:
        mark("has_order_by")
        for o in order.expressions or []:
            direction = (o.args.get("desc") and "DESC") or "ASC"
            if direction == "ASC":
                mark("has_order_asc")
            else:
                mark("has_order_desc")
    if tree.find(exp.Limit):
        mark("has_limit")

    # Windows
    mark("n_windows", len(list(tree.find_all(exp.Window))))

    # Misc
    if tree.args.get("distinct"): mark("has_distinct")
    if tree.find(exp.Union): mark("has_union")
    if tree.find(exp.Intersect): mark("has_intersect")
    if tree.find(exp.Except): mark("has_except")
    mark("query_len_char", len(sql))

    def _depth(node: exp.Expression, d=0) -> int:
        if not isinstance(node, exp.Expression):
            return d
        return max([d] + [_depth(c, d + 1) for c in node.args.values() if isinstance(c, exp.Expression) or isinstance(c, list)])
    mark("query_depth", _depth(tree))

    if re.search(r":\w+|\?\d*", sql):
        mark("has_placeholder")
    if tree.find(exp.Interval): mark("has_interval")
    if tree.find(exp.Array): mark("has_array_construct")

    inc("n_eq", len(list(tree.find_all(exp.EQ))))
    inc("n_neq", len(list(tree.find_all(exp.NEQ))))
    inc("n_gt_lt", len(list(tree.find_all(exp.GT))) + len(list(tree.find_all(exp.LT))))

    arithmetic_ops = {exp.Add, exp.Sub, exp.Mul, exp.Div, exp.Mod}
    n_ops = sum(1 for n in tree.walk() if type(n) in arithmetic_ops)
    mark("n_arithmetic_ops", n_ops)

    return vec

 

일단 위 코드 기준으로, 어찌어찌 70차원짜리 Data Embedding은 完

 

다음에는 실제 이 Vector DB로 RAG를 돌려보고 정확도를 측정해보도록 해야겠다.