Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run streamlit dashboard on docker container #1571

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
database_prefix: Optional[str] = None,
database_args: Optional[Dict[str, Any]] = None,
database_check_revision: bool = True,
host: Optional[str] = None,
):
connection_parameters = {
"account": account,
Expand All @@ -56,6 +57,7 @@ def __init__(
"schema": schema,
"warehouse": warehouse,
"role": role,
**({"host": host} if host else {}),
}

if snowpark_session is None:
Expand Down
4 changes: 3 additions & 1 deletion src/dashboard/trulens/dashboard/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def run_dashboard(
address: Optional[str] = None,
force: bool = False,
_dev: Optional[Path] = None,
spcs_runtime: Optional[bool] = False,
_watch_changes: bool = False,
) -> Process:
"""Run a streamlit dashboard to view logged results and apps.
Expand Down Expand Up @@ -120,7 +121,8 @@ def run_dashboard(
"--database-prefix",
session.connector.db.table_prefix,
]

if spcs_runtime:
args.append("--spcs-runtime")
proc = subprocess.Popen(
args,
stdout=subprocess.PIPE,
Expand Down
35 changes: 32 additions & 3 deletions src/dashboard/trulens/dashboard/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ class FeedbackDisplay(BaseModel):
icon: str


def get_spcs_login_token():
"""
Read the login token supplied automatically by Snowflake. These tokens
are short lived and should always be read right before creating any new connection.
"""
with open("/snowflake/session/token", "r") as f:
return f.read()


def init_from_args():
"""Parse command line arguments and initialize Tru with them.

Expand All @@ -43,6 +52,7 @@ def init_from_args():
parser.add_argument(
"--database-prefix", default=core_db.DEFAULT_DATABASE_PREFIX
)
parser.add_argument("--spcs-runtime", default=False)

try:
args = parser.parse_args()
Expand All @@ -54,9 +64,28 @@ def init_from_args():
# so we have to do a hard exit.
sys.exit(e.code)

core_session.TruSession(
database_url=args.database_url, database_prefix=args.database_prefix
)
if args.spcs_runtime:
import os

from snowflake.snowpark import Session
from trulens.connectors.snowflake import SnowflakeConnector

connection_params = {
"account": os.environ.get("SNOWFLAKE_ACCOUNT"),
"host": os.getenv("SNOWFLAKE_HOST"),
"authenticator": "oauth",
"token": get_spcs_login_token(),
"warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"),
"database": os.environ.get("SNOWFLAKE_DATABASE"),
"schema": os.environ.get("SNOWFLAKE_SCHEMA"),
}
snowpark_session = Session.builder.configs(connection_params).create()
connector = SnowflakeConnector(snowpark_session=snowpark_session)
core_session.TruSession(connector=connector)
else:
core_session.TruSession(
database_url=args.database_url, database_prefix=args.database_prefix
)


def trulens_leaderboard(app_ids: Optional[List[str]] = None):
Expand Down
12 changes: 12 additions & 0 deletions tools/snowflake/spcs_dashboard/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
ARG BASE_IMAGE=python:3.11.9-slim-bullseye
FROM $BASE_IMAGE

COPY ./ /trulens_dashboard/

WORKDIR /trulens_dashboard

RUN pip install -r requirements.txt
RUN pip install trulens_connectors_snowflake-1.0.1-py3-none-any.whl
RUN pip install trulens_dashboard-1.0.1-py3-none-any.whl
Comment on lines +9 to +10
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is temporary, once the next version of trulens is released, then this can go in requirments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check with corey whether requirements.txt can be done using poetry instead.


CMD ["python", "run_dashboard.py"]
9 changes: 9 additions & 0 deletions tools/snowflake/spcs_dashboard/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
python-dotenv
pydantic
snowflake[ml]
snowflake-connector-python
snowflake-sqlalchemy
trulens
trulens-connectors-snowflake
# trulens-dashboard
# trulens-feedback
145 changes: 145 additions & 0 deletions tools/snowflake/spcs_dashboard/run_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from argparse import ArgumentParser

from snowflake.snowpark import Session

# get args from command line
parser = ArgumentParser(description="Run container script")
parser.add_argument(
"--build-docker",
action="store_true",
help="Build and push the Docker container",
)
args = parser.parse_args()

session = Session.builder.create()
account = session.get_current_account()
user = session.get_current_user()
database = session.get_current_database()
schema = session.get_current_schema()
warehouse = session.get_current_warehouse()
role = session.get_current_role()


def run_sql_command(command: str):
print(f"Running SQL command: {command}")
result = session.sql(command).collect()
print(f"Result: {result}")
return result


# Check if the image repository exists, if not create it
repository_name = "TRULENS_REPOSITORY"
images = session.sql("SHOW IMAGE REPOSITORIES").collect()
repository_exists = any(image["name"] == repository_name for image in images)

if not repository_exists:
session.sql(f"CREATE IMAGE REPOSITORY {repository_name}").collect()
print(f"Image repository {repository_name} created.")
else:
print(f"Image repository {repository_name} already exists.")

# Retrieve the repository URL
repository_url = (
session.sql(f"SHOW IMAGE REPOSITORIES LIKE '{repository_name}'")
.select('"repository_url"')
.collect()[0]["repository_url"]
)

image_name = "trulens_dashboard"
image_tag = "latest"
app_name = "trulens_dashboard"
container_name = app_name + "_container"
if args.build_docker:
# local build, with docker
import subprocess

subprocess.run(
[
"docker",
"build",
"--platform",
"linux/amd64",
"-t",
f"{repository_url}/{image_name}:{image_tag}",
".",
],
check=True,
)
subprocess.run(
["docker", "push", f"{repository_url}/{image_name}:{image_tag}"],
check=True,
)


# Create compute pool if it does not exist
compute_pool_name = input("Enter compute pool name: ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using input() for getting the compute pool name is not suitable for automated scripts or production environments. Consider using environment variables or command-line arguments instead.

compute_pools = session.sql("SHOW COMPUTE POOLS").collect()
compute_pool_exists = any(
pool["name"] == compute_pool_name.upper() for pool in compute_pools
)
if compute_pool_exists:
print(f"Compute pool {compute_pool_name} already exists")
else:
session.sql(
f"CREATE COMPUTE POOL {compute_pool_name} MIN_NODES = 1 MAX_NODES = 1 INSTANCE_FAMILY = CPU_X64_M"
).collect()
session.sql(f"DESCRIBE COMPUTE POOL {compute_pool_name}").collect()

# Create network rule
network_rule_name = f"{compute_pool_name}_allow_http_https"
session.sql(
f"CREATE OR REPLACE NETWORK RULE {network_rule_name} TYPE = 'HOST_PORT' MODE = 'EGRESS' VALUE_LIST = ('0.0.0.0:443','0.0.0.0:80')"
).collect()
session.sql("SHOW NETWORK RULES").collect()

# Create external access integration
access_integration_name = f"{compute_pool_name}_access_integration"
session.sql(
f"CREATE OR REPLACE EXTERNAL ACCESS INTEGRATION {access_integration_name} ALLOWED_NETWORK_RULES = ({network_rule_name}) ENABLED = true"
).collect()
session.sql("SHOW EXTERNAL ACCESS INTEGRATIONS").collect()

service_name = compute_pool_name + "_trulens_dashboard"
session.sql(
"""
CREATE SERVICE {service_name}
IN COMPUTE POOL {compute_pool_name}
EXTERNAL_ACCESS_INTEGRATIONS = ({access_integration_name})
FROM SPECIFICATION $$
spec:
containers:
- name: trulens-dashboard
image: /{database}/{schema}/{repository_name}/{app_name}:latest
env:
SNOWFLAKE_ACCOUNT: "{account}"
SNOWFLAKE_DATABASE: "{database}"
SNOWFLAKE_SCHEMA: "{schema}"
SNOWFLAKE_WAREHOUSE: "{warehouse}"
SNOWFLAKE_ROLE: "{role}"
RUN_DASHBOARD: "1"
endpoints:
- name: trulens-demo-dashboard-endpoint
port: 8484
public: true
$$
""".format(
service_name=service_name,
compute_pool_name=compute_pool_name,
access_integration_name=access_integration_name,
repository_name=repository_name,
account=account,
database=database,
schema=schema,
warehouse=warehouse,
role=role,
app_name=app_name,
)
).collect()

# Show services and get their status
run_sql_command(f"SHOW ENDPOINTS IN SERVICE {service_name}")
run_sql_command(f"CALL SYSTEM$GET_SERVICE_STATUS('{service_name}')")
run_sql_command(f"CALL SYSTEM$GET_SERVICE_STATUS('{service_name}')")

# Close the session
session.close()
33 changes: 33 additions & 0 deletions tools/snowflake/spcs_dashboard/run_dashboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

from snowflake.snowpark import Session
from trulens.connectors.snowflake import SnowflakeConnector
from trulens.core import TruSession
from trulens.dashboard import run_dashboard


def get_login_token():
"""
Read the login token supplied automatically by Snowflake. These tokens
are short lived and should always be read right before creating any new connection.
"""
with open("/snowflake/session/token", "r") as f:
return f.read()


connection_params = {
"account": os.environ.get("SNOWFLAKE_ACCOUNT"),
"host": os.getenv("SNOWFLAKE_HOST"),
"authenticator": "oauth",
"token": get_login_token(),
"warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"),
"database": os.environ.get("SNOWFLAKE_DATABASE"),
"schema": os.environ.get("SNOWFLAKE_SCHEMA"),
}
snowpark_session = Session.builder.configs(connection_params).create()

connector = SnowflakeConnector(snowpark_session=snowpark_session)
tru_session = TruSession(connector=connector)
tru_session.get_records_and_feedback()

run_dashboard(tru_session, port=8484, spcs_runtime=True)
Loading