cab33cf787
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
144 lines
5.4 KiB
Python
144 lines
5.4 KiB
Python
from __future__ import annotations
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import pandas as pd
|
|
from sqlalchemy import create_engine, text, MetaData, Table, inspect
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
|
|
from .config import ImportConfig, SheetConfig
|
|
from .reader import ExcelReader
|
|
from .schema import build_columns
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _split_table(target_table: str) -> tuple[str | None, str]:
|
|
"""Split 'schema.table' into (schema, table). Returns (None, table) if no dot."""
|
|
if "." in target_table:
|
|
schema, table = target_table.split(".", 1)
|
|
return schema, table
|
|
return None, target_table
|
|
|
|
|
|
class Importer:
|
|
def __init__(self, config: ImportConfig):
|
|
self.config = config
|
|
self.engine = create_engine(config.dsn)
|
|
|
|
def run(self, excel_path: str | Path) -> dict[str, int]:
|
|
"""Import all configured sheets. Returns {table_name: rows_imported}."""
|
|
reader = ExcelReader(excel_path)
|
|
results = {}
|
|
for sheet_cfg in self.config.sheets:
|
|
rows = self._import_sheet(reader, sheet_cfg)
|
|
results[sheet_cfg.target_table] = rows
|
|
return results
|
|
|
|
def _import_sheet(self, reader: ExcelReader, cfg: SheetConfig) -> int:
|
|
df = reader.read(cfg)
|
|
if df.empty:
|
|
logger.warning("Sheet %r is empty, skipping.", cfg.sheet)
|
|
return 0
|
|
|
|
logger.info("Read %d rows from sheet %r -> table %r", len(df), cfg.sheet, cfg.target_table)
|
|
|
|
with self.engine.begin() as conn:
|
|
self._ensure_table(conn, df, cfg)
|
|
|
|
if cfg.mode == "replace":
|
|
dialect = self.engine.dialect.name
|
|
truncate_sql = (
|
|
f"DELETE FROM {cfg.target_table}"
|
|
if dialect == "sqlite"
|
|
else f"TRUNCATE TABLE {cfg.target_table}" # schema.table is valid SQL here
|
|
)
|
|
conn.execute(text(truncate_sql))
|
|
rows = self._bulk_insert(conn, df, cfg.target_table)
|
|
elif cfg.mode == "upsert":
|
|
rows = self._upsert(conn, df, cfg)
|
|
else: # append
|
|
rows = self._bulk_insert(conn, df, cfg.target_table)
|
|
|
|
logger.info("Imported %d rows into %r (mode=%s)", rows, cfg.target_table, cfg.mode)
|
|
return rows
|
|
|
|
def _ensure_table(self, conn, df: pd.DataFrame, cfg: SheetConfig):
|
|
schema, table_name = _split_table(cfg.target_table)
|
|
insp = inspect(conn)
|
|
if schema and schema not in insp.get_schema_names():
|
|
conn.execute(text(f"CREATE SCHEMA {schema}"))
|
|
logger.info("Created schema %r", schema)
|
|
if not insp.has_table(table_name, schema=schema):
|
|
meta = MetaData()
|
|
cols = build_columns(df, cfg.columns, self.config.default_varchar_length)
|
|
table = Table(table_name, meta, *cols, schema=schema)
|
|
meta.create_all(conn)
|
|
logger.info("Created table %r", cfg.target_table)
|
|
|
|
def _bulk_insert(self, conn, df: pd.DataFrame, table_name: str) -> int:
|
|
records = _df_to_records(df)
|
|
if not records:
|
|
return 0
|
|
schema, tname = _split_table(table_name)
|
|
meta = MetaData()
|
|
meta.reflect(bind=conn, schema=schema, only=[tname])
|
|
key = f"{schema}.{tname}" if schema else tname
|
|
table = meta.tables[key]
|
|
conn.execute(table.insert(), records)
|
|
return len(records)
|
|
|
|
def _upsert(self, conn, df: pd.DataFrame, cfg: SheetConfig) -> int:
|
|
dialect = self.engine.dialect.name
|
|
records = _df_to_records(df)
|
|
if not records:
|
|
return 0
|
|
|
|
schema, tname = _split_table(cfg.target_table)
|
|
meta = MetaData()
|
|
meta.reflect(bind=conn, schema=schema, only=[tname])
|
|
key = f"{schema}.{tname}" if schema else tname
|
|
table = meta.tables[key]
|
|
|
|
if dialect == "postgresql":
|
|
stmt = pg_insert(table).values(records)
|
|
update_cols = {c.key: stmt.excluded[c.key] for c in table.columns if c.key not in cfg.upsert_keys}
|
|
stmt = stmt.on_conflict_do_update(index_elements=cfg.upsert_keys, set_=update_cols)
|
|
conn.execute(stmt)
|
|
elif dialect == "oracle":
|
|
# Oracle MERGE via raw SQL
|
|
for record in records:
|
|
_oracle_merge(conn, table, record, cfg.upsert_keys)
|
|
else:
|
|
raise NotImplementedError(f"Upsert not implemented for dialect: {dialect}")
|
|
|
|
return len(records)
|
|
|
|
|
|
def _df_to_records(df: pd.DataFrame) -> list[dict]:
|
|
# Replace pandas NA/NaT with None so SQLAlchemy handles nulls correctly
|
|
return [
|
|
{k: (None if pd.isna(v) else v) for k, v in row.items()}
|
|
for row in df.to_dict(orient="records")
|
|
]
|
|
|
|
|
|
def _oracle_merge(conn, table: Table, record: dict, keys: list[str]):
|
|
key_clauses = " AND ".join(f"t.{k} = s.{k}" for k in keys)
|
|
all_cols = list(record.keys())
|
|
non_keys = [c for c in all_cols if c not in keys]
|
|
|
|
select_parts = ", ".join(f":{c} AS {c}" for c in all_cols)
|
|
update_parts = ", ".join(f"t.{c} = s.{c}" for c in non_keys)
|
|
insert_cols = ", ".join(all_cols)
|
|
insert_vals = ", ".join(f"s.{c}" for c in all_cols)
|
|
|
|
sql = f"""
|
|
MERGE INTO {table.name} t
|
|
USING (SELECT {select_parts} FROM dual) s
|
|
ON ({key_clauses})
|
|
WHEN MATCHED THEN UPDATE SET {update_parts}
|
|
WHEN NOT MATCHED THEN INSERT ({insert_cols}) VALUES ({insert_vals})
|
|
"""
|
|
conn.execute(text(sql), record)
|