Skip to content

Commit

Permalink
Merge pull request #2498 from moj-analytical-services/improve_compare…
Browse files Browse the repository at this point in the history
…_two_records

Improve compare two records
  • Loading branch information
RobinL authored Nov 13, 2024
2 parents 2f67811 + 5e9a69b commit fff3433
Show file tree
Hide file tree
Showing 7 changed files with 769 additions and 51 deletions.
7 changes: 4 additions & 3 deletions splink/internals/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

if TYPE_CHECKING:
from splink.internals.linker import Linker
from splink.internals.settings import Settings


def truth_space_table_from_labels_with_predictions_sqls(
Expand Down Expand Up @@ -289,8 +290,8 @@ def truth_space_table_from_labels_with_predictions_sqls(
return sqls


def _select_found_by_blocking_rules(linker: "Linker") -> str:
brs = linker._settings_obj._blocking_rules_to_generate_predictions
def _select_found_by_blocking_rules(settings_obj: "Settings") -> str:
brs = settings_obj._blocking_rules_to_generate_predictions

if brs:
br_strings = [
Expand Down Expand Up @@ -425,7 +426,7 @@ def predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename):
)

sqls.extend(sqls_2)
br_col = _select_found_by_blocking_rules(linker)
br_col = _select_found_by_blocking_rules(linker._settings_obj)

sql = f"""
select *, {br_col}
Expand Down
110 changes: 73 additions & 37 deletions splink/internals/linker_components/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from typing import TYPE_CHECKING, Any

from splink.internals.accuracy import _select_found_by_blocking_rules
from splink.internals.blocking import (
BlockingRule,
block_using_rules_sqls,
Expand Down Expand Up @@ -639,16 +640,31 @@ def find_matches_to_new_records(
return predictions

def compare_two_records(
self, record_1: dict[str, Any], record_2: dict[str, Any]
self,
record_1: dict[str, Any] | AcceptableInputTableType,
record_2: dict[str, Any] | AcceptableInputTableType,
include_found_by_blocking_rules: bool = False,
) -> SplinkDataFrame:
"""Use the linkage model to compare and score a pairwise record comparison
based on the two input records provided
based on the two input records provided.
If your inputs contain multiple rows, scores for the cartesian product of
the two inputs will be returned.
If your inputs contain hardcoded term frequency columns (e.g.
a tf_first_name column), then these values will be used instead of any
provided term frequency lookup tables. or term frequency values derived
from the input data.
Args:
record_1 (dict): dictionary representing the first record. Columns names
and data types must be the same as the columns in the settings object
record_2 (dict): dictionary representing the second record. Columns names
and data types must be the same as the columns in the settings object
include_found_by_blocking_rules (bool, optional): If True, outputs a column
indicating whether the record pair would have been found by any of the
blocking rules specified in
settings.blocking_rules_to_generate_predictions. Defaults to False.
Examples:
```py
Expand Down Expand Up @@ -683,30 +699,39 @@ def compare_two_records(
SplinkDataFrame: Pairwise comparison with scored prediction
"""

cache = self._linker._intermediate_table_cache
linker = self._linker

retain_matching_columns = linker._settings_obj._retain_matching_columns
retain_intermediate_calculation_columns = (
linker._settings_obj._retain_intermediate_calculation_columns
)
linker._settings_obj._retain_matching_columns = True
linker._settings_obj._retain_intermediate_calculation_columns = True

cache = linker._intermediate_table_cache

uid = ascii_uid(8)

# Check if input is a DuckDB relation without importing DuckDB
if isinstance(record_1, dict):
to_register_left = [record_1]
to_register_left: AcceptableInputTableType = [record_1]
else:
to_register_left = record_1

if isinstance(record_2, dict):
to_register_right = [record_2]
to_register_right: AcceptableInputTableType = [record_2]
else:
to_register_right = record_2

df_records_left = self._linker.table_management.register_table(
df_records_left = linker.table_management.register_table(
to_register_left,
f"__splink__compare_two_records_left_{uid}",
overwrite=True,
)

df_records_left.templated_name = "__splink__compare_two_records_left"

df_records_right = self._linker.table_management.register_table(
df_records_right = linker.table_management.register_table(
to_register_right,
f"__splink__compare_two_records_right_{uid}",
overwrite=True,
Expand All @@ -719,7 +744,9 @@ def compare_two_records(
nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf")
pipeline.append_input_dataframe(nodes_with_tf)

for tf_col in self._linker._settings_obj._term_frequency_columns:
tf_cols = linker._settings_obj._term_frequency_columns

for tf_col in tf_cols:
tf_table_name = colname_to_tf_tablename(tf_col)
if tf_table_name in cache:
tf_table = cache.get_with_logging(tf_table_name)
Expand All @@ -734,67 +761,76 @@ def compare_two_records(
)

sql_join_tf = _join_new_table_to_df_concat_with_tf_sql(
self._linker, "__splink__compare_two_records_left"
linker, "__splink__compare_two_records_left", df_records_left
)

pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_left_with_tf")

sql_join_tf = _join_new_table_to_df_concat_with_tf_sql(
self._linker, "__splink__compare_two_records_right"
linker, "__splink__compare_two_records_right", df_records_right
)

pipeline.enqueue_sql(sql_join_tf, "__splink__compare_two_records_right_with_tf")

source_dataset_ic = (
self._linker._settings_obj.column_info_settings.source_dataset_input_column
)
uid_ic = self._linker._settings_obj.column_info_settings.unique_id_input_column

pipeline = add_unique_id_and_source_dataset_cols_if_needed(
self._linker,
linker,
df_records_left,
pipeline,
in_tablename="__splink__compare_two_records_left_with_tf",
out_tablename="__splink__compare_two_records_left_with_tf_uid_fix",
uid_str="_left",
)
pipeline = add_unique_id_and_source_dataset_cols_if_needed(
self._linker,
linker,
df_records_right,
pipeline,
in_tablename="__splink__compare_two_records_right_with_tf",
out_tablename="__splink__compare_two_records_right_with_tf_uid_fix",
uid_str="_right",
)

sqls = block_using_rules_sqls(
input_tablename_l="__splink__compare_two_records_left_with_tf_uid_fix",
input_tablename_r="__splink__compare_two_records_right_with_tf_uid_fix",
blocking_rules=[BlockingRule("1=1")],
link_type=self._linker._settings_obj._link_type,
source_dataset_input_column=source_dataset_ic,
unique_id_input_column=uid_ic,
)
pipeline.enqueue_list_of_sqls(sqls)
cols_to_select = self._linker._settings_obj._columns_to_select_for_blocking

sqls = compute_comparison_vector_values_from_id_pairs_sqls(
self._linker._settings_obj._columns_to_select_for_blocking,
self._linker._settings_obj._columns_to_select_for_comparison_vector_values,
input_tablename_l="__splink__compare_two_records_left_with_tf_uid_fix",
input_tablename_r="__splink__compare_two_records_right_with_tf_uid_fix",
source_dataset_input_column=source_dataset_ic,
unique_id_input_column=uid_ic,
select_expr = ", ".join(cols_to_select)
sql = f"""
select {select_expr}, 0 as match_key
from __splink__compare_two_records_left_with_tf_uid_fix as l
cross join __splink__compare_two_records_right_with_tf_uid_fix as r
"""
pipeline.enqueue_sql(sql, "__splink__compare_two_records_blocked")

cols_to_select = (
linker._settings_obj._columns_to_select_for_comparison_vector_values
)
pipeline.enqueue_list_of_sqls(sqls)
select_expr = ", ".join(cols_to_select)
sql = f"""
select {select_expr}
from __splink__compare_two_records_blocked
"""
pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors")

sqls = predict_from_comparison_vectors_sqls_using_settings(
self._linker._settings_obj,
sql_infinity_expression=self._linker._infinity_expression,
linker._settings_obj,
sql_infinity_expression=linker._infinity_expression,
)
pipeline.enqueue_list_of_sqls(sqls)

predictions = self._linker._db_api.sql_pipeline_to_splink_dataframe(
if include_found_by_blocking_rules:
br_col = _select_found_by_blocking_rules(linker._settings_obj)
sql = f"""
select *, {br_col}
from __splink__df_predict
"""

pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules")

predictions = linker._db_api.sql_pipeline_to_splink_dataframe(
pipeline, use_cache=False
)

linker._settings_obj._retain_matching_columns = retain_matching_columns
linker._settings_obj._retain_intermediate_calculation_columns = (
retain_intermediate_calculation_columns
)

return predictions
142 changes: 142 additions & 0 deletions splink/internals/realtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

from splink.internals.accuracy import _select_found_by_blocking_rules
from splink.internals.database_api import AcceptableInputTableType, DatabaseAPISubClass
from splink.internals.misc import ascii_uid
from splink.internals.pipeline import CTEPipeline
from splink.internals.predict import (
predict_from_comparison_vectors_sqls_using_settings,
)
from splink.internals.settings_creator import SettingsCreator
from splink.internals.splink_dataframe import SplinkDataFrame


class SQLCache:
def __init__(self):
self._cache: dict[int, tuple[str, str | None]] = {}

def get(self, settings_id: int, new_uid: str) -> str | None:
if settings_id not in self._cache:
return None

sql, cached_uid = self._cache[settings_id]
if cached_uid:
sql = sql.replace(cached_uid, new_uid)
return sql

def set(self, settings_id: int, sql: str | None, uid: str | None) -> None:
if sql is not None:
self._cache[settings_id] = (sql, uid)


_sql_cache = SQLCache()


def compare_records(
record_1: dict[str, Any] | AcceptableInputTableType,
record_2: dict[str, Any] | AcceptableInputTableType,
settings: SettingsCreator | dict[str, Any] | Path | str,
db_api: DatabaseAPISubClass,
use_sql_from_cache: bool = True,
include_found_by_blocking_rules: bool = False,
) -> SplinkDataFrame:
"""Compare two records and compute similarity scores without requiring a Linker.
Assumes any required term frequency values are provided in the input records.
Args:
record_1 (dict): First record to compare
record_2 (dict): Second record to compare
db_api (DatabaseAPISubClass): Database API to use for computations
Returns:
SplinkDataFrame: Comparison results
"""
global _sql_cache

uid = ascii_uid(8)

if isinstance(record_1, dict):
to_register_left: AcceptableInputTableType = [record_1]
else:
to_register_left = record_1

if isinstance(record_2, dict):
to_register_right: AcceptableInputTableType = [record_2]
else:
to_register_right = record_2

df_records_left = db_api.register_table(
to_register_left,
f"__splink__compare_records_left_{uid}",
overwrite=True,
)
df_records_left.templated_name = "__splink__compare_records_left"

df_records_right = db_api.register_table(
to_register_right,
f"__splink__compare_records_right_{uid}",
overwrite=True,
)
df_records_right.templated_name = "__splink__compare_records_right"

settings_id = id(settings)
if use_sql_from_cache:
if cached_sql := _sql_cache.get(settings_id, uid):
return db_api._sql_to_splink_dataframe(
cached_sql,
templated_name="__splink__realtime_compare_records",
physical_name=f"__splink__realtime_compare_records_{uid}",
)

if not isinstance(settings, SettingsCreator):
settings_creator = SettingsCreator.from_path_or_dict(settings)
else:
settings_creator = settings

settings_obj = settings_creator.get_settings(db_api.sql_dialect.sql_dialect_str)

settings_obj._retain_matching_columns = True
settings_obj._retain_intermediate_calculation_columns = True

pipeline = CTEPipeline([df_records_left, df_records_right])

cols_to_select = settings_obj._columns_to_select_for_blocking

select_expr = ", ".join(cols_to_select)
sql = f"""
select {select_expr}, 0 as match_key
from __splink__compare_records_left as l
cross join __splink__compare_records_right as r
"""
pipeline.enqueue_sql(sql, "__splink__compare_two_records_blocked")

cols_to_select = settings_obj._columns_to_select_for_comparison_vector_values
select_expr = ", ".join(cols_to_select)
sql = f"""
select {select_expr}
from __splink__compare_two_records_blocked
"""
pipeline.enqueue_sql(sql, "__splink__df_comparison_vectors")

sqls = predict_from_comparison_vectors_sqls_using_settings(
settings_obj,
sql_infinity_expression=db_api.sql_dialect.infinity_expression,
)
pipeline.enqueue_list_of_sqls(sqls)

if include_found_by_blocking_rules:
br_col = _select_found_by_blocking_rules(settings_obj)
sql = f"""
select *, {br_col}
from __splink__df_predict
"""

pipeline.enqueue_sql(sql, "__splink__found_by_blocking_rules")

predictions = db_api.sql_pipeline_to_splink_dataframe(pipeline)
_sql_cache.set(settings_id, predictions.sql_used_to_create, uid)

return predictions
Loading

0 comments on commit fff3433

Please sign in to comment.