Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[project]
name = "postgresql-charms-single-kernel"
description = "Shared and reusable code for PostgreSQL-related charms"
version = "16.1.0"
version = "16.1.1"
readme = "README.md"
license = "Apache-2.0"
authors = [
Expand Down
84 changes: 76 additions & 8 deletions single_kernel_postgresql/utils/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def __init__(self, message: Optional[str] = None):
self.message = message


class PostgreSQLUpdateUserError(PostgreSQLBaseError):
"""Exception raised when creating a user fails."""


class PostgreSQLUndefinedHostError(PostgreSQLBaseError):
"""Exception when host is not set."""

Expand Down Expand Up @@ -146,6 +150,10 @@ class PostgreSQLGetPostgreSQLVersionError(PostgreSQLBaseError):
"""Exception raised when retrieving PostgreSQL version fails."""


class PostgreSQLListDatabasesError(PostgreSQLBaseError):
"""Exception raised when retrieving the databases."""


class PostgreSQLListAccessibleDatabasesForUserError(PostgreSQLBaseError):
"""Exception raised when retrieving the accessible databases for a user fails."""

Expand Down Expand Up @@ -439,24 +447,36 @@ def _adjust_user_definition(
Returns:
A tuple containing the adjusted user definition and a list of additional statements.
"""
db_roles, connect_statements = self._adjust_user_roles(user, roles, database)
if db_roles:
str_roles = [f'"{role}"' for role in db_roles]
user_definition += f" IN ROLE {', '.join(str_roles)}"
return user_definition, connect_statements

def _adjust_user_roles(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split off from _adjust_user_definition() so that we can get just the roles.

self, user: str, roles: Optional[List[str]], database: Optional[str]
) -> Tuple[List[str], List[str]]:
"""Adjusts the user definition to include additional statements.

Returns:
A tuple containing the adjusted user definition and a list of additional statements.
"""
db_roles = []
connect_statements = []
if database:
if roles is not None and not any(
True
for role in roles
if role in [ROLE_STATS, ROLE_READ, ROLE_DML, ROLE_BACKUP, ROLE_DBA]
role in [ROLE_STATS, ROLE_READ, ROLE_DML, ROLE_BACKUP, ROLE_DBA] for role in roles
):
user_definition += f' IN ROLE "charmed_{database}_admin", "charmed_{database}_dml"'
db_roles.append(f"charmed_{database}_admin")
db_roles.append(f"charmed_{database}_dml")
else:
connect_statements.append(
SQL("GRANT CONNECT ON DATABASE {} TO {};").format(
Identifier(database), Identifier(user)
)
)
if roles is not None and any(
True
for role in roles
if role
role
in [
ROLE_STATS,
ROLE_READ,
Expand All @@ -466,14 +486,15 @@ def _adjust_user_definition(
ROLE_ADMIN,
ROLE_DATABASES_OWNER,
]
for role in roles
):
for system_database in ["postgres", "template1"]:
connect_statements.append(
SQL("GRANT CONNECT ON DATABASE {} TO {};").format(
Identifier(system_database), Identifier(user)
)
)
return user_definition, connect_statements
return db_roles, connect_statements

def _process_extra_user_roles(
self, user: str, extra_user_roles: Optional[List[str]] = None
Expand Down Expand Up @@ -1841,3 +1862,50 @@ def drop_hba_triggers(self) -> None:
finally:
if connection:
connection.close()

def list_databases(self, prefix: Optional[str] = None) -> List[str]:
"""List non-system databases starting with prefix."""
prefix_stmt = (
SQL(" AND datname LIKE {}").format(Literal(prefix + "%")) if prefix else SQL("")
)
try:
with self._connect_to_database() as connection, connection.cursor() as cursor:
cursor.execute(
SQL(
"SELECT datname FROM pg_database WHERE datistemplate = false AND datname <>'postgres'{};"
).format(prefix_stmt)
)
return [row[0] for row in cursor.fetchall()]
except psycopg2.Error as e:
raise PostgreSQLListDatabasesError() from e
finally:
if connection:
connection.close()

def add_user_to_databases(
self, user: str, databases: List[str], extra_user_roles: Optional[List[str]] = None
) -> None:
Comment on lines +1885 to +1887
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't alter the user, just add roles and connect privs.

"""Grant user access to database."""
try:
roles, _ = self._process_extra_user_roles(user, extra_user_roles)
connect_stmt = []
for database in databases:
db_roles, db_connect_stmt = self._adjust_user_roles(user, roles, database)
roles += db_roles
connect_stmt += db_connect_stmt
with self._connect_to_database() as connection, connection.cursor() as cursor:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated and out of curiosity, why not add commit control on the context manager, as a parameter of connection.cursor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely copy paste momentum. We should overhaul and move to psycopg3 in general.

cursor.execute(SQL("RESET ROLE;"))
cursor.execute(SQL("BEGIN;"))
cursor.execute(SQL("SET LOCAL log_statement = 'none';"))
cursor.execute(SQL("COMMIT;"))

# Add extra user roles to the new user.
for role in roles:
cursor.execute(
SQL("GRANT {} TO {};").format(Identifier(role), Identifier(user))
)
for statement in connect_stmt:
cursor.execute(statement)
except psycopg2.Error as e:
logger.error(f"Failed to create user: {e}")
raise PostgreSQLUpdateUserError() from e
109 changes: 109 additions & 0 deletions tests/unit/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
PostgreSQLCreateUserError,
PostgreSQLDatabasesSetupError,
PostgreSQLGetLastArchivedWALError,
PostgreSQLListDatabasesError,
PostgreSQLUndefinedHostError,
PostgreSQLUndefinedPasswordError,
PostgreSQLUpdateUserError,
ROLE_DATABASES_OWNER,
)
from single_kernel_postgresql.config.literals import Substrates
Expand Down Expand Up @@ -813,3 +815,110 @@ def test_set_up_database_k8s_skips_change_owner_and_chmod(harness):
# On K8S substrate we must not attempt to change ownership or chmod the path.
_change_owner.assert_not_called()
_chmod.assert_not_called()


def test_list_databases():
with patch(
"single_kernel_postgresql.utils.postgresql.PostgreSQL._connect_to_database",
) as _connect_to_database:
pg = PostgreSQL(
Substrates.VM, "primary", "current", "operator", "password", "postgres", None
)
execute = _connect_to_database.return_value.__enter__.return_value.cursor.return_value.__enter__.return_value.execute

# No prefix
pg.list_databases()
execute.assert_called_once_with(
Composed([
SQL(
"SELECT datname FROM pg_database WHERE datistemplate = false AND datname <>'postgres'"
),
SQL(""),
SQL(";"),
])
)
execute.reset_mock()

# With prefix
pg.list_databases(prefix="test")
execute.assert_called_once_with(
Composed([
SQL(
"SELECT datname FROM pg_database WHERE datistemplate = false AND datname <>'postgres'"
),
Composed([SQL(" AND datname LIKE "), Literal("test%")]),
SQL(";"),
])
)
execute.reset_mock()

# Exception
execute.side_effect = psycopg2.Error
with pytest.raises(PostgreSQLListDatabasesError):
pg.list_databases()
assert False


def test_add_user_to_databases():
with (
patch(
"single_kernel_postgresql.utils.postgresql.PostgreSQL._connect_to_database"
) as _connect_to_database,
patch(
"single_kernel_postgresql.utils.postgresql.PostgreSQL._process_extra_user_roles",
return_value=([], []),
),
):
pg = PostgreSQL(
Substrates.VM, "primary", "current", "operator", "password", "postgres", None
)
execute = _connect_to_database.return_value.__enter__.return_value.cursor.return_value.__enter__.return_value.execute

pg.add_user_to_databases("test-user", ["db1", "db2"])
assert execute.call_count == 8
execute.assert_any_call(SQL("RESET ROLE;"))
execute.assert_any_call(SQL("BEGIN;"))
execute.assert_any_call(SQL("SET LOCAL log_statement = 'none';"))
execute.assert_any_call(SQL("COMMIT;"))
execute.assert_any_call(
Composed([
SQL("GRANT "),
Identifier("charmed_db1_admin"),
SQL(" TO "),
Identifier("test-user"),
SQL(";"),
])
)
execute.assert_any_call(
Composed([
SQL("GRANT "),
Identifier("charmed_db1_dml"),
SQL(" TO "),
Identifier("test-user"),
SQL(";"),
])
)
execute.assert_any_call(
Composed([
SQL("GRANT "),
Identifier("charmed_db2_admin"),
SQL(" TO "),
Identifier("test-user"),
SQL(";"),
])
)
execute.assert_any_call(
Composed([
SQL("GRANT "),
Identifier("charmed_db2_dml"),
SQL(" TO "),
Identifier("test-user"),
SQL(";"),
])
)

# Exception
execute.side_effect = psycopg2.Error
with pytest.raises(PostgreSQLUpdateUserError):
pg.add_user_to_databases("test-user", ["db1", "db2"])
assert False
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.