naming
Your naming is not consistent. sometimes you use letter, sometimes w for the same thing. I generally avoid 1-letter variable names, but if you use them, be consistent
dict.setdefault
using dict.setdefault, you can simplify your Trie.add method significantly
def add(self, word):
curr = self.root
for letter in word:
curr = curr.children.setdefault(letter, TrieNode())
curr.end_of_word = True
Then you can also change the TrieNode.children to an ordinary dict.
string representation
For debugging, it can be handy to have a string representation of a Node
def __repr__(self):
return f'TrieNode(end_of_word={self.end_of_word}, children={tuple(self.children)})'
getting a node
Currently, there is no way in your Trie to get a node. Having this method would simplify the rest of the implementation
def __getitem__(self, word):
curr = self.root
for letter in word:
curr = curr.children[letter]
return curr
def get(self, word):
return self[word]
If you changed the type of TrieNode.children to a dict instead of a defaultdict, this will raise a KeyError. If you left it at defaultdict, this will return an empty TrieNode: 'TrieNode(end_of_word=False, children=())', check for it, and raise the KeyError yourself
would simplify the rest of the implementation
def __getitem__(self, word):
curr = self.root
for letter in word:
curr = curr.children[letter]
if not (curr.children or curr.end_of_word):
raise KeyError(f'{word} not in Trie')
return curr
trie['foo']
TrieNode(end_of_word=True, children=('b', 'f'))
Search
With the method to get a Node, Search becomes as trivial as
def search(self, word):
try:
return self[word].end_of_word
except KeyError:
return False
words starting with prefix
This name can be shortened to starts_with.
Here, I would move the iteration to find the 'child-words' to the TrieNode, and recursively descend down the nodes
def child_words(self, prefix=''):
if self.end_of_word:
yield prefix
for letter, node in self.children.items():
word = prefix + letter
yield from node.child_words(word)
Trie.starts_with becomes simply:
def starts_with(self, prefix):
try:
node = self[prefix]
except KeyError:
raise KeyError(f"Prefix `{prefix}` not in Trie")
return node.child_words(prefix)
which returns the generator yielding words
list(trie.starts_with('foo))
['foo', 'foob', 'foobar', 'foof']
If you want to, you can even add in a inclusive boolean flag
def child_words(self, prefix='', inclusive=True):
if inclusive and self.end_of_word:
yield prefix
for letter, node in self.children.items():
word = prefix + letter
yield from node.child_words(word, inclusive=True)
full code
class TrieNode:
def __init__(self):
self.children = dict()
self.end_of_word = False
def __repr__(self):
return f'TrieNode(end_of_word={self.end_of_word},' \
f' children={tuple(self.children)})'
def child_words(self, prefix='', inclusive=True):
if inclusive and self.end_of_word:
yield prefix
for letter, node in self.children.items():
word = prefix + letter
yield from node.child_words(word, inclusive=True)
class Trie_Maarten:
def __init__(self):
self.root = TrieNode()
def add(self, word):
curr = self.root
for letter in word:
curr = curr.children.setdefault(letter, TrieNode())
curr.end_of_word = True
def __getitem__(self, word):
curr = self.root
for letter in word:
curr = curr.children[letter]
return curr
def get(self, word):
return self[word]
def search(self, word):
try:
return self[word].end_of_word
except KeyError:
return False
def starts_with(self, prefix, inclusive=True):
try:
node = self[prefix]
except KeyError:
raise KeyError(f"Prefix `{prefix}` not in Trie")
return node.child_words(prefix, inclusive)