Skip to content

Trino experiments #864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions ohsome_quality_api/api/request_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Field,
field_validator,
)
from pydantic.json_schema import SkipJsonSchema

from ohsome_quality_api.attributes.definitions import AttributeEnum
from ohsome_quality_api.topics.definitions import TopicEnum, get_topic_preset
Expand Down Expand Up @@ -68,6 +69,8 @@ class IndicatorRequest(BaseBpolys):
alias="topic",
)
include_figure: bool = True
# feature flag to use SQL queries against Trino instead of ohsome API
trino: SkipJsonSchema[bool] = False

@field_validator("topic")
@classmethod
Expand Down
1 change: 1 addition & 0 deletions ohsome_quality_api/api/response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class BaseResponse(BaseConfig):
attribution: dict[str, str] = {"url": ATTRIBUTION_URL}


# TODO: add sql_filter
class TopicMetadata(BaseConfig):
name: str
description: str
Expand Down
3 changes: 2 additions & 1 deletion ohsome_quality_api/attributes/attributes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ car-roads:
name: Road Name
description: TODO
filter: name=*
filter_sql: element_at (contributions.tags, 'name') IS NOT NULL
sidewalk:
name: Sidewalk
description: TODO
Expand Down Expand Up @@ -398,4 +399,4 @@ public-transport-stops:
departures-board:
name: Departures Board
description: TODO
filter: departures_board=*
filter: departures_board=*
53 changes: 41 additions & 12 deletions ohsome_quality_api/attributes/definitions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from enum import Enum
from functools import singledispatch
from typing import List

import yaml
Expand Down Expand Up @@ -52,19 +53,47 @@ def get_attribute_preset(topic_key: str) -> List[Attribute]:
) from error


def build_attribute_filter(attribute_key: List[str] | str, topic_key: str) -> str:
"""Build attribute filter for ohsome API query."""
attributes = get_attributes()
try:
if isinstance(attribute_key, str):
return get_topic_preset(topic_key).filter + " and (" + attribute_key + ")"
@singledispatch
def build_attribute_filter(attributes: str | list, topic_key: str, trino: bool) -> str:
raise NotImplementedError


@build_attribute_filter.register
def _(
attributes: list,
topic_key: str,
trino: bool = False,
) -> str:
"""Build attribute filter from attributes keys."""
if trino:
filter = get_topic_preset(topic_key).sql_filter
else:
filter = get_topic_preset(topic_key).filter

all_attributes = get_attributes()
for key in attributes:
if trino:
filter += " AND (" + all_attributes[topic_key][key].filter_sql + ")"
else:
attribute_filter = get_topic_preset(topic_key).filter
for key in attribute_key:
attribute_filter += " and (" + attributes[topic_key][key].filter + ")"
return attribute_filter
except KeyError as error:
raise KeyError("Invalid topic or attribute key(s).") from error
filter += " and (" + all_attributes[topic_key][key].filter + ")"

return filter


@build_attribute_filter.register
def _(
attributes: str,
topic_key: str,
trino: bool = False,
) -> str:
"""Build attribute filter from user given attribute filter."""
if trino:
topic_filter = get_topic_preset(topic_key).sql_filter
filter = topic_filter + " AND (" + attributes + ")"
else:
topic_filter = get_topic_preset(topic_key).filter
filter = topic_filter + " and (" + attributes + ")"
return filter


attribute_keys = {
Expand Down
1 change: 1 addition & 0 deletions ohsome_quality_api/attributes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class Attribute(BaseModel):
filter: str
name: str
description: str
filter_sql: str | None = None # TODO: Should not be optional
model_config = ConfigDict(
extra="forbid",
frozen=True,
Expand Down
81 changes: 65 additions & 16 deletions ohsome_quality_api/indicators/attribute_completeness/indicator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import logging
import os
from datetime import datetime, timezone
from string import Template

import dateutil.parser
import dateutil
import plotly.graph_objects as go
from geojson import Feature

Expand All @@ -11,7 +14,11 @@
)
from ohsome_quality_api.indicators.base import BaseIndicator
from ohsome_quality_api.ohsome import client as ohsome_client
from ohsome_quality_api.topics.models import BaseTopic as Topic
from ohsome_quality_api.topics.models import TopicDefinition as Topic
from ohsome_quality_api.trino import client as trino_client
from ohsome_quality_api.utils.helper_geo import get_bounding_box

WORKING_DIR = os.path.dirname(os.path.abspath(__file__))


class AttributeCompleteness(BaseIndicator):
Expand Down Expand Up @@ -46,20 +53,22 @@ def __init__(
attribute_keys: list[str] | None = None,
attribute_filter: str | None = None,
attribute_title: str | None = None,
trino: bool = False, # Feature flag to use SQL instead of ohsome API queries
) -> None:
super().__init__(topic=topic, feature=feature)
super().__init__(topic=topic, feature=feature, trino=trino)
self.threshold_yellow = 0.75
self.threshold_red = 0.25
self.attribute_keys = attribute_keys
self.attribute_filter = attribute_filter
self.attribute_title = attribute_title
self.absolute_value_1 = None
self.absolute_value_2 = None
self.absolute_value_1: int | None = None
self.absolute_value_2: int | None = None
self.description = None
if self.attribute_keys:
self.attribute_filter = build_attribute_filter(
self.attribute_keys,
self.topic.key,
self.trino,
)
self.attribute_title = ", ".join(
[
Expand All @@ -71,20 +80,60 @@ def __init__(
self.attribute_filter = build_attribute_filter(
self.attribute_filter,
self.topic.key,
self.trino,
)

async def preprocess(self) -> None:
# Get attribute filter
response = await ohsome_client.query(
self.topic,
self.feature,
attribute_filter=self.attribute_filter,
)
timestamp = response["ratioResult"][0]["timestamp"]
self.result.timestamp_osm = dateutil.parser.isoparse(timestamp)
self.result.value = response["ratioResult"][0]["ratio"]
self.absolute_value_1 = response["ratioResult"][0]["value"]
self.absolute_value_2 = response["ratioResult"][0]["value2"]
if self.trino:
file_path = os.path.join(WORKING_DIR, "query.sql")
with open(file_path, "r") as file:
sql_template = file.read()

bounding_box = get_bounding_box(self.feature)
geometry = json.dumps(self.feature["geometry"])

sql = sql_template.format(
bounding_box=bounding_box,
geometry=geometry,
filter=self.topic.sql_filter,
)
query = await trino_client.query(sql)
results = await trino_client.fetch(query)
# TODO: Check for None
self.absolute_value_1 = results[0][0]

sql = sql_template.format(
bounding_box=bounding_box,
geometry=geometry,
filter=self.attribute_filter,
)
query = await trino_client.query(sql)
results = await trino_client.fetch(query)
self.absolute_value_2 = results[0][0]

if self.absolute_value_1 is None and self.absolute_value_2 is None:
self.result.value = None
elif self.absolute_value_1 is None:
raise ValueError("Unreachable code.")
elif self.absolute_value_2 is None:
self.result.value = 0
else:
self.result.value = self.absolute_value_2 / self.absolute_value_1

# TODO: Query Trino for Timestamp
self.result.timestamp_osm = datetime.now(timezone.utc)
else:
# Get attribute filter
response = await ohsome_client.query(
self.topic,
self.feature,
attribute_filter=self.attribute_filter,
)
timestamp = response["ratioResult"][0]["timestamp"]
self.result.timestamp_osm = dateutil.parser.isoparse(timestamp)
self.result.value = response["ratioResult"][0]["ratio"]
self.absolute_value_1 = response["ratioResult"][0]["value"]
self.absolute_value_2 = response["ratioResult"][0]["value2"]

def calculate(self) -> None:
# result (ratio) can be NaN if no features matching filter1
Expand Down
20 changes: 20 additions & 0 deletions ohsome_quality_api/indicators/attribute_completeness/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
WITH bpoly AS (
SELECT
to_geometry (from_geojson_geometry ('{geometry}')) AS geometry
)
SELECT
Sum(
CASE WHEN ST_Within (ST_GeometryFromText (contributions.geometry), bpoly.geometry) THEN
length
ELSE
Cast(ST_Length (ST_Intersection (ST_GeometryFromText (contributions.geometry), bpoly.geometry)) AS integer)
END) AS length
FROM
bpoly,
sotm2024_iceberg.geo_sort.contributions AS contributions
WHERE
status = 'latest'
AND ST_Intersects (bpoly.geometry, ST_GeometryFromText (contributions.geometry))
AND {filter}
AND (bbox.xmax <= {bounding_box.lon_max} AND bbox.xmin >= {bounding_box.lon_min})
AND (bbox.ymax <= {bounding_box.lat_max} AND bbox.ymin >= {bounding_box.lat_min})
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import re


def translate_filter_to_sql(filter_str, tags_column="contributions.tags"):
"""
Translates an ohsome filter string into an SQL filter format.
"""

parts = re.split(r"(\(|\))", filter_str)

result = []
for part in parts:
if part not in ("(", ")"):
split_part = re.split(r"\b(and|or)\b", part)
result.extend([p.strip() for p in split_part if p.strip()])
else:
result.append(part)

sql_parts = []
for substring in result:
if substring in ("(", ")", "and", "or"):
sql_parts.append(substring)
else:
sql_parts.append(substring_converter(substring))

for substring in sql_parts:
if "osm_type = " in substring:
sql_parts.remove(substring)
sql_parts.insert(0, substring)

if sql_parts[-1] in ("and", "or"):
sql_parts = sql_parts[:-1]

result_sql = " ".join(sql_parts)

return result_sql


def substring_converter(substring, tags_column="contributions.tags"):
# TODO: for all re.compile check if last group can be removed (artifact)

# key IN
match = re.search(r"([\w\S]+)\s+in(?:\s|$)", substring)
if match:
key = match.group(1)
return (
f"element_at({tags_column},'{key}') IS NOT NULL AND "
f"{tags_column}['{key}'] IN "
)

type_match = re.search(r"type:([\w:]+)", substring)
if type_match:
key = type_match.group(1)
return f"osm_type = '{key}' AND "

geometry_match = re.search(r"geometry:([\w:]+)", substring)
if geometry_match:
key = geometry_match.group(1)
return f"osm_type={key} AND "

# Match key=value expressions
equal_pattern = re.compile(r"(\w+)=([\w:]+)(?:\s+|\)\s+)?(\w+)?", re.IGNORECASE)
for match in equal_pattern.finditer(substring):
key, value = match.groups()
return (
f"element_at({tags_column},'{key}') IS NOT NULL AND "
f"{tags_column}['{key}'] = '{value}'"
)

# Match key=*
exists_pattern = re.compile(r"(\w+)=\*(?:\s+|\)\s+)?(\w+)?", re.IGNORECASE)
for match in exists_pattern.finditer(substring):
key = match.group(1)
return f"element_at({tags_column}, '{key}') IS NOT NULL "

# Match key!=*
not_exists_pattern = re.compile(r"(\w+)!=\*(?:\s+|\)\s+)?(\w+)?", re.IGNORECASE)
for match in not_exists_pattern.finditer(substring):
key = match.group(1)
return f"element_at({tags_column}, '{key}') IS NULL"

# Match key!=value expressions
not_exists_pattern = re.compile(
r"(\w+)!=([\w:]+)(?:\s+|\)\s+)?(\w+)?", re.IGNORECASE
)
for match in not_exists_pattern.finditer(substring):
key, value = match.groups()
return f"{tags_column}['{key}'] != '{value}'"

# list of values for "key IN (...)"
if re.fullmatch(r"^[^,]+(?:\s*,\s*[^,]+)*$", substring):
formatted_string = ""
values = re.split(r",\s*", substring)
for value in values:
formatted_string = formatted_string + f"'{value}', "
# remove last ", "
return formatted_string[:-2]

else:
raise (ValueError)
11 changes: 10 additions & 1 deletion ohsome_quality_api/indicators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
self,
topic: Topic,
feature: Feature,
trino: bool = False,
) -> None:
self.metadata: IndicatorMetadata = get_indicator(
camel_to_hyphen(type(self).__name__)
Expand All @@ -39,6 +40,11 @@ def __init__(
self.result: Result = Result(
description=self.templates.label_description["undefined"],
)
self.trino: bool = trino
if self.trino and self.topic.sql_filter is None:
raise ValueError(
"No SQL query found to run against Trino for topic: " + self.topic.name
)
self._get_default_figure()

def as_dict(self, include_data: bool = False, exclude_label: bool = False) -> dict:
Expand All @@ -50,7 +56,10 @@ def as_dict(self, include_data: bool = False, exclude_label: bool = False) -> di
"metadata": self.metadata.model_dump(by_alias=True),
"topic": self.topic.model_dump(
by_alias=True,
exclude={"ratio_filter"},
exclude={
"ratio_filter",
"sql_filter",
}, # TODO: do not exclude SQL filter
),
"result": result,
**self.feature.properties,
Expand Down
Loading