Channels
In this section we will add basic channel support:
- Joining and leaving channels
- Sending messages to channels
We will also cover automatically decoding the payload data based on method argument type-hints.
See resulting code on GitHub
Shared code
Let's add a channel
property to the Message
class. It will contain the name of the channel the message is intended for.
class Message:
...
channel: Optional[str] = None
Server side
Data-classes
We will add functionality to store the channel state. Add the following fields to the ChatData
class:
from asyncio import Queue
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Set
from weakref import WeakSet
@dataclass(frozen=True)
class ChatData:
...
channel_users: Dict[str, Set[SessionId]] = field(default_factory=lambda: defaultdict(WeakSet))
channel_messages: Dict[str, Queue] = field(default_factory=lambda: defaultdict(Queue))
In the channel_users
dict, the keys are channel names, and the value is a set of user session ids. A WeakSet is used to automatically remove logged-out users.
In the channel_messages
dict, the keys are the channel names, and the value is a Queue of messages sent by users to the channel.
Helper methods
Next, we will define some helper methods for managing channel messages:
ensure_channel_exists
: initialize the data for a new channel if it doesn't exist.channel_message_delivery
: an asyncio task which will deliver channel messages to all the users in a channel.
def ensure_channel_exists(channel_name: str):
if channel_name not in chat_data.channel_users:
chat_data.channel_users[channel_name] = WeakSet()
chat_data.channel_messages[channel_name] = Queue()
asyncio.create_task(channel_message_delivery(channel_name))
If the channel doesn't exist yet (Line 2) It will be added to the channel_users
and channel_messages
dictionaries.
Line 5 starts an asyncio task (described below) which will deliver messages sent to the channel, to the channel's users.
async def channel_message_delivery(channel_name: str):
while True:
try:
message = await chat_data.channel_messages[channel_name].get()
for session_id in chat_data.channel_users[channel_name]:
user_specific_message = Message(user=message.user,
content=message.content)
chat_data.user_session_by_id[session_id].messages.put_nowait(user_specific_message)
except Exception as exception:
logging.error(str(exception), exc_info=True)
The above method will loop infinitely and watch the channel_messages
queue of the specified
channel (Line 8). Upon receiving a message, it will be delivered to all the users in the channel (Lines 9-13).
The final helper will look up username by session id:
def find_username_by_session(session_id: SessionId) -> Optional[str]:
session = chat_data.user_session_by_id.get(session_id)
if session is None:
return None
return session.username
Join/Leave Channel
Now let's add the channel join/leave handling request-response endpoints.
class ChatUserSession:
def router_factory(self):
router = RequestRouter()
@router.response('channel.join')
async def join_channel(payload: Payload) -> Awaitable[Payload]:
channel_name = payload.data.decode('utf-8')
ensure_channel_exists(channel_name)
chat_data.channel_users[channel_name].add(self._session.session_id)
return create_response()
@router.response('channel.leave')
async def leave_channel(payload: Payload) -> Awaitable[Payload]:
channel_name = payload.data.decode('utf-8')
chat_data.channel_users[channel_name].discard(self._session.session_id)
return create_response()
Send channel message
Next we add the ability to send channel message. We will modify the send_message
method:
class ChatUserSession:
def router_factory(self):
router = RequestRouter()
@router.response('message')
async def send_message(payload: Payload) -> Awaitable[Payload]:
message = Message(**json.loads(payload.data))
logging.info('Received message for user: %s, channel: %s', message.user, message.channel)
target_message = Message(self._session.username, message.content, message.channel)
if message.channel is not None:
await chat_data.channel_messages[message.channel].put(target_message)
elif message.user is not None:
session = find_session_by_username(message.user)
await session.messages.put(target_message)
return create_response()
Lines 16-20 decide whether it is a private message or a channel message, and add it to the relevant queue.
List channels
class ChatUserSession:
def router_factory(self):
router = RequestRouter()
@router.stream('channels')
async def get_channels() -> Publisher:
count = len(chat_data.channel_messages)
generator = ((Payload(ensure_bytes(channel)), index == count) for (index, channel) in
enumerate(chat_data.channel_messages.keys(), 1))
return StreamFromGenerator(lambda: generator)
Lines 6-11 define an endpoint for getting a list of channels. It uses the StreamFromGenerator
helper. Note that the argument to this class
is a factory method for the generator, not the generator itself.
Get channel users
class ChatUserSession:
def router_factory(self):
router = RequestRouter()
@router.stream('channel.users')
async def get_channel_users(payload: Payload) -> Publisher:
channel_name = utf8_decode(payload.data)
if channel_name not in chat_data.channel_users:
return EmptyStream()
count = len(chat_data.channel_users[channel_name])
generator = ((Payload(ensure_bytes(find_username_by_session(session_id))), index == count) for
(index, session_id) in
enumerate(chat_data.channel_users[channel_name], 1))
return StreamFromGenerator(lambda: generator)
Lines 6-11 define an endpoint for getting a list of users in a given channel. The find_username_by_session
helper method is used to
convert the session ids to usernames.
If the channel does not exist (Line 10) the EmptyStream
helper can be used as a response.
Simplify route methods
Up until now, all routed methods received the Payload
as an argument and extracted and decoded the data property.
There is a simpler method for doing this. By passing a payload_mapper
to the RequestRouter and specifying a
type-hint on the route method argument, the method arguments will be automatically decoded.
We will modify the code and add this functionality.
In the shared
module add the following helper method, which will be used as the payload_mapper
:
from rsocket.payload import Payload
from rsocket.helpers import utf8_decode
def decode_payload(cls, payload: Payload):
data = payload.data
if cls is bytes:
return data
if cls is str:
return utf8_decode(data)
return decode_dataclass(data, cls)
This method converts the payload data according to some type hints. It assumes that the payload data can be converted to either:
bytes
, in which case no transformation is applied (Lines 7-8)str
, in which case a utf8 decode is applied (Lines 10-11)- A
dataclass
which, which is decoded using a previously define helper (See Private Message) (Line 13)
Next we will pass this method as an argument to the RequestRouter
:
router = RequestRouter(payload_mapper=decode_payload)
Once this is done, the signature of the routed methods may be simplified. For example the join_channel
method can be simplified to:
@router.response('channel.join')
async def join_channel(channel_name: str) -> Awaitable[Payload]:
# channel_name = payload.data.decode('utf-8') # this can be removed
...
and the send_message
method can be simplified to:
@router.response('message')
async def send_message(message: Message) -> Awaitable[Payload]:
# message = Message(**json.loads(payload.data)) # this can be removed
logging.info('Received message for user: %s, channel: %s', message.user, message.channel)
...
Client side
Channel requests
We will add the methods on the ChatClient
to interact with the new server functionality:
from typing import List
from rsocket.awaitable.awaitable_rsocket import AwaitableRSocket
from rsocket.extensions.helpers import composite, route
from rsocket.frame_helpers import ensure_bytes
from rsocket.payload import Payload
from rsocket.helpers import utf8_decode
from shared import encode_dataclass
class ChatClient:
async def join(self, channel_name: str):
request = Payload(ensure_bytes(channel_name), composite(route('channel.join')))
await self._rsocket.request_response(request)
return self
async def leave(self, channel_name: str):
request = Payload(ensure_bytes(channel_name), composite(route('channel.leave')))
await self._rsocket.request_response(request)
return self
async def channel_message(self, channel: str, content: str):
print(f'Sending {content} to channel {channel}')
await self._rsocket.request_response(Payload(encode_dataclass(Message(channel=channel, content=content)),
composite(route('message'))))
async def list_channel_users(self, channel_name: str):
request = Payload(ensure_bytes(channel_name), composite(route('channel.users')))
response = await AwaitableRSocket(self._rsocket).request_stream(request)
return list(map(lambda _: utf8_decode(_.data), response))
async def get_users(self, channel_name: str) -> List[str]:
request = Payload(ensure_bytes(channel_name), composite(route('channel.users')))
users = await AwaitableRSocket(self._rsocket).request_stream(request)
return [utf8_decode(user.data) for user in users]
Lines 15-23 define the join/leave methods. They are both simple routed request_response
calls, with the channel name as the payload data.
Lines 25-28 define the list_channels method. This method uses the AwaitableRSocket
adapter to simplify getting the response stream as a list.
Lines 30-31 define the get_users method, which lists a channel's users.
Update the print_message
method to include the channel:
def print_message(data: bytes):
message = Message(**json.loads(data))
print(f'{self._username}: from {message.user} ({message.channel}): {message.content}')
Test the new functionality
Let's test the new functionality using the following code:
async def messaging_example(user1: ChatClient, user2: ChatClient):
user1.listen_for_messages()
user2.listen_for_messages()
await user1.join('channel1')
await user2.join('channel1')
print(f'Channels: {await user1.list_channels()}')
await user1.private_message('user2', 'private message from user1')
await user1.channel_message('channel1', 'channel message from user1')
await asyncio.sleep(1)
user1.stop_listening_for_messages()
user2.stop_listening_for_messages()
Call the example method from the main
method and pass it the two chat clients:
user1 = ChatClient(client1)
user2 = ChatClient(client2)
await user1.login('user1')
await user2.login('user2')
await messaging_example(user1, user2)