View Javadoc
1   /*
2    * Copyright 2016–2021 Michael Osipov
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *     http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  package net.sf.michaelo.activedirectory;
17  
18  import java.util.Arrays;
19  import java.util.Hashtable;
20  import java.util.NoSuchElementException;
21  import java.util.Objects;
22  import java.util.Scanner;
23  import java.util.concurrent.ThreadLocalRandom;
24  import java.util.logging.Level;
25  import java.util.logging.Logger;
26  
27  import javax.naming.Context;
28  import javax.naming.InvalidNameException;
29  import javax.naming.NameNotFoundException;
30  import javax.naming.NamingEnumeration;
31  import javax.naming.NamingException;
32  import javax.naming.directory.Attribute;
33  import javax.naming.directory.Attributes;
34  import javax.naming.directory.DirContext;
35  import javax.naming.directory.InitialDirContext;
36  
37  import org.apache.commons.lang3.Validate;
38  
39  /**
40   * A locator for various Active Directory services like LDAP, Global Catalog, Kerberos, etc. via DNS
41   * SRV resource records. This is a lightweight implementation of
42   * <a href="https://www.rfc-editor.org/rfc/rfc2782.html">RFC 2782</a> for the resource records
43   * depicted
44   * <a href="https://technet.microsoft.com/en-us/library/cc759550%28v=ws.10%29.aspx">here</a>, but
45   * with the limitation that only TCP is queried and the {@code _msdcs} subdomain is ignored. The
46   * server selection algorithm for failover is fully implemented.
47   * <p>
48   * Here is a minimal example how to create a {@code ActiveDirectoryDnsLocator} with the supplied
49   * builder:
50   *
51   * <pre>
52   * ActiveDirectoryDnsLocator.Builder builder = new ActiveDirectoryDnsLocator.Builder();
53   * ActiveDirectoryDnsLocator locator = builder.build();
54   * HostPort[] servers = locator.locate("ldap", "ad.example.com");
55   * </pre>
56   *
57   * An {@code ActiveDirectoryDnsLocator} object will be initially preconfigured by its builder for
58   * you:
59   * <ol>
60   * <li>The context factory is set by default to {@code com.sun.jndi.dns.DnsContextFactory}.</li>
61   * </ol>
62   *
63   * A complete overview of all {@code DirContext} properties can be found
64   * <a href= "https://docs.oracle.com/javase/8/docs/technotes/guides/jndi/jndi-dns.html">here</a>.
65   * Make sure that you pass reasonable/valid values only otherwise the behavior is undefined.
66   *
67   * <p>
68   * <b>Note</b>: This class uses JULI to print log messages, enable at least level {@code FINE}
69   * to see output.
70   */
71  public class ActiveDirectoryDnsLocator {
72  
73  	private static class SrvRecord implements Comparable<SrvRecord> {
74  
75  		static final String UNAVAILABLE_SERVICE = ".";
76  
77  		private int priority;
78  		private int weight;
79  		private int sum;
80  		private int port;
81  		private String target;
82  
83  		public SrvRecord(int priority, int weight, int port, String target) {
84  			Validate.inclusiveBetween(0, 0xFFFF, priority, "priority must be between 0 and 65535");
85  			Validate.inclusiveBetween(0, 0xFFFF, weight, "weight must be between 0 and 65535");
86  			Validate.inclusiveBetween(0, 0xFFFF, port, "port must be between 0 and 65535");
87  			Validate.notEmpty(target, "target cannot be null or empty");
88  
89  			this.priority = priority;
90  			this.weight = weight;
91  			this.port = port;
92  			this.target = target;
93  		}
94  
95  		@Override
96  		public boolean equals(Object obj) {
97  			if (obj == null || !(obj instanceof SrvRecord))
98  				return false;
99  
100 			SrvRecord that = (SrvRecord) obj;
101 
102 			return priority == that.priority && weight == that.weight && port == that.port
103 					&& target.equals(that.target);
104 		}
105 
106 		@Override
107 		public int hashCode() {
108 			return Objects.hash(priority, weight, port, target);
109 		}
110 
111 		@Override
112 		public String toString() {
113 			StringBuilder builder = new StringBuilder("SRV RR: ");
114 			builder.append(priority).append(' ');
115 			builder.append(weight).append(' ');
116 			if (sum != 0)
117 				builder.append('(').append(sum).append(") ");
118 			builder.append(port).append(' ');
119 			builder.append(target);
120 			return builder.toString();
121 		}
122 
123 		@Override
124 		public int compareTo(SrvRecord that) {
125 			// Comparing according to the RFC
126 			if (priority > that.priority) {
127 				return 1;
128 			} else if (priority < that.priority) {
129 				return -1;
130 			} else if (weight == 0 && that.weight != 0) {
131 				return -1;
132 			} else if (weight != 0 && that.weight == 0) {
133 				return 1;
134 			} else {
135 				return 0;
136 			}
137 		}
138 
139 	}
140 
141 	/**
142 	 * A mere container for a host along with a port.
143 	 */
144 	public static class HostPort {
145 
146 		private String host;
147 		private int port;
148 
149 		public HostPort(String host, int port) {
150 			Validate.notEmpty(host, "host cannot be null or empty");
151 			Validate.inclusiveBetween(0, 0xFFFF, port, "port must be between 0 and 65535");
152 
153 			this.host = host;
154 			this.port = port;
155 		}
156 
157 		public String getHost() {
158 			return host;
159 		}
160 
161 		public int getPort() {
162 			return port;
163 		}
164 
165 		@Override
166 		public boolean equals(Object obj) {
167 			if (obj == null || !(obj instanceof HostPort))
168 				return false;
169 
170 			HostPort that = (HostPort) obj;
171 
172 			return host.equals(that.host) && port == that.port;
173 		}
174 
175 		@Override
176 		public int hashCode() {
177 			return Objects.hash(host, port);
178 		}
179 
180 		@Override
181 		public String toString() {
182 			StringBuilder builder = new StringBuilder();
183 			builder.append(host).append(':').append(port);
184 			return builder.toString();
185 		}
186 
187 	}
188 
189 	private static final String SRV_RR_FORMAT = "_%s._tcp.%s";
190 	private static final String SRV_RR_WITH_SITES_FORMAT = "_%s._tcp.%s._sites.%s";
191 
192 	private static final String SRV_RR = "SRV";
193 	private static final String[] SRV_RR_ATTR = new String[] { SRV_RR };
194 
195 	private static final Logger LOGGER = Logger
196 			.getLogger(ActiveDirectoryDnsLocator.class.getName());
197 
198 	private final Hashtable<String, Object> env;
199 
200 	private ActiveDirectoryDnsLocator(Builder builder) {
201 		env = new Hashtable<String, Object>();
202 		env.put(Context.INITIAL_CONTEXT_FACTORY, builder.contextFactory);
203 		env.putAll(builder.additionalProperties);
204 	}
205 
206 	/**
207 	 * A builder to construct an {@link ActiveDirectoryDnsLocator} with a fluent interface.
208 	 *
209 	 * <p>
210 	 * <strong>Notes:</strong>
211 	 * <ol>
212 	 * <li>This class is not thread-safe. Configure the builder in your main thread, build the
213 	 * object and pass it on to your forked threads.</li>
214 	 * <li>An {@code IllegalStateException} is thrown if a property is modified after this builder
215 	 * has already been used to build an {@code ActiveDirectoryDnsLocator}, simply create a new
216 	 * builder in this case.</li>
217 	 * <li>All passed arrays will be defensively copied and null/empty values will be skipped except
218 	 * when all elements are invalid, an exception will be raised.</li>
219 	 * </ol>
220 	 */
221 	public static final class Builder {
222 
223 		// Builder properties
224 		private String contextFactory;
225 		private Hashtable<String, Object> additionalProperties;
226 
227 		private boolean done;
228 
229 		/**
230 		 * Constructs a new builder for {@link ActiveDirectoryDnsLocator}.
231 		 */
232 		public Builder() {
233 			// Initialize default values first as mentioned in the class' Javadoc
234 			contextFactory("com.sun.jndi.dns.DnsContextFactory");
235 			additionalProperties = new Hashtable<String, Object>();
236 		}
237 
238 		/**
239 		 * Sets the context factory for this service locator.
240 		 *
241 		 * @param contextFactory
242 		 *            the context factory class name
243 		 * @throws NullPointerException
244 		 *             if {@code contextFactory} is null
245 		 * @throws IllegalArgumentException
246 		 *             if {@code contextFactory} is empty
247 		 * @return this builder
248 		 */
249 		public Builder contextFactory(String contextFactory) {
250 			check();
251 			this.contextFactory = validateAndReturnString("contextFactory", contextFactory);
252 			return this;
253 		}
254 
255 		/**
256 		 * Sets an additional property not available through the builder interface.
257 		 *
258 		 * @param name
259 		 *            name of the property
260 		 * @param value
261 		 *            value of the property
262 		 * @throws NullPointerException
263 		 *             if {@code name} is null
264 		 * @throws IllegalArgumentException
265 		 *             if {@code value} is empty
266 		 * @return this builder
267 		 */
268 		public Builder additionalProperty(String name, Object value) {
269 			check();
270 			Validate.notEmpty(name, "additional property's name cannot be null or empty");
271 			this.additionalProperties.put(name, value);
272 			return this;
273 		}
274 
275 		/**
276 		 * Builds an {@code ActiveDirectoryDnsLocator} and marks this builder as non-modifiable for
277 		 * future use. You may call this method as often as you like, it will return a new
278 		 * {@code ActiveDirectoryDnsLocator} instance on every call.
279 		 *
280 		 * @throws IllegalStateException
281 		 *             if a combination of necessary attributes is not set
282 		 * @return an {@code ActiveDirectoryDnsLocator} object
283 		 */
284 		public ActiveDirectoryDnsLocator build() {
285 
286 			ActiveDirectoryDnsLocator serviceLocator = new ActiveDirectoryDnsLocator(this);
287 			done = true;
288 
289 			return serviceLocator;
290 		}
291 
292 		private void check() {
293 			if (done)
294 				throw new IllegalStateException("cannot modify an already used builder");
295 		}
296 
297 		private String validateAndReturnString(String name, String value) {
298 			return Validate.notEmpty(value, "property '%s' cannot be null or empty", name);
299 		}
300 
301 	}
302 
303 	private SrvRecord[] lookUpSrvRecords(DirContext context, String name) throws NamingException {
304 		Attributes attrs = null;
305 
306 		try {
307 			attrs = context.getAttributes(name, SRV_RR_ATTR);
308 		} catch (InvalidNameException e) {
309 			NamingException ne = new NamingException("name '" + name + "' is invalid");
310 			ne.initCause(ne);
311 			throw ne;
312 		} catch (NameNotFoundException e) {
313 			return null;
314 		}
315 
316 		Attribute srvAttr = attrs.get(SRV_RR);
317 		if (srvAttr == null)
318 			return null;
319 
320 		NamingEnumeration<?> records = null;
321 
322 		SrvRecord[] srvRecords = new SrvRecord[srvAttr.size()];
323 
324 		try {
325 			records = srvAttr.getAll();
326 
327 			int recordCnt = 0;
328 			while (records.hasMoreElements()) {
329 				String record = (String) records.nextElement();
330 				try (Scanner scanner = new Scanner(record)) {
331 					scanner.useDelimiter(" ");
332 
333 					int priority = scanner.nextInt();
334 					int weight = scanner.nextInt();
335 					int port = scanner.nextInt();
336 					String target = scanner.next();
337 					SrvRecord srvRecord = new SrvRecord(priority, weight, port, target);
338 
339 					srvRecords[recordCnt++] = srvRecord;
340 				}
341 			}
342 		} catch (NoSuchElementException e) {
343 			throw new IllegalStateException("The supplied SRV RR is invalid", e);
344 		} finally {
345 			if (records != null)
346 				try {
347 					records.close();
348 				} catch (NamingException e) {
349 					; // ignore
350 				}
351 		}
352 
353 		/*
354 		 * No servers returned or explicitly indicated by the DNS server that this service is not
355 		 * provided as described by the RFC.
356 		 */
357 		if (srvRecords.length == 0 || srvRecords.length == 1
358 				&& srvRecords[0].target.equals(SrvRecord.UNAVAILABLE_SERVICE))
359 			return null;
360 
361 		return srvRecords;
362 	}
363 
364 	private HostPort[] sortByRfc2782(SrvRecord[] srvRecords) {
365 		if (srvRecords == null)
366 			return null;
367 
368 		// Apply the server selection algorithm
369 		Arrays.sort(srvRecords);
370 
371 		HostPort[] sortedServers = new HostPort[srvRecords.length];
372 		for (int i = 0, start = -1, end = -1, hp = 0; i < srvRecords.length; i++) {
373 
374 			start = i;
375 			while (i + 1 < srvRecords.length
376 					&& srvRecords[i].priority == srvRecords[i + 1].priority) {
377 				i++;
378 			}
379 			end = i;
380 
381 			for (int repeat = 0; repeat < (end - start) + 1; repeat++) {
382 				int sum = 0;
383 				for (int j = start; j <= end; j++) {
384 					if (srvRecords[j] != null) {
385 						sum += srvRecords[j].weight;
386 						srvRecords[j].sum = sum;
387 					}
388 				}
389 
390 				int r = sum == 0 ? 0 : ThreadLocalRandom.current().nextInt(sum + 1);
391 				for (int k = start; k <= end; k++) {
392 					SrvRecord srvRecord = srvRecords[k];
393 
394 					if (srvRecord != null && srvRecord.sum >= r) {
395 						String host = srvRecord.target.substring(0, srvRecord.target.length() - 1);
396 						sortedServers[hp++] = new HostPort(host, srvRecord.port);
397 						srvRecords[k] = null;
398 					}
399 				}
400 			}
401 		}
402 
403 		return sortedServers;
404 	}
405 
406 	/**
407 	 * Locates a desired service via DNS within an Active Directory site and domain, sorted and
408 	 * selected according to RFC 2782.
409 	 *
410 	 * @param service
411 	 *            the service to be located
412 	 * @param site
413 	 *            the Active Directory site the client resides in
414 	 * @param domainName
415 	 *            the desired domain name. Can be any naming context name.
416 	 * @return the located servers or null if none found
417 	 * @throws NullPointerException
418 	 *             if service or domain name is null
419 	 * @throws IllegalArgumentException
420 	 *             if service or domain name is empty
421 	 * @throws NamingException
422 	 *             if an error has occurred while creating or querying the DNS directory context
423 	 * @throws IllegalStateException
424 	 *             if any of the DNS returned RRs not adhere to the RFC
425 	 */
426 	public HostPort[] locate(String service, String site, String domainName)
427 			throws NamingException {
428 		Validate.notEmpty(service, "service cannot be null or empty");
429 		Validate.notEmpty(domainName, "domainName cannot be null or empty");
430 
431 		DirContext context = null;
432 		try {
433 			context = new InitialDirContext(env);
434 		} catch (NamingException e) {
435 			NamingException ne = new NamingException("failed to create DNS directory context");
436 			ne.initCause(e);
437 			throw ne;
438 		}
439 
440 		SrvRecord[] srvRecords = null;
441 
442 		String lookupName;
443 		if (site != null && !site.isEmpty())
444 			lookupName = String.format(SRV_RR_WITH_SITES_FORMAT, service, site, domainName);
445 		else
446 			lookupName = String.format(SRV_RR_FORMAT, service, domainName);
447 
448 		try {
449 			LOGGER.log(Level.FINE, "Looking up SRV RRs for ''{0}''", lookupName);
450 
451 			srvRecords = lookUpSrvRecords(context, lookupName);
452 		} catch (NamingException e) {
453 			LOGGER.log(Level.SEVERE,
454 					String.format("Failed to look up SRV RRs for '%s'", lookupName), e);
455 
456 			throw e;
457 		} finally {
458 			try {
459 				context.close();
460 			} catch (NamingException e) {
461 				; // ignore
462 			}
463 		}
464 
465 		if (srvRecords == null)
466 			LOGGER.log(Level.FINE, "No SRV RRs for ''{0}'' found", lookupName);
467 		else {
468 			if (LOGGER.isLoggable(Level.FINER))
469 				LOGGER.log(Level.FINER, "Found {0} SRV RRs for ''{1}'': {2}", new Object[] {
470 						srvRecords.length, lookupName, Arrays.toString(srvRecords) });
471 			else
472 				LOGGER.log(Level.FINE, "Found {0} SRV RRs for ''{1}''",
473 						new Object[] { srvRecords.length, lookupName });
474 		}
475 
476 		HostPort[] servers = sortByRfc2782(srvRecords);
477 
478 		if (servers == null)
479 			return null;
480 
481 		if (LOGGER.isLoggable(Level.FINER))
482 			LOGGER.log(Level.FINER, "Selected {0} servers for ''{1}'': {2}",
483 					new Object[] { servers.length, lookupName, Arrays.toString(servers) });
484 		else
485 			LOGGER.log(Level.FINE, "Selected {0} servers for ''{1}''",
486 					new Object[] { servers.length, lookupName });
487 
488 		return servers;
489 	}
490 
491 	/**
492 	 * Locates a desired service via DNS within an Active Directory domain, sorted and selected
493 	 * according to RFC 2782.
494 	 *
495 	 * @param service
496 	 *            the service to be located
497 	 * @param domainName
498 	 *            the desired domain name. Can be any naming context name.
499 	 * @return the located servers or null if none found
500 	 * @throws NullPointerException
501 	 *             if service or domain name is null
502 	 * @throws IllegalArgumentException
503 	 *             if service or domain name is empty
504 	 * @throws NamingException
505 	 *             if an error has occurred while creating or querying the DNS directory context
506 	 * @throws IllegalStateException
507 	 *             if any of the DNS returned RRs not adhere to the RFC
508 	 */
509 	public HostPort[] locate(String service, String domainName) throws NamingException {
510 		return locate(service, null, domainName);
511 	}
512 
513 }